tensorflow:在多个检查点上运行模型评估 [英] tensorflow: run model evaluation over multiple checkpoints

查看:33
本文介绍了tensorflow:在多个检查点上运行模型评估的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在我当前的项目中,我训练一个模型并每 100 个迭代步骤保存一次检查点.检查点文件都保存在同一目录中(model.ckpt-100、model.ckpt-200、model.ckpt-300 等).之后,我想根据所有保存的检查点的验证数据来评估模型,而不仅仅是最新的.

In my current project I train a model and save checkpoints every 100 iteration steps. The checkpoint files are all saved to the same directory (model.ckpt-100, model.ckpt-200 , model.ckpt-300 etc). And after that I would like to evalute the model based on validation data for all the saved checkpoints, not just the latest one.

目前我用于恢复检查点文件的代码如下所示:

Currently my piece of code for restoring the checkpoint file looks like this:

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
ckpt_list = saver.last_checkpoints
print(ckpt_list)
if ckpt and ckpt.model_checkpoint_path:
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)
    # extract global_step from it.
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    print('Succesfully loaded model from %s at step=%s.' %
            (ckpt.model_checkpoint_path, global_step))
else:
    print('No checkpoint file found')
    return

然而,这只会恢复最近保存的检查点文件.那么如何在所有保存的检查点文件上编写一个循环呢?我尝试使用 saver.last_checkpoints 获取检查点文件列表,但是返回的列表为空.

However, this restores only the latest saved checkpoint file. So how do I write a loop over all saved checkpoint files? I tried to get a list of the checkpoint files with saver.last_checkpoints, however, the returned list is empty.

任何帮助将不胜感激,提前致谢!

Any help would be highly appreciated, thanks in advance!

推荐答案

最快的解决方案:

tensor2tensor 有一个模块 utils 和一个脚本 avg_checkpoints.py 保存新检查点的平均权重.假设您有一个想要平均的检查点列表.您有 2 个使用选项:

Fastest solution:

tensor2tensor has a module utils with a script avg_checkpoints.py that saves the averaged weights in a new checkpoint. Let us say you have a list of checkpoints that you want to average. You have 2 options for usage:

  1. 来自命令行

  1. From command line

TRAIN_DIR=path_to_your_model_folder
FNC_PATH=path_to_tensor2tensor+'/utils/avg.checkpoints.py'
CKPTS=model.ckpt-10000,model.ckpt-20000,model.ckpt-100000

python3 $FNC_PATH --prefix=$TRAIN_DIR --checkpoints=$CKPTS \ 
    --output_path="${TRAIN_DIR}averaged.ckpt"

  • 来自您自己的代码(使用 os.system):

    import os
    os.system(
        "python3 "+FNC_DIR+" --prefix="+TRAIN_DIR+" --checkpoints="+CKPTS+
        " --output_path="+TRAIN_DIR+"averaged.ckpt"
    )
    

  • 作为指定检查点列表并使用 --checkpoints 参数的替代方法,您可以只使用 --num_checkpoints=10 来平均最后 10 个检查点.

    As an alternative to specifying a list of checkpoints and using the --checkpoints argument, you can just use --num_checkpoints=10 to average the last 10 checkpoints.

    这是一个不依赖于 tensor2tensor 的代码片段,但仍然可以平均可变数量的检查点(与 ted 的答案相反).假设 steps 是应该合并的检查点列表(例如 [10000, 20000, 30000, 40000]).

    Here is a code snippet that does not rely on tensor2tensor, but can still average a variable number of checkpoints (as opposed to ted's answer). Assume steps is a list of checkpoints that should be merged (e.g. [10000, 20000, 30000, 40000]).

    那么:

    # Restore all sessions and save the weight matrices
    values = []
    for step in steps:
        tf.reset_default_graph()
        path = model_path+'/model.ckpt-'+str(step)
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph(path+'.meta')
            saver.restore(sess, path)
            values.append(sess.run(tf.all_variables()))
    
    # Average weights
    variables = tf.all_variables()
    all_assign = []
    for ind, var in enumerate(variables):
        weights = np.concatenate(
            [np.expand_dims(w[ind],axis=0)  for w in values],
            axis=0
        )
        all_assign.append(tf.assign(var, np.mean(weights, axis=0))
    

    然后你可以继续,但是你喜欢,例如保存平均检查点:

    Then you can proceed, however you prefer, e.g. saving the averaged checkpoint:

    # Now save the new values into a separate checkpoint
    with tf.Session() as sess_test:
        sess_test.run(all_assign)
        saver = tf.train.Saver() 
        saver.save(sess_test, model_path+'/average_'+str(num_checkpoints))
    

    这篇关于tensorflow:在多个检查点上运行模型评估的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

    查看全文
    登录 关闭
    扫码关注1秒登录
    发送“验证码”获取 | 15天全站免登陆