TensorFlow对象检测API:从导出的模型检查点训练 [英] Tensorflow Object Detection API: Train from exported model checkpoint
本文介绍了TensorFlow对象检测API:从导出的模型检查点训练的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我以前有一个导出的RetinanNet模型(最初来自对象检测动物园),它已经使用TensorFlow对象检测API(TensorFlow 2.4.1版)在自定义数据集上进行了微调。下面是导出模型的文件夹的外观。
对模型运行评估时(如下所示),MAP@0.5IOU为0.5。
python model_main_tf2.py --model_dir=exported-models/retinanet --pipeline_config_path=exported-models/retinanet/pipeline.config --checkpoint_dir=exported-models/retinanet/checkpoint
问题
由于不幸的情况,我没有培训模型时的培训文件夹。由于我最近得到了更多的数据,我想使用导出的模型作为进一步培训的起点,并已在pipeline.config
中为新培训设置了pipeline.config
:
fine_tune_checkpoint: "exported-models/retinanet/checkpoint/ckpt-0"
num_steps: 25000
startup_delay_steps: 0.0
replicas_to_aggregate: 8
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
fine_tune_checkpoint_type: "detection"
use_bfloat16: false
fine_tune_checkpoint_version: V2
但是,当使用model_main_tf2.py
脚本开始培训时,第一个检查点(位于步骤0)的分数很低--即使是在为导出的模型运行评估的同一数据集上也是如此。
我希望第一个检查点的分数(至少对于相同的测试集)与导出模型的分数相同。这种假设是错误的吗?在这种情况下,原因何在?
推荐答案
我终于找到了here:
// Whether to load all checkpoint vars that match model variable names and
// sizes. This option is only available if `from_detection_checkpoint` is
// True. This option is *not* supported for TF2 --- setting it to true
// will raise an error. **Instead, set fine_tune_checkpoint_type: 'full'.**
optional bool load_all_detection_checkpoint_vars = 19 [default = false];
通过将fine_tune_checkpoint_type
设置为";Full";,我获得了第一个检查点的正确地图(步骤为0)。
这篇关于TensorFlow对象检测API:从导出的模型检查点训练的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文