Keras张量-使用来自另一个张量的索引获取值 [英] Keras tensors - Get values with indices coming from another tensor
本文介绍了Keras张量-使用来自另一个张量的索引获取值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
假设我有两个张量:
-
valueMatrix
,形状为(?, 3)
,其中?
是批次大小 -
indexMatrix
,形状为(?, 1)
valueMatrix
, shaped as(?, 3)
, where?
is the batch sizeindexMatrix
, shaped as(?, 1)
我想从indexMatrix
中包含的索引处的valueMatrix
中检索值.
I want to retrieve values from valueMatrix
at the indices contained in indexMatrix
.
示例(伪代码):
valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int
在此示例中,我希望执行以下操作:
I want from this example to do something like:
valueMatrix[indexMatrix] --> returns --> [[15],[4]]
与其他后端相比,我更喜欢Tensorflow,但答案必须与使用Lambda层或其他适合该任务的层的Keras模型兼容.
I prefer Tensorflow over other backends, but the answer must be compatible with a Keras model using Lambda layers or other suitable layers for the task.
推荐答案
import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])
# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)
with tf.Session() as sess:
print(sess.run(values))
#[[15]
# [ 4]]
这篇关于Keras张量-使用来自另一个张量的索引获取值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文