如何正确使用 TensorFlow tensorflow.contrib.seq2seq [英] How to correctly use TensorFlow tensorflow.contrib.seq2seq
问题描述
我以某种方式滥用了 TensorFlow 的 tf.contrib.seq2seq
模块,但没有产生错误,所以我很难找到错误.我的问题是我的解码器为输出序列中的每个输出输出相同的值(在我的情况下,0 和 3 之间的分类标签,包括).在下面的例子中,我的输出序列有 8 个标签.
I'm misusing TensorFlow's tf.contrib.seq2seq
module in some manner, but no errors are produced so I'm having trouble find the bug. My problem is that my decoder outputs the same value (in my case, a categorical label between 0 and 3, inclusive) for every output in the output sequence. In the below example, my output sequence has 8 labels.
我的代码:
attention_mechanism = BahdanauAttention(num_units=ATTENTION_UNITS,
memory=encoder_outputs,
normalize=True)
attention_wrapper = AttentionWrapper(cell=self._create_lstm_cell(DECODER_SIZE),
attention_mechanism=attention_mechanism,
attention_layer_size=None)
attention_zero = attention_wrapper.zero_state(batch_size=self.x.shape[0], dtype=tf.float32)
# concatenate c1 and c2 from encoder final states
new_c = tf.concat([encoder_final_states[0].c, encoder_final_states[1].c], axis=1)
# concatenate h1 and h2 from encoder final states
new_h = tf.concat([encoder_final_states[0].h, encoder_final_states[1].h], axis=1)
# define initial state using concatenated h states and c states
init_state = attention_zero.clone(cell_state=LSTMStateTuple(c=new_c, h=new_h))
training_helper = TrainingHelper(inputs=self.y_actual, # feed in ground truth
sequence_length=output_length) # feed in sequence length
decoder = BasicDecoder(cell=attention_wrapper,
helper=training_helper,
initial_state=init_state
)
decoder_outputs, decoder_final_state, decoder_final_sequence_lengths = dynamic_decode(decoder=decoder,
impute_finished=True)
我需要创建 LSTMStateTuple
因为我的编码器使用双向 RNN.
I need to create the LSTMStateTuple
because my encoder uses a bidirectional RNN.
我怀疑错误出在解码器中,因为我的编码器的输出没有任何一致性.但是,我可能是错的.
I suspect that the error is in the decoder because the outputs of my encoder don't have any appearance of uniformity. However, I could be wrong.
推荐答案
问题确实是我需要设置 output_attention=False
因为我使用的是 Bahdanau Attention.
The problem was indeed that I needed to set output_attention=False
because I am using Bahdanau Attention.
这篇关于如何正确使用 TensorFlow tensorflow.contrib.seq2seq的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!