合并稀疏张量中的重复索引 [英] Merge duplicate indices in a sparse tensor

查看:26
本文介绍了合并稀疏张量中的重复索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个带有重复索引的稀疏张量,并且它们重复的地方我想合并值(总结它们)这样做的最佳方法是什么?

Lets say I have a sparse tensor with duplicate indices and where they are duplicate I want to merge values (sum them up) What is the best way to do this?

示例:

indicies = [[1, 1], [1, 2], [1, 2], [1, 3]]
values = [1, 2, 3, 4]

object = tf.SparseTensor(indicies, values, shape=[10, 10])

result = tf.MAGIC(object)

result 应该是具有以下值(或具体!)的备用张量:

result should be a spare tensor with the following values (or concrete!):

indicies = [[1, 1], [1, 2], [1, 3]]
values = [1, 5, 4]

我唯一想到的是将索引串连在一起以创建索引哈希,将其应用于第三维,然后减少第三维的总和.

The only thing I have though of is to string concat the indicies together to create an index hash apply it to a third dimension and then reduce sum on that third dimension.

indicies = [[1, 1, 11], [1, 2, 12], [1, 2, 12], [1, 3, 13]]
sparse_result = tf.sparse_reduce_sum(sparseTensor, reduction_axes=2, keep_dims=true)

但是感觉非常非常难看

推荐答案

这是一个使用 tf.segment_sum 的解决方案.这个想法是将索引线性化为一维空间,使用 tf.unique 获取唯一索引,运行 tf.segment_sum,并将索引转换回 ND空间.

Here is a solution using tf.segment_sum. The idea is to linearize the indices in to a 1-D space, get the unique indices with tf.unique, run tf.segment_sum, and convert the indices back to N-D space.

indices = tf.constant([[1, 1], [1, 2], [1, 2], [1, 3]])
values = tf.constant([1, 2, 3, 4])

# Linearize the indices. If the dimensions of original array are
# [N_{k}, N_{k-1}, ... N_0], then simply matrix multiply the indices
# by [..., N_1 * N_0, N_0, 1]^T. For example, if the sparse tensor
# has dimensions [10, 6, 4, 5], then multiply by [120, 20, 5, 1]^T
# In your case, the dimensions are [10, 10], so multiply by [10, 1]^T

linearized = tf.matmul(indices, [[10], [1]])

# Get the unique indices, and their positions in the array
y, idx = tf.unique(tf.squeeze(linearized))

# Use the positions of the unique values as the segment ids to
# get the unique values
values = tf.segment_sum(values, idx)

# Go back to N-D indices
y = tf.expand_dims(y, 1)
indices = tf.concat([y//10, y%10], axis=1)

tf.InteractiveSession()
print(indices.eval())
print(values.eval())

这篇关于合并稀疏张量中的重复索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆