用于训练的 Tensorflow while_loop [英] Tensorflow while_loop for training

查看:36
本文介绍了用于训练的 Tensorflow while_loop的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在我的问题中,我需要在每个训练步骤的数据中使用 1 个示例运行 GD.众所周知, session.run() 有开销,因此训练模型时间太长.为了避免开销,我尝试使用 while_loop 并通过一次 run() 调用对所有数据训练模型.但它的方法不起作用,并且 train_op 甚至不执行.下面是我正在做的简单示例:

In my problem I need run GD with 1 example from data on each training step. It's known problem that session.run() has overhead and therefore it is too long to train model. In attempt to avoid overhead I tried to use while_loop and train model on all data with one run() call. But it approach don't work and train_op don't execute even ones. Below simple example of what I'm doing:

data = [k*1. for k in range(10)]
tf.reset_default_graph()

i = tf.Variable(0, name='loop_i')
q_x = tf.FIFOQueue(100000, tf.float32)
q_y = tf.FIFOQueue(100000, tf.float32)

x = q_x.dequeue()
y = q_y.dequeue()
w = tf.Variable(0.)
b = tf.Variable(0.)
loss = (tf.add(tf.mul(x, w), b) - y)**2

gs = tf.Variable(0)

train_op = tf.train.GradientDescentOptimizer(0.05).minimize(loss, global_step=gs)

s = tf.Session()
s.run(tf.initialize_all_variables())

def cond(i):
    return i < 10

def body(i):
    return tf.tuple([tf.add(i, 1)], control_inputs=[train_op])


loop = tf.while_loop(cond, body, [i])

for _ in range(1):
    s.run(q_x.enqueue_many((data, )))
    s.run(q_y.enqueue_many((data, )))

s.run(loop)
s.close()

我做错了什么?或者这个问题有另一个解决方案,开销太高?

What I'm doing wrong? Or there is another solution of this problem with too expensive overhead?

谢谢!

推荐答案

模型没有出现训练的原因是因为输入读取、梯度计算、minimize()调用都是定义外部(因此,在数据流术语中,之前)tf.while_loop() 的主体.这意味着模型的所有这些部分在循环执行之前只运行一次,循环本身没有任何影响.

The reason the model does not appear to train is because the input reading, gradient calculation, and the minimize() call are all defined outside (and hence, in dataflow terms, before) the body of the tf.while_loop(). This means that all of these parts of the model run only once, before the loop executes, and the loop itself has no effect.

稍微重构—将 dequeue() 操作、梯度计算和 minimize() 调用移动到循环内—修复了问题并允许您的程序进行训练:

A slight refactoring—to move the dequeue() operations, gradient calculation, and minimize() call inside the loop—fixes the problem and allows your program to train:

optimizer = tf.train.GradientDescentOptimizer(0.05)

def cond(i):
    return i < 10

def body(i):
    # Dequeue a new example each iteration.
    x = q_x.dequeue()
    y = q_y.dequeue()

    # Compute the loss and gradient update based on the current example.
    loss = (tf.add(tf.mul(x, w), b) - y)**2
    train_op = optimizer.minimize(loss, global_step=gs)

    # Ensure that the update is applied before continuing.
    return tf.tuple([tf.add(i, 1)], control_inputs=[train_op])

loop = tf.while_loop(cond, body, [i])

<小时>

更新:这是一个完整的程序,它根据您问题中的代码执行 while 循环:


UPDATE: Here's a complete program the executes the while loop, based on the code in your question:

import tensorflow as tf

# Define a single queue with two components to store the input data.
q_data = tf.FIFOQueue(100000, [tf.float32, tf.float32])

# We will use these placeholders to enqueue input data.
placeholder_x = tf.placeholder(tf.float32, shape=[None])
placeholder_y = tf.placeholder(tf.float32, shape=[None])
enqueue_data_op = q_data.enqueue_many([placeholder_x, placeholder_y])

gs = tf.Variable(0)
w = tf.Variable(0.)
b = tf.Variable(0.)
optimizer = tf.train.GradientDescentOptimizer(0.05)

# Construct the while loop.
def cond(i):
    return i < 10

def body(i):
    # Dequeue a single new example each iteration.
    x, y = q_data.dequeue()
    # Compute the loss and gradient update based on the current example.
    loss = (tf.add(tf.multiply(x, w), b) - y) ** 2
    train_op = optimizer.minimize(loss, global_step=gs)
    # Ensure that the update is applied before continuing.
    with tf.control_dependencies([train_op]):
        return i + 1

loop = tf.while_loop(cond, body, [tf.constant(0)])

data = [k * 1. for k in range(10)]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(1):
        # NOTE: Constructing the enqueue op ahead of time avoids adding
        # (potentially many) copies of `data` to the graph.
        sess.run(enqueue_data_op,
                 feed_dict={placeholder_x: data, placeholder_y: data})
    print (sess.run([gs, w, b]))  # Prints before-loop values.
    sess.run(loop)
    print (sess.run([gs, w, b]))  # Prints after-loop values.

这篇关于用于训练的 Tensorflow while_loop的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆