如何减小Tflite模型的大小或以编程方式下载并设置它? [英] How to Reduce size of Tflite model or Download and set it programmatically?

查看:68
本文介绍了如何减小Tflite模型的大小或以编程方式下载并设置它?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

好的,所以在我的应用中,我尝试使用人脸网络模型实现人脸识别,该模型转换为平均约93 MB的tflite, 但是,此模型最终会增加我的apk大小. 所以我试图找到其他方法来解决这个问题

Okay so in my app i am trying to implement face recognition using face net model which is converted to tflite averaging at about 93 MB approximately, however this model eventually increases size of my apk. so i am trying to find alternate ways to deal with this

首先,我想到的是以某种方式压缩它,然后在安装应用程序时解压缩

Firstly i can think of is to compress it in some way and then uncompress when app is installed

另一种方法是,我应该将该模型上传到服务器,并在下载后将其加载到我的应用程序中. 但是我似乎不知道如何实现这一点:

Another way is that i should upload that model to server and after being downloaded get it loaded in my application. However i do not seem know how to implement this:

默认情况下,人脸网允许从资产文件夹中实施

By default face net allows implementation from assets folder

 var facenet = FaceNet(getAssets());

但是如果我要下载该模型,如何将其加载到我的应用程序中?

But in case i'm downloading that model how can i get it loaded in my application?

这是我的脸净输入代码:

Here is my face net intilization code:

  public FaceNet(AssetManager assetManager) throws IOException {
        tfliteModel = loadModelFile(assetManager);
        tflite = new Interpreter(tfliteModel, tfliteOptions);
        imgData = ByteBuffer.allocateDirect(
                BATCH_SIZE
                        * IMAGE_HEIGHT
                        * IMAGE_WIDTH
                        * NUM_CHANNELS
                        * NUM_BYTES_PER_CHANNEL);
        imgData.order(ByteOrder.nativeOrder());
    }   


private MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
            AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
            FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
            FileChannel fileChannel = inputStream.getChannel();
            long startOffset = fileDescriptor.getStartOffset();
            long declaredLength = fileDescriptor.getDeclaredLength();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
        }

我的FaceNet类:

My FaceNet Class:

public class FaceNet {
    private static final String MODEL_PATH = "facenet.tflite";

    private static final float IMAGE_MEAN = 127.5f;
    private static final float IMAGE_STD = 127.5f;

    private static final int BATCH_SIZE = 1;
    private static final int IMAGE_HEIGHT = 160;
    private static final int IMAGE_WIDTH = 160;
    private static final int NUM_CHANNELS = 3;
    private static final int NUM_BYTES_PER_CHANNEL = 4;
    private static final int EMBEDDING_SIZE = 512;

    private final int[] intValues = new int[IMAGE_HEIGHT * IMAGE_WIDTH];
    private ByteBuffer imgData;

    private MappedByteBuffer tfliteModel;
    private Interpreter tflite;
    private final Interpreter.Options tfliteOptions = new Interpreter.Options();

    public FaceNet(AssetManager assetManager) throws IOException {
        tfliteModel = loadModelFile(assetManager);
        tflite = new Interpreter(tfliteModel, tfliteOptions);
        imgData = ByteBuffer.allocateDirect(
                BATCH_SIZE
                        * IMAGE_HEIGHT
                        * IMAGE_WIDTH
                        * NUM_CHANNELS
                        * NUM_BYTES_PER_CHANNEL);
        imgData.order(ByteOrder.nativeOrder());
    }

    private MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
        AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    private void convertBitmapToByteBuffer(Bitmap bitmap) {
        if (imgData == null) {
            return;
        }
        imgData.rewind();
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        // Convert the image to floating point.
        int pixel = 0;
        for (int i = 0; i < IMAGE_HEIGHT; ++i) {
            for (int j = 0; j < IMAGE_WIDTH; ++j) {
                final int val = intValues[pixel++];
                addPixelValue(val);
            }
        }
    }

    private void addPixelValue(int pixelValue){
        //imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
        //imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
        //imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
        imgData.putFloat(((pixelValue >> 16) & 0xFF) / 255.0f);
        imgData.putFloat(((pixelValue >> 8) & 0xFF) / 255.0f);
        imgData.putFloat((pixelValue & 0xFF) / 255.0f);
    }

    public void inspectModel(){
        String tag = "Model Inspection";
        Log.i(tag, "Number of input tensors: " + String.valueOf(tflite.getInputTensorCount()));
        Log.i(tag, "Number of output tensors: " + String.valueOf(tflite.getOutputTensorCount()));

        Log.i(tag, tflite.getInputTensor(0).toString());
        Log.i(tag, "Input tensor data type: " + tflite.getInputTensor(0).dataType());
        Log.i(tag, "Input tensor shape: " + Arrays.toString(tflite.getInputTensor(0).shape()));
        Log.i(tag, "Output tensor 0 shape: " + Arrays.toString(tflite.getOutputTensor(0).shape()));
    }

    private Bitmap resizedBitmap(Bitmap bitmap, int height, int width){
        return Bitmap.createScaledBitmap(bitmap, width, height, true);
    }

    private Bitmap croppedBitmap(Bitmap bitmap, int upperCornerX, int upperCornerY, int height, int width){
        return Bitmap.createBitmap(bitmap, upperCornerX, upperCornerY, width, height);
    }

    private float[][] run(Bitmap bitmap){
        bitmap = resizedBitmap(bitmap, IMAGE_HEIGHT, IMAGE_WIDTH);
        convertBitmapToByteBuffer(bitmap);

        float[][] embeddings = new float[1][512];
        tflite.run(imgData, embeddings);

        return embeddings;
    }

    public double getSimilarityScore(Bitmap face1, Bitmap face2){
        float[][] face1_embedding = run(face1);
        float[][] face2_embedding = run(face2);

        double distance = 0.0;
        for (int i = 0; i < EMBEDDING_SIZE; i++){
            distance += (face1_embedding[0][i] - face2_embedding[0][i]) * (face1_embedding[0][i] - face2_embedding[0][i]);
        }
        distance = Math.sqrt(distance);

        return distance;
    }

    public void close(){
        if (tflite != null) {
            tflite.close();
            tflite = null;
        }
        tfliteModel = null;
    }

}

推荐答案

好吧,我想不出任何减少模型文件大小的解决方案,但是通过观察您的类,我可以说,毕竟它返回了一个映射的字节从文件输入流中缓存缓冲区,以便从存储中获取文件,只需将文件放在外部存储中的facenet文件夹中,然后在文件输入流中获取映射的字节缓冲区,这是kotlin中的解决方案.

Well i can't think of any solution for reducing the size of your model file but by observing your class i can say that after all it's returning a mapped byte buffer from your file input stream so to get file from storage simply put your file in facenet folder in external storage and then get mapped bytebuffer on your file input stream here is a solution in kotlin.

class FaceNetStorage @Throws(IOException::class)
constructor() {
    private val intValues = IntArray(IMAGE_HEIGHT * IMAGE_WIDTH)
    private var imgData: ByteBuffer? = null

    private var tfliteModel: MappedByteBuffer? = null
    private var tflite: Interpreter? = null
    private val tfliteOptions = Interpreter.Options()

    init {
        val str = Environment.getExternalStorageDirectory().toString()+"/Facenet"
        val sd_main = File(str)
        var success = true
        if (!sd_main.exists()) {
            success = sd_main.mkdir()
        }
        if (success) {
            val sd = File(str+"/"+MODEL_PATH)
            tfliteModel = loadModelFile(sd)
            tflite = Interpreter(tfliteModel!!, tfliteOptions)
            imgData = ByteBuffer.allocateDirect(
                    BATCH_SIZE
                            * IMAGE_HEIGHT
                            * IMAGE_WIDTH
                            * NUM_CHANNELS
                            * NUM_BYTES_PER_CHANNEL)
            imgData!!.order(ByteOrder.nativeOrder())
        }
    }

    @Throws(IOException::class)
    private fun loadModelFile(file: File): MappedByteBuffer {
        val inputStream = FileInputStream(file)
        val fileChannel = inputStream.channel
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size())
    }

    private fun convertBitmapToByteBuffer(bitmap: Bitmap) {
        if (imgData == null) {
            return
        }
        imgData!!.rewind()
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        // Convert the image to floating point.
        var pixel = 0
        for (i in 0 until IMAGE_HEIGHT) {
            for (j in 0 until IMAGE_WIDTH) {
                val `val` = intValues[pixel++]
                addPixelValue(`val`)
            }
        }
    }

    private fun addPixelValue(pixelValue: Int) {
        imgData!!.putFloat((pixelValue shr 16 and 0xFF) / 255.0f)
        imgData!!.putFloat((pixelValue shr 8 and 0xFF) / 255.0f)
        imgData!!.putFloat((pixelValue and 0xFF) / 255.0f)
    }

    fun inspectModel() {
        val tag = "Model Inspection"
        Log.i(tag, "Number of input tensors: " + tflite!!.inputTensorCount.toString())
        Log.i(tag, "Number of output tensors: " + tflite!!.outputTensorCount.toString())

        Log.i(tag, tflite!!.getInputTensor(0).toString())
        Log.i(tag, "Input tensor data type: " + tflite!!.getInputTensor(0).dataType())
        Log.i(tag, "Input tensor shape: " + Arrays.toString(tflite!!.getInputTensor(0).shape()))
        Log.i(tag, "Output tensor 0 shape: " + Arrays.toString(tflite!!.getOutputTensor(0).shape()))
    }

    private fun resizedBitmap(bitmap: Bitmap, height: Int, width: Int): Bitmap {
        return Bitmap.createScaledBitmap(bitmap, width, height, true)
    }

    private fun croppedBitmap(bitmap: Bitmap, upperCornerX: Int, upperCornerY: Int, height: Int, width: Int): Bitmap {
        return Bitmap.createBitmap(bitmap, upperCornerX, upperCornerY, width, height)
    }

    private fun run(bitmap: Bitmap): Array<FloatArray> {
        var bitmap = bitmap
        bitmap = resizedBitmap(bitmap, IMAGE_HEIGHT, IMAGE_WIDTH)
        convertBitmapToByteBuffer(bitmap)

        val embeddings = Array(1) { FloatArray(512) }
        tflite!!.run(imgData, embeddings)

        return embeddings
    }

    fun getSimilarityScore(face1: Bitmap, face2: Bitmap): Double {
        val face1_embedding = run(face1)
        val face2_embedding = run(face2)

        var distance = 0.0
        for (i in 0 until EMBEDDING_SIZE) {
            distance += ((face1_embedding[0][i] - face2_embedding[0][i]) * (face1_embedding[0][i] - face2_embedding[0][i])).toDouble()
        }
        distance = Math.sqrt(distance)

        return distance
    }

    fun close() {
        if (tflite != null) {
            tflite!!.close()
            tflite = null
        }
        tfliteModel = null
    }

    companion object {
        private val MODEL_PATH = "facenet.tflite"

        private val IMAGE_MEAN = 127.5f
        private val IMAGE_STD = 127.5f

        private val BATCH_SIZE = 1
        private val IMAGE_HEIGHT = 160
        private val IMAGE_WIDTH = 160
        private val NUM_CHANNELS = 3
        private val NUM_BYTES_PER_CHANNEL = 4
        private val EMBEDDING_SIZE = 512
    }

}

这篇关于如何减小Tflite模型的大小或以编程方式下载并设置它?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆