计算 tf.while_loop 的每个时间步长的梯度 [英] Compute gradients for each time step of tf.while_loop
问题描述
给定一个 TensorFlow tf.while_loop
,我如何计算 x_out
相对于网络的每个时间步的所有权重的梯度?
Given a TensorFlow tf.while_loop
, how can I calculate the gradient of x_out
with respect to all weights of the network for each time step?
network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)
weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0
def condition(steps, x):
return steps <= 5
def loop(steps, x_in):
weight_1 = tf.Variable(1.0)
x_out = x_in * weight_1
steps += 1
return [steps, x_out]
_, x_final = tf.while_loop(
condition,
loop,
[steps, layer_1]
)
一些注意事项
- 在我的网络中,情况是动态的.不同的运行将运行 while 循环的次数不同.
- 调用
tf.gradients(x, tf.trainable_variables())
会因AttributeError: 'WhileContext' object has no attribute 'pred'
而崩溃.似乎在循环中使用tf.gradients
的唯一可能性是计算关于weight_1
的梯度和x_in
的当前值/仅时间步长,无需通过时间进行反向传播. - 在每个时间步中,网络将输出动作的概率分布.然后需要梯度来实现策略梯度.
- In my network the condition is dynamic. Different runs are going to run the while loop a different amount of times.
- Calling
tf.gradients(x, tf.trainable_variables())
crashes withAttributeError: 'WhileContext' object has no attribute 'pred'
. It seems like the only possibility to usetf.gradients
within the loop is to calculate the gradient with respect toweight_1
and the current value ofx_in
/ time step only without backpropagating through time. - In each time step, the network is going to output a probability distribution over actions. The gradients are then needed for a policy gradient implementation.
推荐答案
你永远不能在基于 Tensorflow
图中完全创建共轭梯度下降时,我发现了这一点.
You can't ever call tf.gradients
inside tf.while_loop
in Tensorflow based on this and this, I found this out the hard way when I was trying to create conjugate gradient descent entirely into the Tensorflow
graph.
但如果我正确理解您的模型,您可以制作自己的 RNNCell
版本并将其包装在 tf.dynamic_rnn
中,但实际单元格实现会有点复杂,因为您需要在运行时动态评估条件.
But if I understand your model correctly, you could make your own version of an RNNCell
and wrap it in a tf.dynamic_rnn
, but the actual cell
implementation will be a little complex since you need to evaluate a condition dynamically at runtime.
对于初学者,您可以查看 Tensorflow 的 dynamic_rnn
代码 这里.
For starters, you can take a look at Tensorflow's dynamic_rnn
code here.
或者,动态图从来都不是 Tensorflow
的强大套件,因此请考虑使用其他框架,例如 PyTorch
,或者您可以尝试使用 eager_execution
看看这是否有帮助.
Alternatively, dynamic graphs have never been Tensorflow
's strong suite, so consider using other frameworks like PyTorch
or you can try out eager_execution
and see if that helps.
这篇关于计算 tf.while_loop 的每个时间步长的梯度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!