如何在 tensorflow 中正确实现延迟加载? [英] How to correctly implement lazy loading in tensorflow?

查看:39
本文介绍了如何在 tensorflow 中正确实现延迟加载?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

以下代码(同时尝试在 https://danijar.com 中复制代码结构/structuring-your-tensorflow-models/ )

The following code (while trying to replicate code structure in https://danijar.com/structuring-your-tensorflow-models/ )

import tensorflow as tf

class Model:

    def __init__(self, x):
        self.x = x
        self._output = None

    @property
    def output(self):
        if not self._output:
            weight = tf.Variable(tf.constant(4.0))
            bias = tf.Variable(tf.constant(2.0))
            self._output = tf.multiply(self.x, weight) + bias
        return self._output

def main():
    x = tf.placeholder(tf.float32)
    model = Model(x)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        output = sess.run(model.output, {x: 4.0})
        print(output)

if __name__ == '__main__':
    main()

报错.部分内容如下:

gives an error. A part of it is as follows:

Caused by op 'Variable_1/read', defined at:
    File "example.py", line 27, in <module>
        main()
    File "example.py", line 23, in main
        output = sess.run(model.output, {x: 4.0})
    File "example.py", line 12, in output
        weight = tf.Variable(tf.Variable(tf.constant(4.0)))

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1

我该如何解决这个问题?

How do I resolve the issue?

推荐答案

问题是对sess.run(tf.global_variables_initializer())的调用发生在之前变量在下一行第一次调用 model.output 时被创建.

The problem is that the call to sess.run(tf.global_variables_initializer()) happens before the variables are created, in the first call to model.output on the following line.

为了解决这个问题,你必须在调用 sess.run(tf.global_variables_initializer()) 之前以某种方式访问​​ model.output.例如,以下代码有效:

To fix the problem, you must somehow access model.output before calling sess.run(tf.global_variables_initializer()). For example, the following code works:

import tensorflow as tf

class Model:

    def __init__(self, x):
        self.x = x
        self._output = None

    @property
    def output(self):
        # NOTE: You must use `if self._output is None` when `self._output` can
        # be a tensor, because `if self._output` on a tensor object will raise
        # an exception.
        if self._output is None:
            weight = tf.Variable(tf.constant(4.0))
            bias = tf.Variable(tf.constant(2.0))
            self._output = tf.multiply(self.x, weight) + bias
        return self._output

def main():
    x = tf.placeholder(tf.float32)
    model = Model(x)

    # The variables are created on this line.
    output_t = model.output

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        output = sess.run(output_t, {x: 4.0})
        print(output)

if __name__ == '__main__':
    main()

这篇关于如何在 tensorflow 中正确实现延迟加载?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆