Tensorflow,在 RNN 中保存状态的最佳方法? [英] Tensorflow, best way to save state in RNNs?

查看:55
本文介绍了Tensorflow,在 RNN 中保存状态的最佳方法?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前在 tensorflow 中有一系列链接在一​​起的 RNN 的以下代码.我没有使用 MultiRNN,因为我稍后要对每一层的输出做一些事情.

I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of each layer.

 for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
        time_outputs = [None] * TIME_STEPS

        for t in range(TIME_STEPS):
            rnn_input = getTimeStep(rnn_outputs[r - 1], t)
            time_outputs[t], state = rnn_func(rnn_input, state)
            time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
            scope.reuse_variables()
        rnn_outputs[r] = tf.concat(1, time_outputs)

目前我有固定数量的时间步长.但是,我想将其更改为只有一个时间步长,但要记住批次之间的状态.因此,我需要为每一层创建一个状态变量,并将每一层的最终状态分配给它.像这样的东西.

Currently I have a fixed number of time steps. However I would like to change it to have only one timestep but remember the state between batches. I would therefore need to create a state variable for each layer and assign it the final state of each of the layers. Something like this.

for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        saved_state = tf.get_variable('saved_state', ...)
        rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
        saved_state = tf.assign(saved_state, state)

然后,对于每一层,我需要评估 sess.run 函数中保存的状态以及调用我的训练函数.我需要为每个 rnn 层执行此操作.这似乎有点麻烦.我需要跟踪每个保存的状态并在运行中对其进行评估.然后运行还需要将状态从我的 GPU 复制到主机内存,这将是低效且不必要的.有没有更好的方法来做到这一点?

Then for each of the layers I would need to evaluate the saved state in my sess.run function as well as calling my training function. I would need to do this for every rnn layer. This seems like kind of a hassle. I would need to track every saved state and evaluate it in run. Also then run would need to copy the state from my GPU to host memory which would be inefficient and unnecessary. Is there a better way of doing this?

推荐答案

这里是通过定义状态变量来更新 LSTM 初始状态的代码,当 state_is_tuple=True 时.它还支持多层.

Here is the code to update the LSTM's initial state, when state_is_tuple=True by defining state variables. It also supports multiple layers.

我们定义了两个函数 - 一个用于获取具有初始零状态的状态变量和一个用于返回操作的函数,我们可以将其传递给 session.run 以便更新状态变量LSTM 的最后一个隐藏状态.

We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

我们可以使用它在每批之后更新 LSTM 的状态.请注意,我使用 tf.nn.dynamic_rnn 展开:

We can use that to update the LSTM's state after each batch. Note that I use tf.nn.dynamic_rnn for unrolling:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

这个答案的主要区别在于,state_is_tuple=True 使 LSTM 的状态成为LSTMStateTuple 包含两个变量(细胞状态和隐藏状态),而不仅仅是一个变量.使用多个层然后使 LSTM 的状态成为 LSTMStateTuples 的元组 - 每层一个.

The main difference to this answer is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

使用经过训练的模型进行预测/解码时,您可能希望将状态重置为零.然后,您可以使用此功能:

When using a trained model for prediction / decoding, you might want to reset the state to zero. Then, you can make use of this function:

def get_state_reset_op(state_variables, cell, batch_size):
    # Return an operation to set each variable in a list of LSTMStateTuples to zero
    zero_states = cell.zero_state(batch_size, tf.float32)
    return get_state_update_op(state_variables, zero_states)

例如上图:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})

这篇关于Tensorflow,在 RNN 中保存状态的最佳方法?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆