在 tensorflow 中多次运行 train op [英] Run train op multiple times in tensorflow
问题描述
我有一些相当大的批量大小,我想对其进行多个梯度步骤.虽然我可以用 python for 循环轻松地做到这一点,但我想可能有一种更有效的方法,它不涉及在每次迭代时将数据传输到 gpu.我已经尝试多次将火车操作放入获取列表中,但我不确定它是否真的运行了不止一次(运行时间完全相同).
I have some fairly large batch sizes on which I'd like to take multiple gradient steps. While I could easily do this with a python for loop, I imagine that there might be a more efficient method that doesn't involve transferring the data to gpu on each iteration. I've tried putting the train op in the fetch list multiple times, but I'm not sure that it's actually being run more than once (the runtime is exactly the same).
推荐答案
如果您有可变大小的批处理,那么变量不适合保存它,您可以改为在 run
之间保留此数据使用持久张量调用.这是一个玩具示例
If you have variable-sized batch then variable is a bad fit for saving it, and you could instead persist this data between run
calls using peristent tensors. Here's a toy example
t = tf.int32
params = tf.Variable(tf.ones_initializer((), dtype=dt))
data_batches = [[1], [2, 3], [4, 5, 6]]
# op that uploads data to TF and saves it as a persistent Tensor
data_saver_placeholder = tf.placeholder(dt)
tensor_handle_op = tf.get_session_handle(data_saver_placeholder)
data_placeholder, data = tf.get_session_tensor(dt)
train_op = tf.assign_add(params, tf.reduce_prod(data))
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)
for batch in data_batches:
# upload tensor to TF runtime and save its handle
tensor_handle = sess.run(tensor_handle_op, feed_dict={data_saver_placeholder: batch})
# run train op several times reusing same data
for i in range(3):
sess.run(train_op, feed_dict={data_placeholder: tensor_handle.handle})
assert sess.run(params) == 382
这篇关于在 tensorflow 中多次运行 train op的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!