现在加载在 Eager TensorFlow 中是否已损坏? [英] Is loading in eager TensorFlow broken right now?

查看:34
本文介绍了现在加载在 Eager TensorFlow 中是否已损坏?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

从 tf.keras.Model 继承的类中的权重目前似乎无法加载.我无法使用检查点从类外的 Example() 加载权重,所以我尝试在内部进行,所有帐户都应该这样做.它能够保存权重,就像在保存 Example() 时一样,但它仍然无法加载它们.这是我的模型代码:

Weights in classes inheriting from tf.keras.Model seem unable to load at the moment. I am unable to load the weights from Example() outside of the class using checkpointing, so I tried to do it within, which by all accounts should work. Its able to save the weights, as it can when just saving Example(), however it still can't load them. This is my model code:

class Example(tf.keras.Model):
    def __init__(self, cfg):
        super(Example, self).__init__()

        self.model = tf.keras.Sequential([
             ........layers.......
        ])

        # Create saver
        self.save_path = cfg.save_dir + cfg.extension
        self.ckpt_prefix = self.save_path + '/ckpt'
        self.saver = tf.train.Checkpoint(model=self.model)

    def call(self, x_in):
        x_out = self.model(x_in)
        return x_out

    def save(self):
        self.saver.save(file_prefix=self.ckpt_prefix)

    def load(self):
        self.saver.restore(tf.train.latest_checkpoint(self.save_path))

这是我用来检查它是否加载的:

And this is what I use to check if it loads:

example = Example()
if Path(self.example.save_path).is_dir():
            print(self.example.weights)
            print(self.example.model.weights)
            self.example.load()
            print(self.example.weights)
            print(self.example.model.weights)

输出:

[]
[]
[]
[]

这在 tensorflow 1.3 和 2.0 上都进行了测试,我可以确认第一批后权重不为空,并且正在检查点/保存.

This was tested on both tensorflow 1.3 and 2.0, and I can confirm that the weights are not empty after the first batch, as well as that it is checkpointing/saving.

推荐答案

事实证明,TensorFlow 有三种不同的检查点方式,具体取决于检查点的内容.

As it turns out, there are three different ways TensorFlow does checkpointing, depending on what is being checkpointed.

  1. 检查点对象只是一个变量.这在调用 checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)) 后立即恢复.

检查点对象是一个定义了输入形状的模型.这也会立即恢复.

The checkpointed object is a model with input shape defined. This is also restored immediately.

检查点对象是一个没有定义输入形状的模型.这是行为发生变化的地方,因为 TensorFlow 会进行延迟"恢复,并且在输入传递给模型之前不会恢复模型权重.

The checkpointed object is a model without input shape defined. This is where the behaviour changes, as TensorFlow does a "delayed" restore, and will NOT restore the model weights until input is passed to the model.

这是一个例子:

import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Create model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
    logits = model(img)
    loss = tf.losses.mean_squared_error(truth, logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)

# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')

# Check if weights update
print("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

# This next line is REQUIRED to restore
#model(img)

print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()

有输出:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
  File "test.py", line 58, in <module>
    status.assert_consumed()
  File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  name: "VARIABLE_VALUE"
  full_name: "sequential/conv2d/kernel"
  checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}

但是,取消注释行 model(img) 将产生以下输出:

However, uncommenting the line model(img) will produce the following output:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>

因此需要传递输入数据以正确恢复形状不变模型.

So input data needs to be passed to properly restore a shape invariant model.

参考文献:

https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorationshttps://github.com/tensorflow/tensorflow/issues/27937

这篇关于现在加载在 Eager TensorFlow 中是否已损坏?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆