键 <variable_name>在检查点 Tensorflow 中找不到 [英] Key &lt;variable_name&gt; not found in checkpoint Tensorflow

查看:24
本文介绍了键 <variable_name>在检查点 Tensorflow 中找不到的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Tensorflow v1.1,我一直在尝试弄清楚如何使用我的 EMA 权重进行推理,但无论我做什么,我都不断收到错误

I'm using Tensorflow v1.1 and I've been trying to figure out how to use my EMA'ed weights for inference, but no matter what I do I keep getting the error

未找到:在检查点中未找到 W/ExponentialMovingAverage 密钥

Not found: Key W/ExponentialMovingAverage not found in checkpoint

即使当我循环并打印出所有 tf.global_variables 键存在

even though when I loop through and print out all the tf.global_variables the key exists

这是一个可重复的脚本,大量改编自 Facenet 的 单元测试:

Here is a reproducible script heavily adapted from Facenet's unit test:

import tensorflow as tf
import numpy as np


tf.reset_default_graph()

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)

# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
    train_op = tf.group(averages_op)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

saver = tf.train.Saver(tf.trainable_variables())

# Launch the graph.
sess = tf.Session()
sess.run(init)

# Fit the line.
for _ in range(201):
    sess.run(train_op)

w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')

saver.save(sess, os.path.join("model_ex1"))

tf.reset_default_graph()

tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()

print('------------------------------------------------------')
for var in tf.global_variables():
    print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
    print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
    print('ema variable: ' + var.op.name)
print('------------------------------------------------------')

mode = 1
restore_vars = {}
if mode == 0:
    ema = tf.train.ExponentialMovingAverage(1.0)
    for var in tf.trainable_variables():
        print('%s: %s' % (ema.average_name(var), var.op.name))
        restore_vars[ema.average_name(var)] = var
elif mode == 1:
    for var in tf.trainable_variables():
        ema_name = var.op.name + '/ExponentialMovingAverage'
        print('%s: %s' % (ema_name, var.op.name))
        restore_vars[ema_name] = var

saver = tf.train.Saver(restore_vars, name='ema_restore')

saver.restore(sess, os.path.join("model_ex1")) # error happens here!

w_restored = sess.run('W:0')
b_restored = sess.run('b:0')

print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)

推荐答案

key not found in checkpoint 错误意味着该变量存在于内存中的模型中,但不存在于磁盘上的序列化检查点文件中.

The key not found in checkpoint error means that the variable exists in your model in memory but not in the serialized checkpoint file on disk.

您应该使用inspect_checkpoint 工具了解检查点中保存了哪些张量,以及为什么这里没有保存一些指数移动平均线.

You should use the inspect_checkpoint tool to understand what tensors are being saved in your checkpoint, and why some exponential moving averages are not being saved here.

从您的重现示例中不清楚哪一行应该触发错误

It's not clear from your repro example which line is supposed to trigger the error

这篇关于键 <variable_name>在检查点 Tensorflow 中找不到的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆