如何基于带索引的张量过滤张量流的张量? [英] How to filter tensorflow's Tensor based on tensor with indices?

查看:70
本文介绍了如何基于带索引的张量过滤张量流的张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个大小为[batch_size, 5, 10]的张量称为my_tensor. 我还有一个大小为[batch_size, 1]的张量,其中包含索引为selecter的张量.

我想相对于selecter过滤my_tensor以产生大小为[batch_size, 10]的新张量,即仅选择selecter包含的值.基本上,这是在减小中间尺寸(尺寸为5).

我觉得tf.where是正确的选择,但不确定. 非常感谢您的帮助!

解决方案

解决方案是与tf.gather_nd一起使用.

tf.gather_nd(
    my_tensor,
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))

如果从头开始将selecter构造为一维,则可以摆脱squeeze.

Let's say I have a tensor of size [batch_size, 5, 10] called my_tensor. I also have an another tensor of size [batch_size, 1] holding indices called selecter.

I want to filter my_tensor with respect to selecter to produce new tensor of size [batch_size, 10], i.e. select only values that selecter contains. Basically, it's kinda reducing the middle dimension(which has size 5).

I feel like tf.where is the right choice, but not sure about it. I would really appreciate your help!

解决方案

The solution is to go with tf.gather_nd.

tf.gather_nd(
    my_tensor,
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))

You can get rid of the squeeze if you construct selecter to be 1-D from the beginning.

这篇关于如何基于带索引的张量过滤张量流的张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆