如何显式广播张量以匹配张量流中的另一个形状? [英] How to explicitly broadcast a tensor to match another's shape in tensorflow?
问题描述
我在tensorflow中有三个张量,A、B和C
,A
和B
都是(m,n, r)
, C
是形状为 (m, n, 1)
的二进制张量.
I have three tensors, A, B and C
in tensorflow, A
and B
are both of shape (m, n, r)
, C
is a binary tensor of shape (m, n, 1)
.
我想根据 C
的值从 A 或 B 中选择元素.显而易见的工具是 tf.select
,但是它没有广播语义,所以我需要首先将 C
显式广播到与 A 和 B 相同的形状.
I want to select elements from either A or B based on the value of C
. The obvious tool is tf.select
, however that does not have broadcasting semantics, so I need to first explicitly broadcast C
to the same shape as A and B.
这将是我第一次尝试如何做到这一点,但它不喜欢我将张量 (tf.shape(A)[2]
) 混合到形状列表中.
This would be my first attempt at how to do this, but it doesn't like me mixing a tensor (tf.shape(A)[2]
) into the shape list.
import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))
C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)
这里的正确方法是什么?
What's the correct approach here?
推荐答案
在自 0.12rc0 以来的所有 TensorFlow 版本中,问题中的代码直接有效.TensorFlow 会自动将张量和 Python 数字叠加到张量参数中.以下使用 tf.pack()
的解决方案仅在 0.12rc0 之前的版本中需要.请注意,tf.pack()
已重命名为 tf.stack()
在 TensorFlow 1.0 中.
In all versions of TensorFlow since 0.12rc0, the code in the question works directly. TensorFlow will automatically stack tensors and Python numbers into a tensor argument. The solution below using tf.pack()
is only needed in versions prior to 0.12rc0. Note that tf.pack()
was renamed to tf.stack()
in TensorFlow 1.0.
您的解决方案非常接近工作.您应该替换该行:
Your solution is very close to working. You should replace the line:
C = tf.tile(C, [1,1,tf.shape(C)[2]])
...具有以下内容:
C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))
(问题的原因是 TensorFlow 不会将张量列表和 Python 文字隐式转换为张量.tf.pack()
接受一个张量列表,因此它将转换其输入中的每个元素(1
, 1
, and tf.shape(C)[2]
) 到张量.由于每个元素都是标量,结果将是一个向量.)
(The reason for the issue is that TensorFlow won't implicitly convert a list of tensors and Python literals into a tensor. tf.pack()
takes a list of tensors, so it will convert each of the elements in its input (1
, 1
, and tf.shape(C)[2]
) to a tensor. Since each element is a scalar, the result will be a vector.)
这篇关于如何显式广播张量以匹配张量流中的另一个形状?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!