在张量流中展平批次 [英] Flatten batch in tensorflow

查看:32
本文介绍了在张量流中展平批次的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个形状为 [None, 9, 2] 的 tensorflow 输入(其中 None 是批处理).

I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).

要对其执行进一步的操作(例如 matmul),我需要将其转换为 [None, 18] 形状.怎么做?

To perform further actions (e.g. matmul) on it I need to transform it to [None, 18] shape. How to do it?

推荐答案

您可以使用 tf.reshape() 轻松完成,而无需知道批量大小.

You can do it easily with tf.reshape() without knowing the batch size.

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list()        # a list: [None, 9, 2]
dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim])           # -1 means "all"

最后一行中的 -1 表示整列,无论运行时的批大小如何.您可以在 tf.reshape().

The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().

谢谢@kbrose.对于超过 1 个维度未定义的情况,我们可以使用 tf.shape() 或者 tf.reduce_prod().

Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.

x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])

tf.shape() 返回一个可以在运行时计算的形状张量.tf.get_shape() 和 tf.shape() 之间的区别可以在在文档中看到.

tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.

我也在另一个 .contrib.layers.flatten() 中尝试过.第一种情况最简单,但不能处理第二种情况.

I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.

这篇关于在张量流中展平批次的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆