在 tensorflow 中多次运行 train op [英] Run train op multiple times in tensorflow

查看:41
本文介绍了在 tensorflow 中多次运行 train op的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一些相当大的批量大小,我想对其进行多个梯度步骤.虽然我可以用 python for 循环轻松地做到这一点,但我想可能有一种更有效的方法,它不涉及在每次迭代时将数据传输到 gpu.我已经尝试多次将火车操作放入获取列表中,但我不确定它是否真的运行了不止一次(运行时间完全相同).

I have some fairly large batch sizes on which I'd like to take multiple gradient steps. While I could easily do this with a python for loop, I imagine that there might be a more efficient method that doesn't involve transferring the data to gpu on each iteration. I've tried putting the train op in the fetch list multiple times, but I'm not sure that it's actually being run more than once (the runtime is exactly the same).

推荐答案

如果您有可变大小的批处理,那么变量不适合保存它,您可以改为在 run 之间保留此数据使用持久张量调用.这是一个玩具示例

If you have variable-sized batch then variable is a bad fit for saving it, and you could instead persist this data between run calls using peristent tensors. Here's a toy example

t = tf.int32
params = tf.Variable(tf.ones_initializer((), dtype=dt))
data_batches = [[1], [2, 3], [4, 5, 6]]

# op that uploads data to TF and saves it as a persistent Tensor
data_saver_placeholder = tf.placeholder(dt)
tensor_handle_op = tf.get_session_handle(data_saver_placeholder)

data_placeholder, data = tf.get_session_tensor(dt)
train_op = tf.assign_add(params, tf.reduce_prod(data)) 
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)

for batch in data_batches:
    # upload tensor to TF runtime and save its handle
    tensor_handle = sess.run(tensor_handle_op, feed_dict={data_saver_placeholder: batch})
    # run train op several times reusing same data
    for i in range(3):
        sess.run(train_op, feed_dict={data_placeholder: tensor_handle.handle})


assert sess.run(params) == 382

这篇关于在 tensorflow 中多次运行 train op的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆