从阵列的每一行中减去多个向量(超级广播) [英] Subtracting multiple vectors from each row of an array (super broadcasting)

查看:77
本文介绍了从阵列的每一行中减去多个向量(超级广播)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个数据集,Xm x 2,三个向量存储在一个矩阵C = [c1'; c2'; c3']中,它是3 x 2.我正在尝试对我的代码进行矢量化处理,该代码针对X中的每个数据点,找到C中的哪个矢量最接近(平方距离).我想从X中的每个矢量(行)中减去C中的每个矢量(行),从而得到m x 63m x 2元素之间的差值的m x 63m x 2矩阵c5>.我当前的实现一次在X中执行这一行:

I have a data set, X that is m x 2, and three vectors stored in a matrix C = [c1'; c2'; c3'] that is 3 x 2. I am trying to vectorize my code that finds, for each data point in X, which vector in C is closest (squared distance). I would like to subtract each vector (row) in C from each vector (row) in X, resulting in an m x 6 or 3m x 2 matrix of differences between the elements of X and the elements of C. My current implementation does this one row in X at a time:

for i = 1:size(X, 1)
    diffs = bsxfun(@minus, X(i,:), C);    % gives a 3 x 2 matrix result
    [~, idx(i)] = min(sumsq(diffs), 2);   % returns the index of the closest vector
                                          % in C to the ith vector in X
end

我想摆脱这个for循环,只对整个对象进行矢量化处理,但是bsxfun(@minus, X, C)给我一个八度音阶错误:

I want to get rid of this for loop and just vectorize the whole thing, but bsxfun(@minus, X, C) gives me a an error in Octave:

错误:bsxfun:不符合规定的尺寸:300x2和3x2

error: bsxfun: nonconformant dimensions: 300x2 and 3x2

有什么想法可以在两个矩阵之间超广播"我的减法运算吗?

Any ideas how I can "super-broadcast" my subtraction operation between these two matrices?

推荐答案

此问题的核心是计算尺寸为m x 3的距离矩阵D,该矩阵包含X中所有数据点之间的成对距离,并且C中的所有数据点.可以将X中的第i个矢量x_iC中的第j个矢量c_j之间的欧几里德距离改写为:

The core of this problem is to compute a distance matrix D of size m x 3 that contains the pairwise distances between all data points in X and all data points in C. The Euclidean distance between the i-th vector x_i in X and the j-th vector c_j in C can be rewritten as:

|x_i-c_j|^2 = |x_i|^2 - 2<x_i, c_j> + |c_j|^2

其中<,>表示内部乘积.该方程式的右侧可以很容易地矢量化,因为所有对的内积只是BLAS3运算的X * C'.这种计算距离矩阵的方法在Christopher Bishop的《模式识别和机器学习》一书中称为dist2函数.我复制了下面的功能并做了一些修改.

where <,> refers to inner product. The right-hand side of this equation can be easily vectorized, because the inner product of all pairs is just X * C' which is BLAS3 operation. This way of computing the distance matrix is known as dist2 function in the book Pattern Recognition and Machine Learning by Christopher Bishop. I copy the function below with a little modification.

function D = dist2(X, C)        
    tempx = full(sum(X.^2, 2));
    tempc = full(sum(C.^2, 2).');
    D = -2*(X * C.');
    D = bsxfun(@plus, D, tempx);
    D = bsxfun(@plus, D, tempc);

这里full用于XC是稀疏矩阵的情况.

The full here is used in case X or C is a sparse matrix.

注意:由于数值舍入误差,用这种方法计算的距离矩阵D可能只有很小的负值.为防止这种情况,请使用

Note: The distance matrix D computed this way might have tiny negative entries due to numerical rounding error. To guard against this case, use

D = max(D, 0);

可以从D检索C中最接近的向量的索引:

The indices of the closest vector in C can be retrieved from D:

[~, idx] = min(D, [], 2);

这篇关于从阵列的每一行中减去多个向量(超级广播)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆