如何在tflite中使用posenet模型的输出 [英] How to use outputs of posenet model in tflite

查看:356
本文介绍了如何在tflite中使用posenet模型的输出的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用这里.它获取输入1 * 353 * 257 * 3输入图像并返回4个变暗数组1 * 23 * 17 * 17、1 * 23 * 17 * 34、1 * 23 * 17 * 64和1 * 23 * 17 * 1.该模型的输出步幅为16.如何获得输入图像上所有17个姿势点的坐标?我尝试从out1阵列的热图中打印置信度得分,但每个像素的值接近0.00.代码如下:

I am using the tflite model for posenet from here. It takes input 1*353*257*3 input image and returns 4 arrays of dimens 1*23*17*17, 1*23*17*34, 1*23*17*64 and 1*23*17*1. The model has an output stride of 16. How can I get the coordinates of all 17 pose points on my input image? I have tried printing the confidence scores from the heatmap of out1 array but I get near to 0.00 values for each pixel. Code is given below:

public class MainActivity extends AppCompatActivity {
private static final int CAMERA_REQUEST = 1888;
private ImageView imageView;
private static final int MY_CAMERA_PERMISSION_CODE = 100;
Interpreter tflite = null;
private String TAG = "rohit";
//private Canvas canvas;

Map<Integer, Object> outputMap = new HashMap<>();
float[][][][] out1 = new float[1][23][17][17];
float[][][][] out2 = new float[1][23][17][34];
float[][][][] out3 = new float[1][23][17][64];
float[][][][] out4 = new float[1][23][17][1];

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);
    String modelFile="multi_person_mobilenet_v1_075_float.tflite";
    try {
        tflite=new Interpreter(loadModelFile(MainActivity.this,modelFile));
    } catch (IOException e) {
        e.printStackTrace();
    }
    final Tensor no = tflite.getInputTensor(0);
    Log.d(TAG, "onCreate: Input shape"+ Arrays.toString(no.shape()));

    int c = tflite.getOutputTensorCount();
    Log.d(TAG, "onCreate: Output Count" +c );
    for (int i = 0; i <4 ; i++) {
        final Tensor output = tflite.getOutputTensor(i);
        Log.d(TAG, "onCreate: Output shape" + Arrays.toString(output.shape()));
    }
    this.imageView =  this.findViewById(R.id.imageView1);
    Button photoButton = this.findViewById(R.id.button1);
    photoButton.setOnClickListener(new View.OnClickListener() {

        @Override
        public void onClick(View v) {
            if (checkSelfPermission(Manifest.permission.CAMERA)
                    != PackageManager.PERMISSION_GRANTED) {
                requestPermissions(new String[]{Manifest.permission.CAMERA},
                        MY_CAMERA_PERMISSION_CODE);
            } else {
                Intent cameraIntent = new Intent(android.provider.MediaStore.ACTION_IMAGE_CAPTURE);
                startActivityForResult(cameraIntent, CAMERA_REQUEST);
            }
        }
    });
}

public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
    super.onRequestPermissionsResult(requestCode, permissions, grantResults);
    if (requestCode == MY_CAMERA_PERMISSION_CODE) {
        if (grantResults[0] == PackageManager.PERMISSION_GRANTED) {
            Toast.makeText(this, "camera permission granted", Toast.LENGTH_LONG).show();
            Intent cameraIntent = new
                    Intent(android.provider.MediaStore.ACTION_IMAGE_CAPTURE);
            startActivityForResult(cameraIntent, CAMERA_REQUEST);
        } else {
            Toast.makeText(this, "camera permission denied", Toast.LENGTH_LONG).show();
        }
    }
}

protected void onActivityResult ( int requestCode, int resultCode, Intent data){
    if (requestCode == CAMERA_REQUEST && resultCode == Activity.RESULT_OK) {
        Bitmap photo = (Bitmap) data.getExtras().get("data");
        Log.d(TAG,"bhai:"+photo.getWidth()+":"+photo.getHeight());
        //imageView.setImageBitmap(photo);
        photo = Bitmap.createScaledBitmap(photo, 353, 257, false);
        photo = photo.copy(Bitmap.Config.ARGB_8888,true);
        Log.d(TAG, "onActivityResult: Bitmap resized");

        int width =photo.getWidth();
        int height = photo.getHeight();
        float[][][][] result = new float[1][width][height][3];
        int[] pixels = new int[width*height];
        photo.getPixels(pixels, 0, width, 0, 0, width, height);
        int pixelsIndex = 0;
        for (int i = 0; i < width; i++)
        {
            for (int j = 0; j < height; j++)
            {
                // result[i][j] =  pixels[pixelsIndex];
                int p = pixels[pixelsIndex];
                result[0][i][j][0]  = (p >> 16) & 0xff;
                result[0][i][j][1]  = (p >> 8) & 0xff;
                result[0][i][j][2]  = p & 0xff;
                pixelsIndex++;
            }
        }
        Object [] inputs = {result};
        //inputs[0] = inp;

        outputMap.put(0, out1);
        outputMap.put(1, out2);
        outputMap.put(2, out3);
        outputMap.put(3, out4);

        tflite.runForMultipleInputsOutputs(inputs,outputMap);
        out1 = (float[][][][]) outputMap.get(0);
        out2 = (float[][][][]) outputMap.get(1);
        out3 = (float[][][][]) outputMap.get(2);
        out4 = (float[][][][]) outputMap.get(3);

        Canvas canvas = new Canvas(photo);
        Paint p = new Paint();
        p.setColor(Color.RED);

        float[][][] scores = new float[out1[0].length][out1[0][0].length][17];
        int[][] heatmap_pos = new int[17][2];

        for(int i=0;i<17;i++)
        {
            float max = -1;

            for(int j=0;j<out1[0].length;j++)
            {
                for(int k=0;k<out1[0][0].length;k++)
                {
                  //  Log.d("mylog", "onActivityResult: "+out1[0][j][k][i]);
                        scores[j][k][i]  = sigmoid(out1[0][j][k][i]);
                        if(max<scores[j][k][i])
                        {
                            max = scores[j][k][i];
                            heatmap_pos[i][0] = j;
                            heatmap_pos[i][1] = k;
                        }
                }

            }
       //     Log.d(TAG, "onActivityResult: "+max+"    "+heatmap_pos[i][0]+"    "+heatmap_pos[i][1]);
        }

        for(int i=0;i<17;i++)
        {
            float max = -1;

            for(int j=0;j<out1[0].length;j++)
            {
                for(int k=0;k<out1[0][0].length;k++)
                {
                    Log.d("mylog", "onActivityResult: "+out1[0][j][k][i]);
                    scores[j][k][i]  = sigmoid(out1[0][j][k][i]);
                    if(max<scores[j][k][i])
                    {
                        max = scores[j][k][i];
                        heatmap_pos[i][0] = j;
                        heatmap_pos[i][1] = k;
                    }
                }

            }
            //     Log.d(TAG, "onActivityResult: "+max+"    "+heatmap_pos[i][0]+"    "+heatmap_pos[i][1]);
        }
        for(int i=0;i<17;i++)
        {
            Log.d("heatlog", "onActivityResult: "+heatmap_pos[i][0]+"    "+heatmap_pos[i][1]);
        }
        float[][] offset_vector = new float[17][2];
        float[][] keypoint_pos = new float[17][2];
        for(int i=0;i<17;i++)
        {
            offset_vector[i][0] = out2[0][heatmap_pos[i][0]][heatmap_pos[i][1]][i];
            offset_vector[i][1] = out2[0][heatmap_pos[i][0]][heatmap_pos[i][1]][i+17];
            Log.d("myoff",offset_vector[i][0]+":"+offset_vector[i][1]);
            keypoint_pos[i][0] = heatmap_pos[i][0]*16+offset_vector[i][0];
            keypoint_pos[i][1] = heatmap_pos[i][1]*16+offset_vector[i][1];
            Log.d(TAG, "onActivityResult: "+keypoint_pos[i][0]+"    "+keypoint_pos[i][1]);
            canvas.drawCircle(keypoint_pos[i][0]+353/2,keypoint_pos[i][1]-257/2,5,p);
        }

        imageView.setImageBitmap(photo);
    }
}

private MappedByteBuffer loadModelFile(Activity activity, String MODEL_FILE) throws IOException {
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

public float sigmoid(float value) {
    float p =  (float)(1.0 / (1 + Math.exp(-value)));
    return p;
}
}

推荐答案

我认为此tflite模型文件存在问题. 因此,我尝试使用模型中的权重创建posenet tflite模型. 该模型中的所有权重都可以从tfjs-models下载: https://github.com/tensorflow/tfjs-models/tree/master/姿势网

I think there are something wrong with this tflite model file. So I tried to create the posenet tflite model using the weights in the model. All the weights in the model can be downloaded from tfjs-models: https://github.com/tensorflow/tfjs-models/tree/master/posenet

然后,您可以生成模型并按照以下回购进行所有的前处理和后处理: https://github.com/zg9uagfv/tf_posenet

Then you can generate the model and do all the pre and post process as the follow repo: https://github.com/zg9uagfv/tf_posenet

生成posenet模型后,可以导出到.pb文件或.tflite文件. 我已经成功尝试了该过程,并且posenet模型可以在具有GPU的Android应用中成功运行.

After the posenet model generated, you can export to .pb file or .tflite file. I have tried the process successfully, and the posenet model can be run in my Android App with GPU successfully.

这篇关于如何在tflite中使用posenet模型的输出的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆