理解 PyTorch einsum [英] Understanding PyTorch einsum

查看:89
本文介绍了理解 PyTorch einsum的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我熟悉 einsum 适用于 NumPy.PyTorch 也提供了类似的功能:torch.einsum().在功能或性能方面有何异同?PyTorch 文档中提供的信息相当少,并且没有提供任何关于此的见解.

I'm familiar with how einsum works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum(). What are the similarities and differences, either in terms of functionality or performance? The information available at PyTorch documentation is rather scanty and doesn't provide any insights regarding this.

推荐答案

由于在 Torch 文档中对 einsum 的描述比较少,我决定写这篇文章来记录、比较和对比 torch.einsum()numpy.einsum().

Since the description of einsum is skimpy in torch documentation, I decided to write this post to document, compare and contrast how torch.einsum() behaves when compared to numpy.einsum().

区别:

  • NumPy 允许小写字母和大写字母 [a-zA-Z] 作为下标字符串"而 PyTorch 只允许小写字母 <代码>[az].

  • NumPy allows both small case and capitalized letters [a-zA-Z] for the "subscript string" whereas PyTorch allows only the small case letters [a-z].

NumPy 接受 nd 数组、纯 Python 列表(或元组)、列表列表(或元组元组、元组列表、列表元组)甚至 PyTorch 张量作为操作数(即输入).这是因为 操作数 必须是 array_like 而不是严格的 NumPy nd-arrays.相反,PyTorch 期望 操作数(即输入)严格为 PyTorch 张量.如果您传递纯 Python 列表/元组(或其组合)或 NumPy nd-arrays,它将抛出 TypeError.

NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a TypeError if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

除了 nd-arrays 之外,NumPy 还支持许多关键字参数(例如 optimize),而 PyTorch 还没有提供这种灵活性.

NumPy supports lot of keyword arguments (for e.g. optimize) in addition to nd-arrays while PyTorch doesn't offer such flexibility yet.

以下是一些示例在 PyTorch 和 NumPy 中的实现:

Here are the implementations of some examples both in PyTorch and NumPy:

# input tensors to work with

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]: 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [18]: bten
Out[18]: 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

<小时>

1) 矩阵乘法
PyTorch: torch.matmul(aten, bten) ;aten.mm(bten)
NumPy : np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]: 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

2) 沿主对角线提取元素
PyTorch:torch.diag(aten)
NumPy : np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])

3) Hadamard 乘积(即两个张量的元素乘积)
PyTorch:aten * bten
NumPy : np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]: 
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

4) 逐元素平方
PyTorch:aten ** 2
NumPy : np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]: 
tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

General:可以通过重复下标字符串和张量n次来实现Element-wise nth次幂.例如,可以使用以下方法计算张量的元素级 4 次方:

General: Element-wise nth power can be implemented by repeating the subscript string and tensor n times. For e.g., computing element-wise 4th power of a tensor can be done using:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]: 
tensor([[  14641,   20736,   28561,   38416],
        [ 194481,  234256,  279841,  331776],
        [ 923521, 1048576, 1185921, 1336336],
        [2825761, 3111696, 3418801, 3748096]])

5) 迹线(即主对角线元素的总和)
PyTorch:torch.trace(aten)
NumPy einsum: np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)

6) 矩阵转置
PyTorch:torch.transpose(aten, 1, 0)
NumPy einsum: np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten)
Out[58]: 
tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

7) 外积(向量的)
PyTorch:torch.ger(vec, vec)
NumPy einsum: np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]: 
tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

8) 内积(向量的)PyTorch:torch.dot(vec1, vec2)
NumPy einsum: np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)

9) 沿 0 轴求和
PyTorch:torch.sum(aten, 0)
NumPy einsum: np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])

10) 沿轴 1 求和
PyTorch:torch.sum(aten, 1)
NumPy einsum: np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])

11) 批量矩阵乘法
PyTorch:torch.bmm(batch_tensor_1, batch_tensor_2)
NumPy : np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 

In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  
Out[15]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape 
Out[16]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape

12) 沿轴 2 求和
PyTorch:torch.sum(batch_ten, 2)
NumPy einsum: np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]: 
tensor([[ 50,  90, 130, 170],
        [  4,   8,  12,  16]])

13) 对 nD 张量中的所有元素求和
PyTorch:torch.sum(batch_ten)
NumPy einsum: np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)

14) 在多个轴上求和(即边缘化)
PyTorch:torch.sum(arr,dim=(dim0,dim1,dim2,dim3,dim4,dim6,dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)

# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True

15) 双点产品/Frobenius 内产品(同:torch.sum(hadamard-product) 参见 3)
PyTorch:torch.sum(aten * bten)
NumPy : np.einsum("ij, ij -> ", arr1, arr2)

15) Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)
PyTorch: torch.sum(aten * bten)
NumPy : np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)

这篇关于理解 PyTorch einsum的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆