Tensorflow numpy重复 [英] Tensorflow numpy repeat

查看:143
本文介绍了Tensorflow numpy重复的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我希望将特定数字重复不同的次数,如下所示:

I wish to repeat a particular number different number of times as shown below:

x = np.array([0,1,2])
np.repeat(x,[3,4,5])
>>> array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])

(0重复3次,1、4次,等等.)

(The 0 is repeated 3 times, 1, 4 times etc.).

此答案( https://stackoverflow.com/a/35367161/2530674 )似乎暗示我可以结合使用tf.tiletf.reshape可获得相同的效果.但是,我相信只有在重复次数恒定的情况下才会如此.

This answer (https://stackoverflow.com/a/35367161/2530674) seems to suggest that I can use a combination of tf.tile and tf.reshape to get the same effect. However, I believe this is only the case if the repetitions are a constant amount.

如何在Tensorflow中获得相同的效果?

How can I get the same effect in Tensorflow?

edit1:很遗憾,没有tf.repeat.

edit1: there is no tf.repeat unfortunately.

推荐答案

这是解决问题的一种蛮力"解决方案,只需将每个值重复最多重复的次数即可,然后选择正确的元素:

This is a kind of "brute force" solution to the problem, simply tiling every value as many times as the largest number of repetitions and then picking the right elements:

import tensorflow as tf

# Repeats across the first dimension
def tf_repeat(arr, repeats):
    arr = tf.expand_dims(arr, 1)
    max_repeats = tf.reduce_max(repeats)
    tile_repeats = tf.concat(
        [[1], [max_repeats], tf.ones([tf.rank(arr) - 2], dtype=tf.int32)], axis=0)
    arr_tiled = tf.tile(arr, tile_repeats)
    mask = tf.less(tf.range(max_repeats), tf.expand_dims(repeats, 1))
    result = tf.boolean_mask(arr_tiled, mask)
    return result

with tf.Graph().as_default(), tf.Session() as sess:
    print(sess.run(tf_repeat([0, 1, 2], [3, 4, 5])))

输出:

[0 0 0 1 1 1 1 2 2 2 2 2]

这篇关于Tensorflow numpy重复的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆