TensorFlow对象检测API:从导出的模型检查点训练 [英] Tensorflow Object Detection API: Train from exported model checkpoint

查看:13
本文介绍了TensorFlow对象检测API:从导出的模型检查点训练的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我以前有一个导出的RetinanNet模型(最初来自对象检测动物园),它已经使用TensorFlow对象检测API(TensorFlow 2.4.1版)在自定义数据集上进行了微调。下面是导出模型的文件夹的外观。

对模型运行评估时(如下所示),MAP@0.5IOU为0.5。

python model_main_tf2.py --model_dir=exported-models/retinanet --pipeline_config_path=exported-models/retinanet/pipeline.config --checkpoint_dir=exported-models/retinanet/checkpoint

问题

由于不幸的情况,我没有培训模型时的培训文件夹。由于我最近得到了更多的数据,我想使用导出的模型作为进一步培训的起点,并已在pipeline.config中为新培训设置了pipeline.config

  fine_tune_checkpoint: "exported-models/retinanet/checkpoint/ckpt-0"
  num_steps: 25000
  startup_delay_steps: 0.0
  replicas_to_aggregate: 8
  max_number_of_boxes: 100
  unpad_groundtruth_tensors: false
  fine_tune_checkpoint_type: "detection"
  use_bfloat16: false
  fine_tune_checkpoint_version: V2

但是,当使用model_main_tf2.py脚本开始培训时,第一个检查点(位于步骤0)的分数很低--即使是在为导出的模型运行评估的同一数据集上也是如此。

我希望第一个检查点的分数(至少对于相同的测试集)与导出模型的分数相同。这种假设是错误的吗?在这种情况下,原因何在?

推荐答案

我终于找到了here

// Whether to load all checkpoint vars that match model variable names and
// sizes. This option is only available if `from_detection_checkpoint` is
// True.  This option is *not* supported for TF2 --- setting it to true
// will raise an error. **Instead, set fine_tune_checkpoint_type: 'full'.**
  optional bool load_all_detection_checkpoint_vars = 19 [default = false];

通过将fine_tune_checkpoint_type设置为";Full";,我获得了第一个检查点的正确地图(步骤为0)。

这篇关于TensorFlow对象检测API:从导出的模型检查点训练的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆