使用语言模型张量流示例预测下一个单词 [英] Predicting next word using the language model tensorflow example

查看:92
本文介绍了使用语言模型张量流示例预测下一个单词的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

关于语言模型的 tensorflow教程允许计算句子的概率:

The tensorflow tutorial on language model allows to compute the probability of sentences :

probabilities = tf.nn.softmax(logits)

在下面的注释中,它还指定了预测下一个单词的方式,而不是概率,但未指定如何做到.那么如何使用此示例输出单词而不是概率呢?

in the comments below it also specifies a way of predicting the next word instead of probabilities but does not specify how this can be done. So how to output a word instead of probability using this example?

lstm = rnn_cell.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
state = tf.zeros([batch_size, lstm.state_size])

loss = 0.0
for current_batch_of_words in words_in_dataset:
    # The value of state is updated after processing each batch of words.
    output, state = lstm(current_batch_of_words, state)

    # The LSTM output can be used to make next word predictions
    logits = tf.matmul(output, softmax_w) + softmax_b
    probabilities = tf.nn.softmax(logits)
    loss += loss_function(probabilities, target_words)

推荐答案

您需要找到概率的argmax,然后通过反转word_to_id映射将索引转换回单词.为了使它起作用,您必须将概率保存在模型中,然后从run_epoch函数获取它们(您也可以只保存argmax本身).这是一个代码段:

You need to find the argmax of the probabilities, and translate the index back to a word by reversing the word_to_id map. To get this to work, you must save the probabilities in the model and then fetch them from the run_epoch function (you could also save just the argmax itself). Here's a snippet:

inverseDictionary = dict(zip(word_to_id.values(), word_to_id.keys()))

def run_epoch(...):
  decodedWordId = int(np.argmax(logits))
  print (" ".join([inverseDictionary[int(x1)] for x1 in np.nditer(x)])  
    + " got" + inverseDictionary[decodedWordId] + 
    + " expected:" + inverseDictionary[int(y)])

在此处查看完整的实现: https://github.com/nelken/tf

See full implementation here: https://github.com/nelken/tf

这篇关于使用语言模型张量流示例预测下一个单词的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆