Tensorflow:有没有办法加载预训练模型而不必重新定义所有变量? [英] Tensorflow: is there a way to load a pretrained model without having to redefine all the variables?

查看:45
本文介绍了Tensorflow:有没有办法加载预训练模型而不必重新定义所有变量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将我的代码拆分为不同的模块,一个用于训练模型,另一个用于分析模型中的权重.

I'm trying to split my code into different modules, one where the model is trained, another which analyzes the weights in the model.

当我使用

save_path = saver.save(sess, "checkpoints5/text8.ckpt")

它生成了 4 个文件,['checkpoint', 'text8.ckpt.data-00000-of-00001', 'text8.ckpt.meta', 'text8.ckpt.index']

It makes 4 files, ['checkpoint', 'text8.ckpt.data-00000-of-00001', 'text8.ckpt.meta', 'text8.ckpt.index']

我尝试使用此代码在单独的模块中恢复它

I tried restoring this in the separate module using this code

train_graph = tf.Graph()
with train_graph.as_default():
    saver = tf.train.Saver()


with tf.Session(graph=train_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('MODEL4'))
    embed_mat = sess.run(embedding)

但我收到此错误消息

ValueError                                Traceback (most recent call last)
<ipython-input-15-deaad9b67888> in <module>()
      1 train_graph = tf.Graph()
      2 with train_graph.as_default():
----> 3     saver = tf.train.Saver()
      4 
      5 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
   1309           time.time() + self._keep_checkpoint_every_n_hours * 3600)
   1310     elif not defer_build:
-> 1311       self.build()
   1312     if self.saver_def:
   1313       self._check_saver_def()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in build(self)
   1318     if context.executing_eagerly():
   1319       raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1320     self._build(self._filename, build_save=True, build_restore=True)
   1321 
   1322   def _build_eager(self, checkpoint_path, build_save, build_restore):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
   1343           return
   1344         else:
-> 1345           raise ValueError("No variables to save")
   1346       self._is_empty = False
   1347 

ValueError: No variables to save

阅读完这个问题后,我似乎需要重新定义训练模型时使用的所有变量.

After reading up on this issue, it seems that I need to redefine all the variables used when I trained the model.

有没有办法不用重新定义一切就可以访问权重?权重只是数字,肯定有一种方法可以直接访问它们吗?

Is there a way to access the weights without having to redefine everything? The weights are just numbers, surely there must be a way to access them directly?

推荐答案

为了仅访问检查点中的变量,请查看 checkpoint_utils 库.它提供了三个有用的 api 函数:load_checkpointlist_variablesload_variable.我不确定是否有更好的方法,但您当然可以使用这些函数来提取检查点中所有变量的字典,如下所示:

For just accessing variables in checkpoints, please checkout the checkpoint_utils library. It provides three useful api function: load_checkpoint, list_variables and load_variable. I'm not sure if there is a better way but you can certainly use these functions to extract a dict of all variables in a checkpoint like this:

import tensorflow as tf

ckpt = 'checkpoints5/text8.ckpt'
var_dict = {name: tf.train.load_checkpoint(ckpt).get_tensor(name)
            for name, _ in tf.train.list_variables(ckpt)}
print(var_dict)

要加载预训练模型而不必重新定义所有变量,您需要的不仅仅是检查点.检查点只有变量,它不知道如何恢复这些变量,即如何将它们映射到图形,而没有实际图形(和适当的地图).SavedModel 更适合这种情况.它可以保存模型 MetaGraph 和所有变量.恢复保存的模型时,您不必手动重新定义所有内容.以下代码是仅使用 simple_save.

To load a pretrained model without having to redefine all the variables, you will need more than just checkpoints. A checkpoint has only variables and it doesn't how to restore these variables, i.e. how to map them to a graph, without an actual graph (and an appropriate map). SavedModel will be better for this scenario. It can save both the model MetaGraph and all variables. You don't have to manually redefine everything when restoring the saved model. The following code is an example using just the simple_save.

要保存经过训练的模型:

To save a trained model:

import tensorflow as tf

x = tf.placeholder(tf.float32)
y_ = tf.reshape(x, [-1, 1])
y_ = tf.layers.dense(y_, units=1)
loss = tf.losses.mean_squared_error(labels=x, predictions=y_)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for _ in range(10):
        sess.run(train_op, feed_dict={x: range(10)})
    # Let's check the bias here so that we can make sure
    # the model we restored later on is indeed our trained model here.
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))
    tf.saved_model.simple_save(sess, 'test', inputs={"x": x}, outputs={"y": y_})

要恢复保存的模型:

import tensorflow as tf

with tf.Session(graph=tf.Graph()) as sess:
    # A model saved by simple_save will be treated as a graph for inference / serving,
    # i.e. uses the tag tag_constants.SERVING
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'test')
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))

这篇关于Tensorflow:有没有办法加载预训练模型而不必重新定义所有变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆