numpy.linalg.solve的右侧超过三个维度 [英] numpy.linalg.solve with right-hand side of more than three dimensions

查看:447
本文介绍了numpy.linalg.solve的右侧超过三个维度的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试求解一个3x3矩阵a和任意形状(3, ...)的右侧b的方程组.如果b具有一维或二维,则numpy.linalg.solve可以解决问题.不过,它可以分解为更多尺寸:

I'm trying to solve an equation system with a 3x3 matrix a and a right hand side b of arbitrary shape (3, ...). If b has one or two dimensions, numpy.linalg.solve does the trick. It breaks down for more dimensions though:

import numpy

a = numpy.random.rand(3, 3)

b = numpy.random.rand(3)
numpy.linalg.solve(a, b)  # okay

b = numpy.random.rand(3, 4)
numpy.linalg.solve(a, b)  # okay

b = numpy.random.rand(3, 4, 5)
numpy.linalg.solve(a, b)  # ERR

ValueError: solve: Input operand 1 has a mismatch in its core 
dimension 0, with gufunc signature (m,m),(m,n)->(m,n) (size 5 is 
different from 3)

我希望形状为(3, 4, 5)的输出数组sol的解决方案对应于右侧b[:, i, j]sol[:, i, j].

I would have expected an output array sol of shape (3, 4, 5) with the solution corresponding to the right-hand side b[:, i, j] is sol[:, i, j].

关于如何最好地解决此问题的任何提示?

Any hint on how to best work around this?

推荐答案

暂时将b整形为(3, 20),求解线性系统,然后将结果数组整形为b的原始形状(3,4, 5):

Temporarily reshape b to (3, 20), solve the linear system, and then reshape the resultant array to original shape of b (3, 4, 5):

In [34]: a = numpy.random.rand(3, 3)
In [35]: b = numpy.random.rand(3, 4, 5)

In [36]: x = numpy.linalg.solve(a, b.reshape(b.shape[0], -1)).reshape(b.shape)

OR

使用b的第一个轴与第二个轴交换. "nofollow noreferrer"> np.swapaxes ,求解线性系统,然后还原轴:

Swap the first axis of b with the second using np.swapaxes, solve the linear system, and then restore the axes:

In [58]: x = np.swapaxes(np.linalg.solve(a, np.swapaxes(b, 0, 1)), 0, 1)


健全性检查:


Sanity Check:

In [38]: np.einsum('ij,jkl', a, x)
Out[38]: 
array([[[ 0.44859955,  0.22967928,  0.74336067,  0.47440575,  0.53798895],
        [ 0.80045696,  0.54138958,  0.89870834,  0.56862419,  0.28217437],
        [ 0.02093982,  0.78534718,  0.77208236,  0.41568151,  0.95100661],
        [ 0.03820421,  0.47067312,  0.71928294,  0.30852615,  0.64454321]],

       [[ 0.31757072,  0.30527186,  0.36768759,  0.95869289,  0.86601996],
        [ 0.60616508,  0.69927063,  0.53470332,  0.88906606,  0.76066344],
        [ 0.95411847,  0.51116677,  0.29338398,  0.04418815,  0.96210206],
        [ 0.23449429,  0.64159963,  0.7732404 ,  0.4314741 ,  0.81279619]],

       [[ 0.6399571 ,  0.57640652,  0.0186913 ,  0.66304489,  0.83372239],
        [ 0.28426522,  0.62367363,  0.37163699,  0.78217433,  0.90573787],
        [ 0.91066088,  0.06699638,  0.43079394,  0.00263537,  0.399102  ],
        [ 0.17711441,  0.48724858,  0.05526752,  0.34251648,  0.94059739]]])

In [39]: b
Out[39]: 
array([[[ 0.44859955,  0.22967928,  0.74336067,  0.47440575,  0.53798895],
        [ 0.80045696,  0.54138958,  0.89870834,  0.56862419,  0.28217437],
        [ 0.02093982,  0.78534718,  0.77208236,  0.41568151,  0.95100661],
        [ 0.03820421,  0.47067312,  0.71928294,  0.30852615,  0.64454321]],

       [[ 0.31757072,  0.30527186,  0.36768759,  0.95869289,  0.86601996],
        [ 0.60616508,  0.69927063,  0.53470332,  0.88906606,  0.76066344],
        [ 0.95411847,  0.51116677,  0.29338398,  0.04418815,  0.96210206],
        [ 0.23449429,  0.64159963,  0.7732404 ,  0.4314741 ,  0.81279619]],

       [[ 0.6399571 ,  0.57640652,  0.0186913 ,  0.66304489,  0.83372239],
        [ 0.28426522,  0.62367363,  0.37163699,  0.78217433,  0.90573787],
        [ 0.91066088,  0.06699638,  0.43079394,  0.00263537,  0.399102  ],
        [ 0.17711441,  0.48724858,  0.05526752,  0.34251648,  0.94059739]]])

使用 np.allclose() ,这样您就不必手动检查数字并检查,尤其是对于大型数组:

Use np.allclose() so that you don't have to manually going through the numbers and check, particularly for large arrays:

In [32]: b_ = np.einsum('ij,jkl', a, x)

In [33]: np.allclose(b, b_)
Out[33]: True

这篇关于numpy.linalg.solve的右侧超过三个维度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆