从张量中获取值的随机索引 [英] Get a random index of a value from a tensor

查看:31
本文介绍了从张量中获取值的随机索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个包含一些数值的张量:

I have a tensor of tensors that contain some numeric values:

[ [0,0,0,1,1,1], [1,2,1,0,1,0] ... ] 

对于每个张量,我想获得一个零值的随机索引.因此,对于第一个张量,可能的输出值为 0,1,2,对于第二个张量,可能的输出值为 3,5.(我只想要从这些可能的结果中随机抽取一个,例如 [0,5])

for each of the tensors i would like to get a random index of a zero value. so for the first tensor possible output values are 0,1,2 for the second tensor possible values are 3,5. (i only want one from each of these possible outcomes at random so something like [0,5])

在 tensorflow 中实现此目的的最佳方法是什么?

What is the best way to accomplish this in tensorflow?

推荐答案

这是一种可能的解决方案:

This is one possible solution:

import tensorflow as tf

# Input data
nums = tf.placeholder(tf.int32, [None, None])
rows = tf.shape(nums)[0]
# Number of zeros on each row
zero_mask = tf.cast(tf.equal(nums, 0), tf.int32)
num_zeros = tf.reduce_sum(zero_mask, axis=1)
# Random values
r = tf.random_uniform([rows], 0, 1, dtype=tf.float32)
# Multiply by the number of zeros to decide which of the zeros you pick
zero_idx = tf.cast(tf.floor(r * tf.cast(num_zeros, r.dtype)), tf.int32)
# Find the indices of the smallest values, which should be the zeros
_, zero_pos = tf.nn.top_k(-nums, k=tf.maximum(tf.reduce_max(num_zeros), 1))
# Select the corresponding position of each row
result = tf.gather_nd(zero_pos, tf.stack([tf.range(rows), zero_idx], axis=1))
# Test
with tf.Session() as sess:
    x = [[0,0,0,1,1,1],
         [1,2,1,0,1,0]]
    print(sess.run(result, feed_dict={nums: x}))
    print(sess.run(result, feed_dict={nums: x}))
    print(sess.run(result, feed_dict={nums: x}))

示例输出:

[1 3]
[2 5]
[0 3]

如果某行没有任何零,那么它会选择索引 0,尽管您可以制作一个掩码来过滤那些:

If some row does not have any zero then it will pick the index 0, although you can make a mask to filter those with something like:

has_zeros = tf.reduce_any(tf.equal(nums, 0), axis=1)

这篇关于从张量中获取值的随机索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆