通过多次向前传播进行反向传播 [英] Backpropagating through multiple forward passes

查看:149
本文介绍了通过多次向前传播进行反向传播的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在通常的反向传播中,我们向前传播一次,计算梯度,然后将其应用以更新权重.但是假设我们希望向前传播两次,并向后传播通过两者,然后仅应用渐变(先跳过).

In usual backprop, we forward-prop once, compute gradients, then apply them to update weights. But suppose we wish to forward-prop twice, and backprop through both, and apply gradients only then (skip on first).

假设以下内容:

x = tf.Variable([2.])
w = tf.Variable([4.])

with tf.GradientTape(persistent=True) as tape:
    w.assign(w * x)
    y = w * w  # w^2 * x
print(tape.gradient(y, x))  # >>None

文档中,tf.Variable是有状态的 对象,该对象阻止渐变,权重为tf.Variable s.

From docs, a tf.Variable is a stateful object, which blocks gradients, and weights are tf.Variables.

例子是与众不同的注意力(与RL相反),或者只是在后续的前向传递中在层之间传递隐藏状态,如下图所示. TF和Keras都不对状态梯度(包括RNN)提供API级别的支持,而梯度仅保留状态状态张量.梯度不会超过一批.

Examples are differentiable hard attention (as opposed to RL), or simply passing a hidden state between layers in subsequent forward passes, as in diagram below. Neither TF nor Keras have an API-level support for stateful gradients, including RNNs, which only keep a stateful state tensor; gradient does not flow beyond one batch.

这怎么完成?

推荐答案

我们需要精心应用tf.while_loop;来自 help(TensorArray) :

We'll need to elaborately apply tf.while_loop; from help(TensorArray):

该类旨在与动态迭代原语(例如while_loopmap_fn)一起使用.它通过特殊的流"来支持梯度反向传播.控制流依赖性.

This class is meant to be used with dynamic iteration primitives such as while_loop and map_fn. It supports gradient back-propagation via special "flow" control flow dependencies.

因此,我们试图编写一个循环,以便将要反向传播的所有输出都写入TensorArray.达到此目的的代码及其下面的高级描述.底部是一个验证示例.

We thus seek to write a loop such that all outputs we are to backpropagate through are written to a TensorArray. Code accomplishing this, and its high-level description, below. At bottom is a validating example.

说明:

  • 代码从 ,为简洁和相关性而重写
  • 为了更好地理解,我建议检查K.rnn SimpleRNNCell.call
  • Code borrows from K.rnn, rewritten for simplicity and relevance
  • For better understanding, I suggest inspecting K.rnn, SimpleRNNCell.call, and RNN.call.
  • model_rnn has a few needless checks for sake of case 3; will link cleaner version
  • The idea's as follows: we traverse the network first bottom-to-top, then left-to-right, and write the entire forward pass to a single TensorArray under a single tf.while_loop; this ensures TF caches tensor ops throughout for backpropagation.
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops


def model_rnn(model, inputs, states=None, swap_batch_timestep=True):
    def step_function(inputs, states):
        out = model([inputs, *states], training=True)
        output, new_states = (out if isinstance(out, (tuple, list)) else
                              (out, states))
        return output, new_states

    def _swap_batch_timestep(input_t):
        # (samples, timesteps, channels) -> (timesteps, samples, channels)
        # iterating dim0 to feed (samples, channels) slices expected by RNN
        axes = list(range(len(input_t.shape)))
        axes[0], axes[1] = 1, 0
        return array_ops.transpose(input_t, axes)

    if swap_batch_timestep:
        inputs = nest.map_structure(_swap_batch_timestep, inputs)

    if states is None:
        states = (tf.zeros(model.inputs[0].shape, dtype='float32'),)
    initial_states = states
    input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)

    def _step(time, output_ta_t, *states):
        current_input = input_ta.read(time)
        output, new_states = step_function(current_input, tuple(states))

        flat_state = nest.flatten(states)
        flat_new_state = nest.flatten(new_states)
        for state, new_state in zip(flat_state, flat_new_state):
            if isinstance(new_state, ops.Tensor):
                new_state.set_shape(state.shape)

        output_ta_t = output_ta_t.write(time, output)
        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
        return (time + 1, output_ta_t) + tuple(new_states)

    final_outputs = tf.while_loop(
        body=_step,
        loop_vars=(time, output_ta) + tuple(initial_states),
        cond=lambda time, *_: tf.math.less(time, time_steps_t))

    new_states = final_outputs[2:]
    output_ta = final_outputs[1]
    outputs = output_ta.stack()
    return outputs, new_states


def _process_args(model, inputs):
    time_steps_t = tf.constant(inputs.shape[0], dtype='int32')

    # assume single-input network (excluding states)
    input_ta = tensor_array_ops.TensorArray(
        dtype=inputs.dtype,
        size=time_steps_t,
        tensor_array_name='input_ta_0').unstack(inputs)

    # assume single-input network (excluding states)
    # if having states, infer info from non-state nodes
    output_ta = tensor_array_ops.TensorArray(
        dtype=model.outputs[0].dtype,
        size=time_steps_t,
        element_shape=model.outputs[0].shape,
        tensor_array_name='output_ta_0')

    time = tf.constant(0, dtype='int32', name='time')
    return input_ta, output_ta, time, time_steps_t


示例和验证:

案例设计:我们两次输入相同的输入,从而实现了某些有状态与无状态的比较;结果也适用于不同的输入.

Case design: we feed the same input twice, which enables certain stateful vs stateless comparisons; results also hold for differing inputs.

  • 情况0 :控制;其他情况必须与此相符.
  • 案例1 :失败;梯度不匹配,即使输出和损耗匹配.输入一半的序列时,反向传播失败.
  • 案例2 :渐变匹配案例1.似乎我们只使用了一个tf.while_loop,但是SimpleRNN在3个时间步中使用了自己的一个,并写入了TensorArray被丢弃;这不会.一种解决方法是自己实现SimpleRNN逻辑.
  • 案例3 :完美匹配.
  • Case 0: control; other cases must match this.
  • Case 1: fail; gradients don't match, even though outputs and loss do. Backprop fails when feeding the halved sequence.
  • Case 2: gradients match case 1. It may seem we've used only one tf.while_loop, but SimpleRNN uses one of its own for the 3 timesteps, and writes to a TensorArray that's discarded; this won't do. A workaround is to implement the SimpleRNN logic ourselves.
  • Case 3: perfect match.

请注意,不存在有状态的RNN单元;有状态性是在RNN基类中实现的,我们已经在model_rnn中重新创建了它.同样,这也是如何处理其他任何层的方法-每次向前通过一次都送入一个阶梯状切片.

Note that there's no such thing as a stateful RNN cell; statefulness is implemented in the RNN base class, and we've recreated it in model_rnn. This is likewise how any other layer is to be handled - feeding one step slice at a time for every forward pass.

import random
import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model

def reset_seeds():
    random.seed(0)
    np.random.seed(1)
    tf.compat.v1.set_random_seed(2)  # graph-level seed
    tf.random.set_seed(3)  # global seed

def print_report(case, model, outs, loss, tape, idx=1):
    print("\nCASE #%s" % case)
    print("LOSS", loss)
    print("GRADS:\n", tape.gradient(loss, model.layers[idx].weights[0]))
    print("OUTS:\n", outs)


#%%# Make data ###############################################################
reset_seeds()
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
x0_2 = y0_2 = tf.concat([x0, x0], axis=1)
x00  = y00  = tf.stack([x0, x0], axis=0)

#%%# Case 0: Complete forward pass; control case #############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model0 = Model(ipt, out)
model0.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs = model0(x0_2, training=True)
    loss = model0.compiled_loss(y0_2, outs)
print_report(0, model0, outs, loss, tape)

#%%# Case 1: Two passes, stateful RNN, direct feeding ########################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model1 = Model(ipt, out)
model1.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs0 = model1(x0, training=True)
    tape.watch(outs0)  # cannot even diff otherwise
    outs1 = model1(x0, training=True)
    tape.watch(outs1)
    outs = tf.concat([outs0, outs1], axis=1)
    tape.watch(outs)
    loss = model1.compiled_loss(y0_2, outs)
print_report(1, model1, outs, loss, tape)

#%%# Case 2: Two passes, stateful RNN, model_rnn #############################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model2 = Model(ipt, out)
model2.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs, _ = model_rnn(model2, x00, swap_batch_timestep=False)
    outs = tf.concat(list(outs), axis=1)
    loss = model2.compiled_loss(y0_2, outs)
print_report(2, model2, outs, loss, tape)

#%%# Case 3: Single pass, stateless RNN, model_rnn ###########################
reset_seeds()
ipt  = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model3 = Model([ipt, sipt], [out, state])
model3.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs, _ = model_rnn(model3, x0_2)
    outs = tf.transpose(outs, (1, 0, 2))
    loss = model3.compiled_loss(y0_2, outs)
print_report(3, model3, outs, loss, tape, idx=2)


垂直流:我们已经验证了水平,时间方向-反向传播;垂直呢?


Vertical flow: we've validated horizontal, timewise-backpropagation; what about vertical?

为此,我们实现了一个堆叠的有状态RNN;结果如下. 此处.

To this end, we implement a stacked stateful RNN; results below. All outputs on my machine, here.

我们特此验证了垂直水平有状态反向传播.这可用于通过正确的反向传播实现任意复杂的向前传播逻辑.应用示例此处.

We've hereby validated both vertical and horizontal stateful backpropagation. This can be used to implement arbitrarily complex forward-prop logic with correct backprop. Applied example here.

#%%# Case 4: Complete forward pass; control case ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
x   = SimpleRNN(4, return_sequences=True)(ipt)
out = SimpleRNN(4, return_sequences=True)(x)
model4 = Model(ipt, out)
model4.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
    outs = model4(x0_2, training=True)
    loss = model4.compiled_loss(y0_2, outs)
print("=" * 80)
print_report(4, model4, outs, loss, tape, idx=1)
print_report(4, model4, outs, loss, tape, idx=2)

#%%# Case 5: Two passes, stateless RNN; model_rnn ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model5a = Model(ipt, out)
model5a.compile('sgd', 'mse')

ipt  = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model5b = Model([ipt, sipt], [out, state])
model5b.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
    outs = model5a(x0_2, training=True)
    outs, _ = model_rnn(model5b, outs)
    outs = tf.transpose(outs, (1, 0, 2))
    loss = model5a.compiled_loss(y0_2, outs)
print_report(5, model5a, outs, loss, tape)
print_report(5, model5b, outs, loss, tape, idx=2)

这篇关于通过多次向前传播进行反向传播的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆