计算 tf.while_loop 的每个时间步长的梯度 [英] Compute gradients for each time step of tf.while_loop

查看:29
本文介绍了计算 tf.while_loop 的每个时间步长的梯度的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给定一个 TensorFlow tf.while_loop,我如何计算 x_out 相对于网络的每个时间步的所有权重的梯度?

Given a TensorFlow tf.while_loop, how can I calculate the gradient of x_out with respect to all weights of the network for each time step?

network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)

weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0

def condition(steps, x):
    return steps <= 5

def loop(steps, x_in):
    weight_1 = tf.Variable(1.0)
    x_out = x_in * weight_1
    steps += 1
    return [steps, x_out]

_, x_final = tf.while_loop(
    condition,
    loop,
    [steps, layer_1]
)

一些注意事项

  1. 在我的网络中,情况是动态的.不同的运行将运行 while 循环的次数不同.
  2. 调用 tf.gradients(x, tf.trainable_variables()) 会因 AttributeError: 'WhileContext' object has no attribute 'pred' 而崩溃.似乎在循环中使用 tf.gradients 的唯一可能性是计算关于 weight_1 的梯度和 x_in 的当前值/仅时间步长,无需通过时间进行反向传播.
  3. 在每个时间步中,网络将输出动作的概率分布.然后需要梯度来实现策略梯度.
  1. In my network the condition is dynamic. Different runs are going to run the while loop a different amount of times.
  2. Calling tf.gradients(x, tf.trainable_variables()) crashes with AttributeError: 'WhileContext' object has no attribute 'pred'. It seems like the only possibility to use tf.gradients within the loop is to calculate the gradient with respect to weight_1 and the current value of x_in / time step only without backpropagating through time.
  3. In each time step, the network is going to output a probability distribution over actions. The gradients are then needed for a policy gradient implementation.

推荐答案

你永远不能在基于 this 和 this,当我试图在 Tensorflow 图中完全创建共轭梯度下降时,我发现了这一点.

You can't ever call tf.gradients inside tf.while_loop in Tensorflow based on this and this, I found this out the hard way when I was trying to create conjugate gradient descent entirely into the Tensorflow graph.

但如果我正确理解您的模型,您可以制作自己的 RNNCell 版本并将其包装在 tf.dynamic_rnn 中,但实际单元格实现会有点复杂,因为您需要在运行时动态评估条件.

But if I understand your model correctly, you could make your own version of an RNNCell and wrap it in a tf.dynamic_rnn, but the actual cell implementation will be a little complex since you need to evaluate a condition dynamically at runtime.

对于初学者,您可以查看 Tensorflow 的 dynamic_rnn 代码 这里.

For starters, you can take a look at Tensorflow's dynamic_rnn code here.

或者,动态图从来都不是 Tensorflow 的强大套件,因此请考虑使用其他框架,例如 PyTorch,或者您可以尝试使用 eager_execution看看这是否有帮助.

Alternatively, dynamic graphs have never been Tensorflow's strong suite, so consider using other frameworks like PyTorch or you can try out eager_execution and see if that helps.

这篇关于计算 tf.while_loop 的每个时间步长的梯度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆