Tensorflow unsorted_segment_sum 维度 [英] Tensorflow unsorted_segment_sum dimension

查看:40
本文介绍了Tensorflow unsorted_segment_sum 维度的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 TensorFlow 的 tf.unsorted_segment_sum 方法,当我作为数据给出的张量只有一行时,它工作正常.例如:

I'm using the tf.unsorted_segment_sum method of TensorFlow and it works fine when the tensor i give as data have only one line. For example:

tf.unsorted_segment_sum(tf.constant([0.2, 0.1, 0.5, 0.7, 0.8]),
                        tf.constant([0, 0, 1, 2, 2]), 3)

给出正确的结果:

array([ 0.3,  0.5 , 1.5 ], dtype=float32)

问题是,如果我使用多行张量,我怎样才能得到每一行的结果?例如,如果我尝试使用两行的张量:

The question is, if i use a tensor with several lines, how can I get the results for each line? For instance, if I try a tensor with two lines:

tf.unsorted_segment_sum(tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                                     [0.2, 0.2, 0.5, 0.7, 0.8]]),
                        tf.constant([[0, 0, 1, 2, 2],
                                     [0, 0, 1, 2, 2]]), 3)

我期望的结果是:

array([ [ 0.3,  0.5 , 1.5 ], [ 0.4, 0.5, 1.5 ] ], dtype=float32)

但我得到的是:

array([ 0.7,  1. ,  3. ], dtype=float32)

我想知道是否有人知道如何在不使用 for 循环的情况下获取每一行的结果?

I want to know if someone know how to obtain the result for each line without using a for loop?

提前致谢

推荐答案

虽然下面的解决方案可能会涵盖一些额外的奇怪用途,但只需通过转置数据就可以更轻松地解决这个问题.事实证明,即使 tf.unsorted_segment_sum 没有 axis 参数,它也只能沿一个轴工作,只要它是第一个.所以你可以这样做:

While the solution below may cover some additional strange uses, this problem can be solved much more easily just by transposing the data. It turns out that, even though tf.unsorted_segment_sum does not have an axis parameter, it can work only along one axis, as long as it is the first one. So you can do just as follows:

import tensorflow as tf

with tf.Session() as sess:
    data = tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                        [0.2, 0.2, 0.5, 0.7, 0.8]])
    idx = tf.constant([0, 0, 1, 2, 2])
    result = tf.transpose(tf.unsorted_segment_sum(tf.transpose(data), idx, 3))
    print(sess.run(result))

输出:

[[ 0.30000001  0.5         1.5       ]
 [ 0.40000001  0.5         1.5       ]]

<小时>

原帖:

tf.unsorted_segment_sum 不支持在单个轴上工作.最简单的解决方案是将操作应用于每一行,然后将它们连接回去:

tf.unsorted_segment_sum does not support working on a single axis. The simplest solution would be to apply the operation to each row and then concatenate them back:

data = tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                    [0.2, 0.2, 0.5, 0.7, 0.8]])
segment_ids = tf.constant([[0, 0, 1, 2, 2],
                           [0, 0, 1, 2, 2]])
num_segments = 3
rows = []
for data_i, ids_i in zip(data, segment_ids):
    rows.append(tf.unsorted_segment_sum(data_i, ids_i))
result = tf.stack(rows, axis=0)

然而,这有缺点:1)它只适用于静态形状的张量(也就是说,你需要有固定的行数)和 2)它可能效率不高.第一个可以使用 tf.while_loop 来规避,但是,它会很复杂,而且还需要您将行一一连接起来,这是非常低效的.此外,您已经声明要避免循环.

However, this has drawbacks: 1) it only works for statically-shaped tensors (that is, you need to have a fixed number of rows) and 2) it may not be as efficient. The first one could be circumvented using a tf.while_loop, but, it would be complicated, and also it would require you to concatenate the rows one by one, which is very inefficient. Also, you already stated you want to avoid loops.

更好的选择是对每一行使用不同的 ID.例如,您可以向 segment_id 中的每个值添加类似 num_segments * row_index 的内容,这样您就可以保证每一行都有自己的一组 id:

A better option is to use different ids for each row. For example, you could add to each value in segment_id something like num_segments * row_index, so you guarantee that each row will have its own set of ids:

num_rows = tf.shape(segment_ids)[0]
rows_idx = tf.range(num_rows)
segment_ids_per_row = segment_ids + num_segments * tf.expand_dims(rows_idx, axis=1)

然后你可以应用操作和reshape来得到你想要的张量:

Then you can apply the operation and the reshape to get the tensor that you want:

seg_sums = tf.unsorted_segment_sum(data, segment_ids_per_row,
                                   num_segments * num_rows)
result = tf.reshape(seg_sums, [-1, num_segments])

输出:

array([[ 0.3, 0.5, 1.5 ],
       [ 0.4, 0.5, 1.5 ]], dtype=float32)

这篇关于Tensorflow unsorted_segment_sum 维度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆