键 <variable_name>在检查点 Tensorflow 中找不到 [英] Key <variable_name> not found in checkpoint Tensorflow
问题描述
我正在使用 Tensorflow v1.1,我一直在尝试弄清楚如何使用我的 EMA 权重进行推理,但无论我做什么,我都不断收到错误
I'm using Tensorflow v1.1 and I've been trying to figure out how to use my EMA'ed weights for inference, but no matter what I do I keep getting the error
未找到:在检查点中未找到 W/ExponentialMovingAverage 密钥
Not found: Key W/ExponentialMovingAverage not found in checkpoint
即使当我循环并打印出所有 tf.global_variables
键存在
even though when I loop through and print out all the tf.global_variables
the key exists
这是一个可重复的脚本,大量改编自 Facenet 的 单元测试:
Here is a reproducible script heavily adapted from Facenet's unit test:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)
# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
train_op = tf.group(averages_op)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.trainable_variables())
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for _ in range(201):
sess.run(train_op)
w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')
saver.save(sess, os.path.join("model_ex1"))
tf.reset_default_graph()
tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()
print('------------------------------------------------------')
for var in tf.global_variables():
print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
print('ema variable: ' + var.op.name)
print('------------------------------------------------------')
mode = 1
restore_vars = {}
if mode == 0:
ema = tf.train.ExponentialMovingAverage(1.0)
for var in tf.trainable_variables():
print('%s: %s' % (ema.average_name(var), var.op.name))
restore_vars[ema.average_name(var)] = var
elif mode == 1:
for var in tf.trainable_variables():
ema_name = var.op.name + '/ExponentialMovingAverage'
print('%s: %s' % (ema_name, var.op.name))
restore_vars[ema_name] = var
saver = tf.train.Saver(restore_vars, name='ema_restore')
saver.restore(sess, os.path.join("model_ex1")) # error happens here!
w_restored = sess.run('W:0')
b_restored = sess.run('b:0')
print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)
推荐答案
key not found in checkpoint
错误意味着该变量存在于内存中的模型中,但不存在于磁盘上的序列化检查点文件中.
The key not found in checkpoint
error means that the variable exists in your model in memory but not in the serialized checkpoint file on disk.
您应该使用inspect_checkpoint 工具了解检查点中保存了哪些张量,以及为什么这里没有保存一些指数移动平均线.
You should use the inspect_checkpoint tool to understand what tensors are being saved in your checkpoint, and why some exponential moving averages are not being saved here.
从您的重现示例中不清楚哪一行应该触发错误
It's not clear from your repro example which line is supposed to trigger the error
这篇关于键 <variable_name>在检查点 Tensorflow 中找不到的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!