使用 Tensorflow 构建一个适用于可变批量大小的图表 [英] Build a graph that works with variable batch size using Tensorflow

查看:26
本文介绍了使用 Tensorflow 构建一个适用于可变批量大小的图表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用 tf.placeholders() 操作来提供可变批量大小的输入,这些输入是 2D 张量,并在我调用 run() 时使用馈送机制为这些张量提供不同的值.我得到了

I use tf.placeholders() ops for feeding variable batch size of input which are 2D tensors and use the feed mechanism to supply different values for those tensors when I call run(). I got

TypeError: 'Tensor' 对象不可迭代.

TypeError: 'Tensor' object is not iterable.

以下是我的代码:

with graph.as_default():
    train_index_input = tf.placeholder(tf.int32, shape=(None, window_size))
    train_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_dimension], -1.0, 1.0))
    embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]
    ......
    ......

由于我在不运行图形的情况下无法看到张量train_index_input"的内容,因此代码引发了'Tensor' object is not iterable"的错误:

As I can not see the content of the tensor "train_index_input" without running the graph, so the error of "'Tensor' object is not iterable" raises for the code:

embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]

我想要获得的是一个形状为 [batch_size, embedding_dimension] 的嵌入矩阵embedding_input",其中 batch_size 不固定.我是否必须在 Tensorflow 中定义一个新操作来嵌入 2D 张量查找?或者有其他方法吗?谢谢

What I want to obtain is an embedding matrix "embedding_input" that is of shape [batch_size, embedding_dimension] where batch_size is not fixed. Do I have to define a new operation in Tensorflow for embedding lookup for 2D tensors? Or any other ways to do it?Thanks

推荐答案

您正在尝试对 Tensorflow 占位符进行 Python 级别的列表推导(for x in train_index_input).那是行不通的——Python 不知道 tf 对象里面有什么.

You're trying to do a python-level list comprehension (for x in train_index_input) over a Tensorflow placeholder. That won't work - Python has no idea what's inside a tf object.

要完成批量嵌入查找,您可以做的只是展平您的批次:

To accomplish a batch embedding lookup, what you can do is just flatten your batch:

train_indexes_flat = tf.reshape(train_index_input, [-1])

通过嵌入查找运行它:

looked_up_embeddings = tf.nn.embedding_lookup(train_embeddings, train_indexes_flat)

然后将其重塑回正确的组:

and then reshape it back into the right groups:

embedding_input = tf.reshape(looked_up_embeddings, [-1, window_size])

这篇关于使用 Tensorflow 构建一个适用于可变批量大小的图表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆