执行矩阵乘法时出现内存错误 [英] memory error while performing matrix multiplication

查看:156
本文介绍了执行矩阵乘法时出现内存错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

作为我正在进行的项目的一部分,我需要计算2m向量之间的均方误差.

As part of a project I'm working on, I need to calculate the mean squared error between 2m vectors.

基本上,我有两个矩阵xxhat,它们的大小都是m x n,我感兴趣的向量是这些向量的行.

Basically I two matrices x and xhat, both are size m by n and the vectors I'm interested in are the rows of these vectors.

我用此代码计算MSE

I calculate the MSE with this code

def cost(x, xhat): #mean squared error between x the data and xhat the output of the machine
    return (1.0/(2 * m)) * np.trace(np.dot(x-xhat,(x-xhat).T))

它工作正常,此公式正确.

It's working correctly, this formula is correct.

问题是在我的特定情况下,我的mn很大.具体来说,是m = 60000n = 785.因此,当我运行我的代码并进入此函数时,我会遇到内存错误.

The problem is that in my specific case, my m and n are very large. specifically, m = 60000 and n = 785. So when I run my code and it enters this function, I get a memory error.

是否有更好的方法来计算MSE?我宁愿避免for循环,而我非常倾向于矩阵乘法,但是矩阵乘法在这里似乎非常浪费.也许一些我不知道的numpy内容?

Is there a better way to calculate the MSE? I'd rather avoid for loops and I lean heavily towards matrix multiplication, but matrix multiplication seems extremely wasteful here. Maybe something in numpy I'm not aware of?

推荐答案

表达式np.dot(x-xhat,(x-xhat).T)创建一个形状为(m,m)的数组.您说m为60000,因此该数组几乎为29 GB.

The expression np.dot(x-xhat,(x-xhat).T) creates an array with shape (m, m). You say m is 60000, so that array is almost 29 gigabytes.

您需要跟踪数组,这只是对角线元素的总和,因此该巨大数组中的大多数未使用.如果仔细观察np.trace(np.dot(x-xhat,(x-xhat).T)),您会发现它只是x - xhat所有元素的平方和.因此,不需要大型中间数组的简单计算np.trace(np.dot(x-xhat,(x-xhat).T))的方法是((x - xhat)**2).sum().例如,

You take the trace of the array, which is just the sum of the diagonal elements, so most of that huge array is unused. If you look carefully at np.trace(np.dot(x-xhat,(x-xhat).T)), you'll see that it is just the sum of the squares of all the elements of x - xhat. So a simpler way to compute np.trace(np.dot(x-xhat,(x-xhat).T)) that doesn't require the huge intermediate array is ((x - xhat)**2).sum(). For example,

In [44]: x
Out[44]: 
array([[ 0.87167186,  0.96838389,  0.72545457],
       [ 0.05803253,  0.57355625,  0.12732163],
       [ 0.00874702,  0.01555692,  0.76742386],
       [ 0.4130838 ,  0.89307633,  0.49532327],
       [ 0.15929044,  0.27025289,  0.75999848]])

In [45]: xhat
Out[45]: 
array([[ 0.20825392,  0.63991699,  0.28896932],
       [ 0.67658621,  0.64919721,  0.31624655],
       [ 0.39460861,  0.33057769,  0.24542263],
       [ 0.10694332,  0.28030777,  0.53177585],
       [ 0.21066692,  0.53096774,  0.65551612]])

In [46]: np.trace(np.dot(x-xhat,(x-xhat).T))
Out[46]: 2.2352330441581061

In [47]: ((x - xhat)**2).sum()
Out[47]: 2.2352330441581061

有关计算MSE的详细信息,请参阅user1984065提供的链接评论.

For more ideas about computing the MSE, see the link provided by user1984065 in a comment.

这篇关于执行矩阵乘法时出现内存错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆