从阵列的每一行中减去多个向量(超级广播) [英] Subtracting multiple vectors from each row of an array (super broadcasting)
问题描述
我有一个数据集,X
是m x 2
,三个向量存储在一个矩阵C = [c1'; c2'; c3']
中,它是3 x 2
.我正在尝试对我的代码进行矢量化处理,该代码针对X
中的每个数据点,找到C
中的哪个矢量最接近(平方距离).我想从X
中的每个矢量(行)中减去C
中的每个矢量(行),从而得到m x 6
或3m x 2
元素之间的差值的m x 6
或3m x 2
矩阵c5>.我当前的实现一次在X
中执行这一行:
I have a data set, X
that is m x 2
, and three vectors stored in a matrix C = [c1'; c2'; c3']
that is 3 x 2
. I am trying to vectorize my code that finds, for each data point in X
, which vector in C
is closest (squared distance). I would like to subtract each vector (row) in C
from each vector (row) in X
, resulting in an m x 6
or 3m x 2
matrix of differences between the elements of X
and the elements of C
. My current implementation does this one row in X
at a time:
for i = 1:size(X, 1)
diffs = bsxfun(@minus, X(i,:), C); % gives a 3 x 2 matrix result
[~, idx(i)] = min(sumsq(diffs), 2); % returns the index of the closest vector
% in C to the ith vector in X
end
我想摆脱这个for
循环,只对整个对象进行矢量化处理,但是bsxfun(@minus, X, C)
给我一个八度音阶错误:
I want to get rid of this for
loop and just vectorize the whole thing, but bsxfun(@minus, X, C)
gives me a an error in Octave:
错误:bsxfun:不符合规定的尺寸:300x2和3x2
error: bsxfun: nonconformant dimensions: 300x2 and 3x2
有什么想法可以在两个矩阵之间超广播"我的减法运算吗?
Any ideas how I can "super-broadcast" my subtraction operation between these two matrices?
推荐答案
此问题的核心是计算尺寸为m x 3
的距离矩阵D
,该矩阵包含X
中所有数据点之间的成对距离,并且C
中的所有数据点.可以将X
中的第i个矢量x_i
和C
中的第j个矢量c_j
之间的欧几里德距离改写为:
The core of this problem is to compute a distance matrix D
of size m x 3
that contains the pairwise distances between all data points in X
and all data points in C
. The Euclidean distance between the i-th vector x_i
in X
and the j-th vector c_j
in C
can be rewritten as:
|x_i-c_j|^2 = |x_i|^2 - 2<x_i, c_j> + |c_j|^2
其中<,>表示内部乘积.该方程式的右侧可以很容易地矢量化,因为所有对的内积只是BLAS3运算的X * C'
.这种计算距离矩阵的方法在Christopher Bishop的《模式识别和机器学习》一书中称为dist2
函数.我复制了下面的功能并做了一些修改.
where <,> refers to inner product. The right-hand side of this equation can be easily vectorized, because the inner product of all pairs is just X * C'
which is BLAS3 operation. This way of computing the distance matrix is known as dist2
function in the book Pattern Recognition and Machine Learning by Christopher Bishop. I copy the function below with a little modification.
function D = dist2(X, C)
tempx = full(sum(X.^2, 2));
tempc = full(sum(C.^2, 2).');
D = -2*(X * C.');
D = bsxfun(@plus, D, tempx);
D = bsxfun(@plus, D, tempc);
这里full
用于X
或C
是稀疏矩阵的情况.
The full
here is used in case X
or C
is a sparse matrix.
注意:由于数值舍入误差,用这种方法计算的距离矩阵D
可能只有很小的负值.为防止这种情况,请使用
Note: The distance matrix D
computed this way might have tiny negative entries due to numerical rounding error. To guard against this case, use
D = max(D, 0);
可以从D
检索C
中最接近的向量的索引:
The indices of the closest vector in C
can be retrieved from D
:
[~, idx] = min(D, [], 2);
这篇关于从阵列的每一行中减去多个向量(超级广播)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!