Keras张量-使用来自另一个张量的索引获取值 [英] Keras tensors - Get values with indices coming from another tensor

查看:796
本文介绍了Keras张量-使用来自另一个张量的索引获取值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有两个张量:

  • valueMatrix,形状为(?, 3),其中?是批次大小
  • indexMatrix,形状为(?, 1)
  • valueMatrix, shaped as (?, 3), where ? is the batch size
  • indexMatrix, shaped as (?, 1)

我想从indexMatrix中包含的索引处的valueMatrix中检索值.

I want to retrieve values from valueMatrix at the indices contained in indexMatrix.

示例(伪代码):

valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float 
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int

在此示例中,我希望执行以下操作:

I want from this example to do something like:

valueMatrix[indexMatrix] --> returns --> [[15],[4]]


与其他后端相比,我更喜欢Tensorflow,但答案必须与使用Lambda层或其他适合该任务的层的Keras模型兼容.


I prefer Tensorflow over other backends, but the answer must be compatible with a Keras model using Lambda layers or other suitable layers for the task.

推荐答案

import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])

# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)

with tf.Session() as sess:
    print(sess.run(values))
#[[15]
# [ 4]]

这篇关于Keras张量-使用来自另一个张量的索引获取值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆