了解 TensorFlow 检查点加载? [英] Understanding TensorFlow checkpoint loading?

查看:33
本文介绍了了解 TensorFlow 检查点加载?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

TF 检查点中包含什么?例如,估算器存储一个包含 GraphDef proto 的单独文件,您基本上可以执行 tf.import_graph_def(),然后创建一个 tf.train.Saver() 并将检查点恢复到图中.现在,如果您有另一个 GraphDef 描述了一个完全不同的图,而该图恰好共享完全相同的变量名称以及匹配的变量维度,您是否能够将检查点加载到该图中?换句话说,它只是一个变量名到值的映射,还是假设在加载过程中会检查图形的其他内容?如果您尝试将检查点加载到作为原始图子集的图中(即张量维度和名称匹配,但缺少某些名称)怎么办?

What's contained in a TF checkpoint? Estimators for example store a separate file that contains the GraphDef proto and you can basically do a tf.import_graph_def(), then create a tf.train.Saver() and restore a checkpoint into the graph. Now if you have another GraphDef describing a completely different graph that just happens to share the exact same variable names together with matching variable dimensions, will you be able to load the checkpoint into that graph? In other words, is it just a variable name to value mapping or does it assume something else about a graph that would be checked during loading? What if you try to load a checkpoint into a graph that is a subset of the original graph (i.e. tensor dimensions and names match, but some names are missing)?

推荐答案

人们什么时候开始阅读文档(?):https://www.tensorflow.org/mobile/prepare_models

When do people start reading the documentation (?): https://www.tensorflow.org/mobile/prepare_models

这些是不同的概念.只要形状匹配,您就可以只加载权重.如果有错配,你会得到:

These are different concepts. You can load just the weights as long as the shapes match. If there is a miss-match you just get:

从检查点恢复失败.这很可能是由于当前图形与来自检查点的图形之间不匹配.请确保您没有更改基于检查站.

Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

但是,您可以调整一个非平凡的情况,其中图形完全不同:

However, you can tweak a non-trivial case, where the graph is completely different:

import tensorflow as tf
import numpy as np

test_data = np.arange(4).reshape(1, 2, 2, 1)

# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(output, {input: test_data}))
  saver = tf.train.Saver()
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print(tf.trainable_variables())

# reset previous elements
tf.reset_default_graph()

# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(output, {input: test_data}))

就我而言,我得到了:

# 1st version (original graph)
[[[[-0.         -0.         -0.        ]
   [-0.08429337 -1.0156475  -0.42691123]]

  [[-0.16858673 -2.031295   -0.85382247]
   [-0.2528801  -3.0469427  -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475  -0.42691123]
   [ 0.91570663 -0.01564753  0.57308877]]

  [[ 1.9157066   0.98435247  1.5730888 ]
   [ 2.9157066   1.9843525   2.5730886 ]]]]

这篇关于了解 TensorFlow 检查点加载?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆