如何在 TensorFlow 中处理带有可变长度序列的批次? [英] How to deal with batches with variable-length sequences in TensorFlow?

查看:39
本文介绍了如何在 TensorFlow 中处理带有可变长度序列的批次?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我试图使用 RNN(特别是 LSTM)进行序列预测.但是,我遇到了可变序列长度的问题.例如,

I was trying to use an RNN (specifically, LSTM) for sequence prediction. However, I ran into an issue with variable sequence lengths. For example,

sent_1 = "I am flying to Dubain"
sent_2 = "I was traveling from US to Dubai"

我正在尝试使用基于此的简单 RNN 预测当前单词之后的下一个单词 构建 PTB LSTM 模型的基准.

I am trying to predicting the next word after the current one with a simple RNN based on this Benchmark for building a PTB LSTM model.

然而,num_steps 参数(用于展开到之前的隐藏状态)应该在每个 Tensorflow 的 epoch 中保持不变.基本上,批处理句子是不可能的,因为句子的长度各不相同.

However, the num_steps parameter (used for unrolling to the previous hidden states), should remain the same in each Tensorflow's epoch. Basically, batching sentences is not possible as the sentences vary in length.

 # inputs = [tf.squeeze(input_, [1])
 #           for input_ in tf.split(1, num_steps, inputs)]
 # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)

这里,num_steps 需要在我的情况下为每个句子更改.我尝试了几次黑客攻击,但似乎没有任何效果.

Here, num_steps need to be changed in my case for every sentence. I have tried several hacks, but nothing seems working.

推荐答案

你可以使用如下描述的bucketing和padding的思想:

You can use the ideas of bucketing and padding which are described in:

   序列到序列模型

    Sequence-to-Sequence Models

此外,创建 RNN 网络的 rnn 函数接受参数 sequence_length.

Also, the rnn function which creates RNN network accepts parameter sequence_length.

例如,您可以创建相同大小的句子桶,用必要数量的零填充它们,或代表零字的占位符,然后将它们与 seq_length = len(zero_words) 一起提供.

As an example, you can create buckets of sentences of the same size, pad them with the necessary amount of zeros, or placeholders which stand for zero word and afterwards feed them along with seq_length = len(zero_words).

seq_length = tf.placeholder(tf.int32)
outputs, states = rnn.rnn(cell, inputs, initial_state=initial_state, sequence_length=seq_length)

sess = tf.Session()
feed = {
    seq_length: 20,
    #other feeds
}
sess.run(outputs, feed_dict=feed)

也看看这个 reddit 帖子:

Take a look at this reddit thread as well:

   具有可变长度"的 Tensorflow 基本 RNN 示例序列

   Tensorflow basic RNN example with 'variable length' sequences

这篇关于如何在 TensorFlow 中处理带有可变长度序列的批次?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆