如何在 tensorflow 中正确实现延迟加载? [英] How to correctly implement lazy loading in tensorflow?
问题描述
以下代码(同时尝试在 https://danijar.com 中复制代码结构/structuring-your-tensorflow-models/ )
The following code (while trying to replicate code structure in https://danijar.com/structuring-your-tensorflow-models/ )
import tensorflow as tf
class Model:
def __init__(self, x):
self.x = x
self._output = None
@property
def output(self):
if not self._output:
weight = tf.Variable(tf.constant(4.0))
bias = tf.Variable(tf.constant(2.0))
self._output = tf.multiply(self.x, weight) + bias
return self._output
def main():
x = tf.placeholder(tf.float32)
model = Model(x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(model.output, {x: 4.0})
print(output)
if __name__ == '__main__':
main()
报错.部分内容如下:
gives an error. A part of it is as follows:
Caused by op 'Variable_1/read', defined at:
File "example.py", line 27, in <module>
main()
File "example.py", line 23, in main
output = sess.run(model.output, {x: 4.0})
File "example.py", line 12, in output
weight = tf.Variable(tf.Variable(tf.constant(4.0)))
FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1
我该如何解决这个问题?
How do I resolve the issue?
推荐答案
问题是对sess.run(tf.global_variables_initializer())
的调用发生在之前变量在下一行第一次调用 model.output
时被创建.
The problem is that the call to sess.run(tf.global_variables_initializer())
happens before the variables are created, in the first call to model.output
on the following line.
为了解决这个问题,你必须在调用 sess.run(tf.global_variables_initializer())
之前以某种方式访问 model.output
.例如,以下代码有效:
To fix the problem, you must somehow access model.output
before calling sess.run(tf.global_variables_initializer())
. For example, the following code works:
import tensorflow as tf
class Model:
def __init__(self, x):
self.x = x
self._output = None
@property
def output(self):
# NOTE: You must use `if self._output is None` when `self._output` can
# be a tensor, because `if self._output` on a tensor object will raise
# an exception.
if self._output is None:
weight = tf.Variable(tf.constant(4.0))
bias = tf.Variable(tf.constant(2.0))
self._output = tf.multiply(self.x, weight) + bias
return self._output
def main():
x = tf.placeholder(tf.float32)
model = Model(x)
# The variables are created on this line.
output_t = model.output
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_t, {x: 4.0})
print(output)
if __name__ == '__main__':
main()
这篇关于如何在 tensorflow 中正确实现延迟加载?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!