如何显式广播张量以匹配张量流中的另一个形状? [英] How to explicitly broadcast a tensor to match another's shape in tensorflow?

查看:32
本文介绍了如何显式广播张量以匹配张量流中的另一个形状?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在tensorflow中有三个张量,A、B和CAB都是(m,n, r), C 是形状为 (m, n, 1) 的二进制张量.

I have three tensors, A, B and C in tensorflow, A and B are both of shape (m, n, r), C is a binary tensor of shape (m, n, 1).

我想根据 C 的值从 A 或 B 中选择元素.显而易见的工具是 tf.select,但是它没有广播语义,所以我需要首先将 C 显式广播到与 A 和 B 相同的形状.

I want to select elements from either A or B based on the value of C. The obvious tool is tf.select, however that does not have broadcasting semantics, so I need to first explicitly broadcast C to the same shape as A and B.

这将是我第一次尝试如何做到这一点,但它不喜欢我将张量 (tf.shape(A)[2]) 混合到形状列表中.

This would be my first attempt at how to do this, but it doesn't like me mixing a tensor (tf.shape(A)[2]) into the shape list.

import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))

C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)

这里的正确方法是什么?

What's the correct approach here?

推荐答案

在自 0.12rc0 以来的所有 TensorFlow 版本中,问题中的代码直接有效.TensorFlow 会自动将张量和 Python 数字叠加到张量参数中.以下使用 tf.pack() 的解决方案仅在 0.12rc0 之前的版本中需要.请注意,tf.pack() 已重命名为 tf.stack() 在 TensorFlow 1.0 中.

In all versions of TensorFlow since 0.12rc0, the code in the question works directly. TensorFlow will automatically stack tensors and Python numbers into a tensor argument. The solution below using tf.pack() is only needed in versions prior to 0.12rc0. Note that tf.pack() was renamed to tf.stack() in TensorFlow 1.0.

您的解决方案非常接近工作.您应该替换该行:

Your solution is very close to working. You should replace the line:

C = tf.tile(C, [1,1,tf.shape(C)[2]])

...具有以下内容:

C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))

(问题的原因是 TensorFlow 不会将张量列表和 Python 文字隐式转换为张量.tf.pack() 接受一个张量列表,因此它将转换其输入中的每个元素(1, 1, and tf.shape(C)[2]) 到张量.由于每个元素都是标量,结果将是一个向量.)

(The reason for the issue is that TensorFlow won't implicitly convert a list of tensors and Python literals into a tensor. tf.pack() takes a list of tensors, so it will convert each of the elements in its input (1, 1, and tf.shape(C)[2]) to a tensor. Since each element is a scalar, the result will be a vector.)

这篇关于如何显式广播张量以匹配张量流中的另一个形状?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆