如何使用 tf.nn.top_k 的返回索引对多维张量进行排序? [英] How to sort a multi-dimensional tensor using the returned indices of tf.nn.top_k?

查看:41
本文介绍了如何使用 tf.nn.top_k 的返回索引对多维张量进行排序?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个多维张量 ab.我想按 a 的值对它们进行排序.

I have two multi-dimensional tensors a and b. And I want to sort them by the values of a.

我找到了 tf.nn.top_k 能够对张量进行排序并返回用于对输入进行排序的索引.如何使用 tf.nn.top_k(a, k=2) 返回的索引对 b 进行排序?

I found tf.nn.top_k is able to sort a tensor and return the indices which is used to sort the input. How can I use the returned indices from tf.nn.top_k(a, k=2) to sort b?

例如

import tensorflow as tf

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)

# How to sort b into
# sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :]
# sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :]
# sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :]
# ...

<小时>

更新

tf.gather_ndtf.meshgrid 可以是一种解决方案.例如,以下代码在 python 3.5 上使用 tensorflow 1.0.0-rc0 进行测试:

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2

sorted_a, indices = tf.nn.top_k(a, k)

shape_a = tf.shape(a)
auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij')

sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1))

但是,我想知道是否有一个更易读且不需要在上面创建auxiliary_indices的解决方案.

However, I wonder if there is a solution which is more readable and doesn't need to create auxiliary_indices above.

推荐答案

您的代码有问题.

b = tf.reshape(tf.range(60), (2, 5, 3, 7))

因为 TensorFlow 无法重塑具有 60 个元素的张量以塑造 [2,5,3,7](210 个元素).并且您无法使用 3 阶张量的索引对 4 阶张量 (b) 进行排序.

Because TensorFlow Cannot reshape a tensor with 60 elements to shape [2,5,3,7] (210 elements). And you can't sort a rank 4 tensor (b) using indices of rank 3 tensors.

这篇关于如何使用 tf.nn.top_k 的返回索引对多维张量进行排序?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆