创建用元数据填充的Tflite模型时出现的问题(用于对象检测) [英] Issue in creating Tflite model populated with metadata (for object detection)

查看:302
本文介绍了创建用元数据填充的Tflite模型时出现的问题(用于对象检测)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在Android上运行tflite模型以进行对象检测.同样,

I am trying to run a tflite model on Android for object detection. For the same,

  1. 我已经用我的图像集成功地训练了模型,如下所示:

(a)培训:

!python3 object_detection/model_main.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--model_dir=training/

(修改配置文件以指向提到我的特定TF记录的位置)

(modifying the config file to point to where my specific TFrecords are mentioned)

(b)导出推理图

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path}

(c)创建tflite就绪图

(c) Create tflite ready graph

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_tflite_ssd_graph.py \
  --pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
  --output_directory={output_directory} \
  --trained_checkpoint_prefix={last_model_path} \
  --add_postprocessing_op=true

  1. 我使用如下图文件中的tflite_convert创建了tflite模型

  1. I have created a tflite model using tflite_convert from the graph file as follows

!tflite_convert
--graph_def_file =/content/drive/我的\ Drive/Detecto \教程/模型/研究/fine_tuned_model/tflite_graph.pb
--output_file =/content/drive/我的\ Drive/Detecto \教程/模型/研究/fine_tuned_model/detect3.tflite
--output_format = TFLITE
--input_shapes = 1,300,300,3
--input_arrays = normalized_input_image_tensor
--output_arrays ='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
--inference_type = FLOAT
--allow_custom_ops

!tflite_convert
--graph_def_file=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/fine_tuned_model/tflite_graph.pb
--output_file=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/fine_tuned_model/detect3.tflite
--output_format=TFLITE
--input_shapes=1,300,300,3
--input_arrays=normalized_input_image_tensor
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
--inference_type=FLOAT
--allow_custom_ops

上述tflite模型经过独立验证并且可以正常运行(在Android外部).

The above tflite model is validated independently and works fine (outside of Android).

现在需要用元数据填充tflite模型,以便可以按照以下每个链接提供的示例Android代码来处理它(因为我遇到了错误:否则,不是有效的Zip文件,并且没有关联在Android Studio上运行的文件).

There is a requirement now to populate the tflite model with metadata so that it can be processed in the sample Android code provided as per link below (as I am getting an error otherwise: not a valid Zip file and does not have associated files when run on Android studio).

https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md

作为同一链接的一部分提供的示例.TFlite填充了元数据,并且运行良好.

The sample .TFlite provided as part of the same link is populated with metadata and works fine.

当我尝试使用以下链接时: https://www.tensorflow.org/lite/convert/metadata#deep_dive_into_the_image_classification_example

When I try to use the following link: https://www.tensorflow.org/lite/convert/metadata#deep_dive_into_the_image_classification_example

populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
populator.populate()

添加元数据(代码的其余部分实际上与对对象检测(而不是图像分类和指定labelmap.txt的位置)的元描述的某些更改几乎相同),它会产生以下错误:

to add metadata (rest of the code is practically the same with some changes of meta description to Object detection instead of Image classification and specifying the location of labelmap.txt), it gives the following error:

<ipython-input-6-173fc798ea6e> in <module>()
  1 populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
  ----> 2 populator.load_metadata_buffer(metadata_buf)
        3 populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
        4 populator.populate()

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_lite_support/metadata/metadata.py in _validate_metadata(self, metadata_buf)
    540           "The number of output tensors ({0}) should match the number of "
    541           "output tensor metadata ({1})".format(num_output_tensors,
--> 542                                                 num_output_meta))
    543 
    544 

ValueError: The number of output tensors (4) should match the number of output tensor metadata (1)

4个输出张量是第2步中output_arrays中提到的张量(有人可以在那里纠正我).我不确定如何相应地更新输出张量元数据.

The 4 output tensors are the ones mentioned in the output_arrays in step 2 (someone may correct me there). I am not sure how to update output tensor metadata accordingly.

最近使用自定义模型(然后在Android上应用)进行对象检测的任何人都可以提供帮助吗?或帮助您了解如何将张量元数据更新为4而不是1.

Can anyone who has recently used object detection using custom model (and then apply on Android) help? Or help understand how to update tensor metadata to 4 instead of 1.

推荐答案

更新:

元数据编写器库具有被释放.它目前支持图像分类器和对象检测器,并且正在支持更多的任务.

The Metadata Writer library has been released. It currently supports image classifier and object detector, and more supported tasks are on the way.

以下是为对象检测器模型编写元数据的示例:

Here is an example to write metadata for an object detector model:

  1. 安装TFLite支持的每晚Pypi软件包:

pip install tflite_support_nightly

  1. 使用以下脚本将元数据写入模型:

来自tflite_support.metadata_writers的

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1_1_default_1.tflite"
_LABEL_FILE = "labelmap.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_1_default_1_metadata.tflite"

writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [127.5], [127.5], [_LABEL_FILE])
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

# Verify the populated metadata and associated files.
displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())

----------先前的手动写入元数据的答案--------

---------- Previous answer that writes metadata manually --------

这是一个代码段,您可以用来填充用于对象检测模型的元数据,该代码段与TFLite Android应用程序兼容.

Here is a code snippet you can use to populate metadata for object detection models, which is compatible with the TFLite Android app.

model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "SSD_Detector"
model_meta.description = (
    "Identify which of a known set of objects might be present and provide "
    "information about their positions within the given image or a video "
    "stream.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

# Creates outputs info.
output_location_meta = _metadata_fb.TensorMetadataT()
output_location_meta.name = "location"
output_location_meta.description = "The locations of the detected boxes."
output_location_meta.content = _metadata_fb.ContentT()
output_location_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.BoundingBoxProperties)
output_location_meta.content.contentProperties = (
    _metadata_fb.BoundingBoxPropertiesT())
output_location_meta.content.contentProperties.index = [1, 0, 3, 2]
output_location_meta.content.contentProperties.type = (
    _metadata_fb.BoundingBoxType.BOUNDARIES)
output_location_meta.content.contentProperties.coordinateType = (
    _metadata_fb.CoordinateType.RATIO)
output_location_meta.content.range = _metadata_fb.ValueRangeT()
output_location_meta.content.range.min = 2
output_location_meta.content.range.max = 2

output_class_meta = _metadata_fb.TensorMetadataT()
output_class_meta.name = "category"
output_class_meta.description = "The categories of the detected boxes."
output_class_meta.content = _metadata_fb.ContentT()
output_class_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_class_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_class_meta.content.range = _metadata_fb.ValueRangeT()
output_class_meta.content.range.min = 2
output_class_meta.content.range.max = 2
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("label.txt")
label_file.description = "Label of objects that this model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
output_class_meta.associatedFiles = [label_file]

output_score_meta = _metadata_fb.TensorMetadataT()
output_score_meta.name = "score"
output_score_meta.description = "The scores of the detected boxes."
output_score_meta.content = _metadata_fb.ContentT()
output_score_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_score_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_score_meta.content.range = _metadata_fb.ValueRangeT()
output_score_meta.content.range.min = 2
output_score_meta.content.range.max = 2

output_number_meta = _metadata_fb.TensorMetadataT()
output_number_meta.name = "number of detections"
output_number_meta.description = "The number of the detected boxes."
output_number_meta.content = _metadata_fb.ContentT()
output_number_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_number_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())

# Creates subgraph info.
group = _metadata_fb.TensorGroupT()
group.name = "detection result"
group.tensorNames = [
    output_location_meta.name, output_class_meta.name,
    output_score_meta.name
]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
    output_location_meta, output_class_meta, output_score_meta,
    output_number_meta
]
subgraph.outputTensorGroups = [group]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()

这篇关于创建用元数据填充的Tflite模型时出现的问题(用于对象检测)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆