TensorFlow,批量索引(第一维)和排序 [英] TensorFlow, batchwise indexing (first dimension) and sorting

查看:2150
本文介绍了TensorFlow,批量索引(第一维)和排序的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个形状为(?,368,5)的参数张量,以及一个形状为的查询张量(?, 368)。查询张量存储用于对第一张量进行排序的索引。

I've got a params tensor with shape (?,368,5), as well as a query tensor with shape (?,368). The query tensor stores indices for sorting the first tensor.

所需的输出形状为:(?,368,5)。由于我需要它用于神经网络中的损失函数,因此使用的操作应该保持可微。此外,在运行时,第一个轴的大小对应于batchsize。

The required output has shape: (?,368,5). Since I need it for a loss function in a neural network, the used operations should stay differentiable. Also, at runtime the size of the first axis ? corresponds to the batchsize.

到目前为止,我试验了 tf.gather tf.gather_nd ,但是
tf.gather(params,查询)产生一个形状(?,368,368,5)的张量。

So far I experimented with tf.gather and tf.gather_nd, however tf.gather(params,query) results in a tensor with shape (?,368,368,5).

查询张量是通过以下方式实现的:

The query tensor is achieved by performing:

query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

总的来说,我尝试通过第三轴上的第一个元素对params张量进行排序(对于倒角距离的种类)。最后要提到的是,我使用 Keras 框架。

Overall, I try to sort the params tensor by the first element on the third axis (for kind of a chamfer distance). At last to mention is, that I work with the Keras framework.

推荐答案

您需要将第一个维度的索引添加到 query ,以便将其与 tf.gather_nd 一起使用。这是一种方法:

You need to add the indices of the first dimension to query in order to use it with tf.gather_nd. Here is a way to do it:

import tensorflow as tf
import numpy as np

np.random.seed(100)

with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.placeholder(tf.float32, [None, 368, 5])
    query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
    n = tf.shape(params)[0]
    # Make tensor of indices for the first dimension
    ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
    # Stack indices
    idx = tf.stack([ii, query], axis=-1)
    # Gather reordered tensor
    result = tf.gather_nd(params, idx)
    # Test
    out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
    # Check the order is correct
    print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
    # True

这篇关于TensorFlow,批量索引(第一维)和排序的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆