神经机器翻译模型的预测是一对一的 [英] Neural Machine Translation model predictions are off-by-one

查看:56
本文介绍了神经机器翻译模型的预测是一对一的的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

问题摘要

在下面的示例中,我的NMT模型损失很大,因为它可以正确预测target_input而不是target_output.

In the following example, my NMT model has high loss because it correctly predicts target_input instead of target_output.

Targetin   :  1  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Targetout  :  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Prediction :  3  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  3  3  3  3  3  3 10  3  3 10  3  3 10  3  3  9  3  4  4  4  4  4  3 10  3  3  9  3  3  6  6  6  6  6  6 10  9  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3  3  3  3  6  6  6  6  6  9  6  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  9  3  3  4  4  4  4  4  4  4  4  4  3 10  4  4  4  4  4  4  4  4  4  4  9  3  3  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  6  6  6  6  6  3  9  3  3  3  3  3  3  3  3  3  3  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  3  3  3  3  3 10  3  3  3  9  3  3 10  3  3  3  3  9  3  9  3 10  3  3  3  3  4  4  4  4  3 10  6  6  6  6  6  6  3  3 10  3  3  3  3  9  3  6  6  6  6  6  6  6  6  6  9  6  9  3  3  3  6  6  6  6  6  6  6  6  3  9  3  9  3  3  6  6  6  3  3  3  3  3  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6
Source     :  9 16  4  7 22 22 19  1 12 19 12 18  5 18  9 18  5  8 12 19 19  5  5 19 22  7 12 12  6 19  7  3 20  7  9 14  4 11 20 12  7  1 18  7  7  5 22  9 13 22 20 19  7 19  7 13  7 11 19 20  6 22 18 17 17  1 12 17 23  7 20  1 13  7 11 11 22  7 12  1 13 12  5  5 19 22  5  5 20  1  5  4 12  9  7 12  8 14 18 22 18 12 18 17 19  4 19 12 11 18  5  9  9  5 14  7 11  6  4 17 23  6  4  5 12  6  7 14  4 20  6  8 12 25  4 19  6  1  5  1  5 20  4 18 12 12  1 11 12  1 25 13 18 19  7 12  7  3  4 22  9  9 12  4  8  9 19  9 22 22 19  1 19  7  5 19  4  5 18 11 13  9  4 14 12 13 20 11 12 11  7  6  1 11 19 20  7 22 22 12 22 22  9  3  8 12 11 14 16  4 11  7 11  1  8  5  5  7 18 16 22 19  9 20  4 12 18  7 19  7  1 12 18 17 12 19  4 20  9  9  1 12  5 18 14 17 17  7  4 13 16 14 12 22 12 22 18  9 12 11  3 18  6 20  7  4 20  7  9  1  7 25 13  5 25 14 11  5 20  7 23 12  5 16 19 19 25 19  7 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

显而易见,该预测与target_input而不是target_output几乎匹配了100%(应为1).损耗和梯度是使用target_output计算的,因此预测与target_input匹配的情况很奇怪.

As is evident, the prediction matches up almost 100% with target_input instead of target_output, as it should (off-by-one). Loss and gradients are being calculated using target_output, so it is strange that predictions are matching up to target_input.

模型概述

NMT模型使用源语言中的主要单词序列来预测目标语言中的单词序列.这是Google Translate的框架.由于NMT使用耦合RNN,因此需要对其进行监督并标记目标输入和输出.

An NMT model predicts a sequence of words in a target language using a primary sequence of words in a source language. This is the framework behind Google Translate. Since NMT uses coupled-RNNs, it is supervised and required labelled target input and output.

NMT使用source序列,target_input序列和target_output序列.在下面的示例中,编码器RNN(蓝色)使用源输入字生成含义向量,然后将其传递给解码器RNN(红色),解码器RNN(红色)使用含义向量生成输出.

NMT uses a source sequence, a target_input sequence, and a target_output sequence. In the example below, the encoder RNN (blue) uses the source input words to produce a meaning vector, which it passes to the decoder RNN (red), which uses the meaning vector to produce output.

在进行新的预测(推理)时,解码器RNN使用其自己的先前输出在时间步中播种下一个预测.但是,为了提高训练效果,可以在每个新的时间步长使用正确的先前预测来播种自己.这就是为什么target_input对于训练必不可少的原因.

When doing new predictions (inference), the decoder RNN uses its own previous output to seed the next prediction in the timestep. However, to improve training, it is allowed to seed itself with the correct previous prediction at each new timestep. This is why target_input is necessary for training.

获取带有源,target_in,target_out的迭代器的代码

def get_batched_iterator(hparams, src_loc, tgt_loc):
    if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):
        utils.integerize_raw_data()

    source_dataset = tf.data.TextLineDataset(src_loc)
    target_dataset = tf.data.TextLineDataset(tgt_loc)
    dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
    dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)

    dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                                  tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
    dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))
    dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))
    # Proceed to batch and return iterator

NMT模型核心代码

def __init__(self, hparams, iterator, mode):
        source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()

        # Lookup embeddings
        embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
        encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
        embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)

        # Build and run Encoder LSTM
        encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)

        # Build and run Decoder LSTM with TrainingHelper and output projection layer
        decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)
        helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
        logits = outputs.rnn_output

        if mode is 'TRAIN' or mode is 'EVAL':  # then calculate loss
            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
            target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
            self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)

        if mode is 'TRAIN':  # then calculate/clip gradients, then optimize model
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)

            optimizer = tf.train.AdamOptimizer(hparams.l_rate)
            self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))

        if mode is 'EVAL':  # then allow access to input/output tensors to printout
            self.src = source
            self.tgt_in = target_in
            self.tgt_out = target_out
            self.logits = logits

完整代码(不必解决此问题) https://github.com/nave01314/tf-nmt

推荐答案

用于预测具有重复结构的类语言语法的NMT模型的核心问题是,激励人们去简单地预测过去的预测是什么.由于TrainingHelper在每个步骤中都向其提供了正确的先前预测,以加快训练速度,因此这会人为地产生模型无法摆脱的局部最小值.

The core issue with the NMT model used to predict a language-like syntax with a repetitive structure is that it becomes incentivized to simply predict whatever the past prediction was. Since it is fed the correct previous prediction at each step by TrainingHelper to speed up training, this artificially produces a local minimum that the model is unable to get out of.

我发现的最佳选择是加权损失函数,以使输出序列中输出不重复的关键点的加权更大.这将激励模型获得正确的信息,而不仅仅是重复 过去的预测.

The best option I have found is to weight the loss functions such the key points in the output sequence where the output is not repetitive are weighted more heavily. This will incentivize the model to get those correct, and not just repeat the past prediction.

这篇关于神经机器翻译模型的预测是一对一的的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆