TensorFlow lite:将模型转换为 tflite 后精度损失很大 [英] TensorFlow lite: High loss in accuracy after converting model to tflite

查看:95
本文介绍了TensorFlow lite:将模型转换为 tflite 后精度损失很大的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在尝试使用 TFLite 来提高 Android 上的检测速度,但奇怪的是我的 .tflite 模型现在几乎只能检测到 1 个类别.

I have been trying TFLite to increase detection speed on Android but strangely my .tflite model now almost only detects 1 category.

我已经对重新训练移动网络后得到的 .pb 模型进行了测试,结果很好,但由于某种原因,当我将其转换为 .tflite 时,检测就差了...

I have done testing on the .pb model that I got after retraining a mobilenet and the results are good but for some reason, when I convert it to .tflite the detection is way off...

对于再培训,我使用了 诗人的 Tensorflow 2

For the retraining I used the retrain.py file from Tensorflow for poets 2

我正在使用以下命令重新训练、优化推理并将模型转换为 tflite:

I am using the following commands to retrain, optimize for inference and convert the model to tflite:

python retrain.py \
--image_dir ~/tf_files/tw/ \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/feature_vector/1 \
--output_graph ~/new_training_dir/retrainedGraph.pb \
-–saved_model_dir ~/new_training_dir/model/ \
--how_many_training_steps 500 

sudo toco \
--input_file=retrainedGraph.pb \
--output_file=optimized_retrainedGraph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TENSORFLOW_GRAPHDEF \
--input_shape=1,224,224,3 \
--input_array=Placeholder \
--output_array=final_result \

sudo toco \
--input_file=optimized_retrainedGraph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=retrainedGraph.tflite \
--inference_type=FLOAT \
--inference_input_type=FLOAT \
--input_arrays=Placeholder \
--output_array=final_result \
--input_shapes=1,224,224,3

我在这里做错了什么吗?准确性的损失从何而来?

Am I doing anything wrong here? Where could the loss in accuracy come from?

推荐答案

我在尝试将 .pb 模型转换为 .lite 时遇到了同样的问题.

I faced the same issue while I was trying to convert a .pb model into .lite.

事实上,我的准确率会从 95 降到 30!

In fact, my accuracy would come down from 95 to 30!

事实证明我犯的错误不是在 .pb 到 .lite 的转换过程中,也不是在执行此操作的命令中.但实际上是在加载图像并对其进行预处理时,才将其传递到 lite 模型并使用

Turns out the mistake I was committing was not during the conversion of .pb to .lite or in the command involved to do so. But it was actually while loading the image and pre-processing it before it is passed into the lite model and inferred using

interpreter.invoke()

命令.

您看到的以下代码是我所说的预处理:

The below code you see is what I meant by pre-processing:

test_image=cv2.imread(file_name)
test_image=cv2.resize(test_image,(299,299),cv2.INTER_AREA)
test_image = np.expand_dims((test_image)/255, axis=0).astype(np.float32)
interpreter.set_tensor(input_tensor_index, test_image)
interpreter.invoke()
digit = np.argmax(output()[0])
#print(digit)
prediction=result[digit]

如您所见,使用imread()"读取图像后,有两个关键命令/预处理完成:

As you can see there are two crucial commands/pre-processing done on the image once it is read using "imread()":

i) 图像的大小应调整为训练期间使用的输入图像/张量的input_height"和input_width"值.在我的情况下 (inception-v3),input_height"和input_width"都是 299.(阅读此值的模型文档或在您用于训练或重新训练模型的文件中查找此变量)

i) The image should be resized to the size that is the "input_height" and "input_width" values of the input image/tensor that was used during the training. In my case (inception-v3) this was 299 for both "input_height" and "input_width". (Read the documentation of the model for this value or look for this variable in the file that you used to train or retrain the model)

ii) 上面代码中的下一个命令是:

ii) The next command in the above code is:

test_image = np.expand_dims((test_image)/255, axis=0).astype(np.float32)

我从公式"/模型代码中得到了这个:

I got this from the "formulae"/model code:

test_image = np.expand_dims((test_image-input_mean)/input_std, axis=0).astype(np.float32)

阅读文档后发现对于我的架构 input_mean = 0 和 input_std = 255.

Reading the documentation revealed that for my architecture input_mean = 0 and input_std = 255.

当我对代码进行上述更改时,我获得了预期的准确率 (90%).

When I did the said changes to my code, I got the accuracy that was expected (90%).

希望这会有所帮助.

这篇关于TensorFlow lite:将模型转换为 tflite 后精度损失很大的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆