访问张量中单个元素的更好方法 [英] Better way to access individual elements in a tensor

查看:34
本文介绍了访问张量中单个元素的更好方法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用张量 a 中定义的索引访问张量 a 的元素.

I am trying to access the elements of a tensor a, with the indexes defined in tensor b.

a=tf.constant([[1,2,3,4],[5,6,7,8]])
b=tf.constant([0,1,1,0])

我希望输出是

out = [1 6 7 4]

我尝试过的:

out=[]
for i in range(a.shape[1]):
    out.append(a[b[i],i])

out=tf.stack(out) #[1 6 7 4]

这给出了正确的输出,但我正在寻找一种更好、更紧凑的方法来做到这一点.

This is giving the correct output, but I'm looking for a better and a compact way to do it.

a 的形状类似于 (2,None) 时,我的逻辑也不起作用,因为我无法使用 range(a.shape[1]),如果答案也包括这种情况,那对我有帮助

Also my logic doesnt work when the shape of a is something like (2,None) since I cannot iterate with range(a.shape[1]), it would help me if the answer included this case too

谢谢

推荐答案

您可以使用 tf.one_hot()tf.boolean_mask().

import tensorflow as tf
import numpy as np

a_tf = tf.placeholder(shape=(2,None),dtype=tf.int32)
b_tf = tf.placeholder(shape=(None,),dtype=tf.int32)

index = tf.one_hot(b_tf,a_tf.shape[0])
out = tf.boolean_mask(tf.transpose(a_tf),index)

a=np.array([[1,2,3,4],[5,6,7,8]])
b=np.array([0,1,1,0])
with tf.Session() as sess:
    print(sess.run(out,feed_dict={a_tf:a,b_tf:b}))

# print
[1 6 7 4]

这篇关于访问张量中单个元素的更好方法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆