如何在嵌套循环中有效地计算上三角的logsexpexp? [英] How to efficiently compute logsumexp of upper triangle in a nested loop?

查看:79
本文介绍了如何在嵌套循环中有效地计算上三角的logsexpexp?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个嵌套的for循环,该循环在权重矩阵的行上进行迭代,并从这些权重行将logumexp应用于外部加法矩阵的上三角部分.它非常慢,因此我试图找出如何通过向量化或取出循环来代替矩阵运算来加快速度的方法.

I have a nested for loop that iterates over rows of the weight matrix and applies logsumexp to the upper triangular portion of the outer addition matrix from these weights rows. It is very slow so I'm trying to figure out how to speed this up by either vectorizing or taking out the loops in lieu of matrix operations.

'''
Wm: weights matrix, nxk
W: updated weights matrix, nxn
triu_inds: upper triangular indices of Wxy outer matrix
'''

for x in range(n-1):
    wx = Wm[x, :]
    for y in range(x+1, n):
        wy = Wm[y, :]
        Wxy = np.add.outer(wx, wy)
        Wxy = Wxy[triu_inds]
        W[x, y] = logsumexp(Wxy)

logsumexp:计算输入数组的指数和的对数

logsumexp: computes the log of the sum of exponentials of an input array

a: [1, 2, 3]
logsumexp(a) = log( exp(1) + exp(2) + exp(3) )

输入数据Wm是nxk维的权重矩阵. K代表患者传感器位置,n代表所有此类可能的传感器位置. Wm中的值基本上是患者传感器与已知传感器的距离.

The input data Wm is a weights matrix of nxk dimensions. K represents a patients sensor locations and n represents all such possible sensor locations. The values in Wm are basically how close a patients sensor is to a known sensor.

示例:

Wm  = [1   2   3]
      [4   5   6]
      [7   8   9]
      [10 11  12]

wx  = [1   2   3]
wy  = [4   5   6]

Wxy = [5   6   7]
      [6   7   8]
      [7   8   9]

triu_indices = ([0, 0, 1], [1, 2, 2])
Wxy[triu_inds] = [6, 7, 8]
logsumexp(Wxy[triu_inds]) = log(exp(6) + exp(7) + exp(8))

推荐答案

您可以在完整矩阵Wm上执行外积,然后交换与操作数1中的列和操作数2中的行相对应的轴以进行应用列的三角形索引.对于所有行组合,结果矩阵都会被填充,因此您需要选择上面的三角形部分.

You can perform the outer product on the full matrix Wm and then swap the axes corresponding to columns in operand 1 and rows in operand 2 in order to apply the triangle indices to the columns. The resulting matrix is filled for all combinations of rows, so you need to select the upper triangle part.

W = logsumexp(
    np.add.outer(Wm, Wm).swapaxes(1, 2)[(slice(None),)*2 + triu_inds],
    axis=-1  # Perform summation over last axis.
)
W = np.triu(W, k=1)

这篇关于如何在嵌套循环中有效地计算上三角的logsexpexp?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆