同时运行多个预训练的 Tensorflow 网络 [英] Run multiple pre-trained Tensorflow nets at the same time

查看:46
本文介绍了同时运行多个预训练的 Tensorflow 网络的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想做的是同时运行多个预训练的 Tensorflow 网络.因为每个网络内部的某些变量的名称可以相同,所以常见的解决方案是在创建网络时使用名称范围.但是,问题是我已经训练了这些模型并将训练后的变量保存在几个检查点文件中.在创建网络时使用名称范围后,我无法从检查点文件加载变量.

What I would like to do is to run multiple pre-trained Tensorflow nets at the same time. Because the names of some variables inside each net can be the same, the common solution is to use a name scope when I create a net. However, the problem is that I have trained these models and save the trained variables inside several checkpoint files. After I use a name scope when I create the net, I cannot load variables from the checkpoint files.

例如,我训练了一个 AlexNet,我想比较两组变量,一组来自纪元 10(保存在文件 epoch_10.ckpt 中),另一组来自纪元 50(保存在文件 epoch_10.ckpt 中)文件 epoch_50.ckpt).因为这两个是完全一样的net,所以里面的变量名是一样的.我可以使用

For example, I have trained an AlexNet and I would like to compare two sets of variables, one set is from the epoch 10 (saved in the file epoch_10.ckpt) and another set is from the epoch 50 (saved in the file epoch_50.ckpt). Because these two are exactly the same net, the names of variables inside are identical. I can create two nets by using

with tf.name_scope("net1"):
    net1 = CreateAlexNet()
with tf.name_scope("net2"):
    net2 = CreateAlexNet()

但是,我无法从 .ckpt 文件加载经过训练的变量,因为当我训练这个网络时,我没有使用名称范围.尽管我可以在训练网络时将名称范围设置为net1",但这会阻止我加载 net2 的变量.

However, I cannot load the trained variables from .ckpt files because when I trained this net, I did not use a name scope. Even though I can set the name scope to "net1" when I train the net, this prevents me from loading the variables for net2.

我试过了:

with tf.name_scope("net1"):
    mySaver.restore(sess, 'epoch_10.ckpt')
with tf.name_scope("net2"):
    mySaver.restore(sess, 'epoch_50.ckpt')

这不起作用.

解决此问题的最佳方法是什么?

What is the best way to solve this problem?

推荐答案

最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:

The easiest solution is to create different sessions that use separate graphs for each model:

# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
  net1 = CreateAlexNet()
  saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')

# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
  net2 = CreateAlexNet()
  saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')

<小时>

如果由于某种原因这不起作用,并且您必须使用单个 tf.Session(例如,因为您想在另一个 TensorFlow 计算中组合来自两个网络的结果),最好解决方法是:


If this doesn't work for some reason, and you have to use a single tf.Session (e.g. because you want to combine results from the two network in another TensorFlow computation), the best solution is to:

  1. 像您已经在做的那样在名称范围内创建不同的网络,并且
  2. 创建单独的tf.train.Saver 两个网络的实例,带有一个额外的参数来重新映射变量名称.

构建储户,您可以将字典作为 var_list 参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到您的 tf.Variable 对象已在每个模型中创建.

When constructing the savers, you can pass a dictionary as the var_list argument, mapping the names of the variables in the checkpoint (i.e. without the name scope prefix) to the tf.Variable objects you've created in each model.

您可以以编程方式构建 var_list,并且您应该能够执行以下操作:

You can build the var_list programmatically, and you should be able to do something like the following:

with tf.name_scope("net1"):
  net1 = CreateAlexNet()
with tf.name_scope("net2"):
  net2 = CreateAlexNet()

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)

# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")

这篇关于同时运行多个预训练的 Tensorflow 网络的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆