使用 NumPy 进行快速张量旋转 [英] Fast tensor rotation with NumPy

查看:30
本文介绍了使用 NumPy 进行快速张量旋转的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在应用程序的核心(用 Python 编写并使用 NumPy)我需要旋转一个四阶张量.实际上,我需要多次旋转很多张量,这是我的瓶颈.我涉及八个嵌套循环的幼稚实现(如下)似乎很慢,但我看不出有什么方法可以利用 NumPy 的矩阵运算,并希望能加快速度.我觉得我应该使用 np.tensordot,但我不知道如何使用.

At the heart of an application (written in Python and using NumPy) I need to rotate a 4th order tensor. Actually, I need to rotate a lot of tensors many times and this is my bottleneck. My naive implementation (below) involving eight nested loops seems to be quite slow, but I cannot see a way to leverage NumPy's matrix operations and, hopefully, speed things up. I've a feeling I should be using np.tensordot, but I don't see how.

在数学上,旋转张量 T' 的元素由下式给出: T'ijkl = Σgia gjb gkc gld Tabcd 总和结束右侧的重复索引.T 和 Tprime 是 3*3*3*3 NumPy 数组,旋转矩阵 g 是 3*3 NumPy 数组.我的缓慢实现(每次调用大约需要 0.04 秒)如下.

Mathematically, elements of the rotated tensor, T', are given by: T'ijkl = Σ gia gjb gkc gld Tabcd with the sum being over the repeated indices on the right hand side. T and Tprime are 3*3*3*3 NumPy arrays and the rotation matrix g is a 3*3 NumPy array. My slow implementation (taking ~0.04 seconds per call) is below.

#!/usr/bin/env python

import numpy as np

def rotT(T, g):
    Tprime = np.zeros((3,3,3,3))
    for i in range(3):
        for j in range(3):
            for k in range(3):
                for l in range(3):
                    for ii in range(3):
                        for jj in range(3):
                            for kk in range(3):
                                for ll in range(3):
                                    gg = g[ii,i]*g[jj,j]*g[kk,k]*g[ll,l]
                                    Tprime[i,j,k,l] = Tprime[i,j,k,l] + 
                                         gg*T[ii,jj,kk,ll]
    return Tprime

if __name__ == "__main__":

    T = np.array([[[[  4.66533067e+01,  5.84985000e-02, -5.37671310e-01],
                    [  5.84985000e-02,  1.56722231e+01,  2.32831900e-02],
                    [ -5.37671310e-01,  2.32831900e-02,  1.33399259e+01]],
                   [[  4.60051700e-02,  1.54658176e+01,  2.19568200e-02],
                    [  1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
                    [  2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
                   [[ -5.35577630e-01,  1.95558600e-02,  1.31108757e+01],
                    [  1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
                    [  1.31108757e+01, -6.67615000e-03,  6.90486240e-01]]],
                  [[[  4.60051700e-02,  1.54658176e+01,  2.19568200e-02],
                    [  1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
                    [  2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
                   [[  1.57414726e+01, -3.86167500e-02, -1.55971950e-01],
                    [ -3.86167500e-02,  4.65601977e+01, -3.57741000e-02],
                    [ -1.55971950e-01, -3.57741000e-02,  1.34215636e+01]],
                   [[  2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
                    [ -1.49072770e-01, -3.63410500e-02,  1.32039847e+01],
                    [ -7.38843000e-03,  1.32039847e+01,  1.38172700e-02]]],
                  [[[ -5.35577630e-01,  1.95558600e-02,  1.31108757e+01],
                    [  1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
                    [  1.31108757e+01, -6.67615000e-03,  6.90486240e-01]],
                   [[  2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
                    [ -1.49072770e-01, -3.63410500e-02,  1.32039847e+01],
                    [ -7.38843000e-03,  1.32039847e+01,  1.38172700e-02]],
                   [[  1.33639532e+01, -1.26331100e-02,  6.84650400e-01],
                    [ -1.26331100e-02,  1.34222177e+01,  1.67851800e-02],
                    [  6.84650400e-01,  1.67851800e-02,  4.89151396e+01]]]])

    g = np.array([[ 0.79389393,  0.54184237,  0.27593346],
                  [-0.59925749,  0.62028664,  0.50609776],
                  [ 0.10306737, -0.56714313,  0.8171449 ]])

    for i in range(100):
        Tprime = rotT(T,g)

有没有办法让它更快?将代码推广到其他张量等级会很有用,但不太重要.

Is there a way to make this go faster? Making the code generalise to other ranks of tensor would be useful, but is less important.

推荐答案

要使用 tensordot,计算 g 张量的外积:

To use tensordot, compute the outer product of the g tensors:

def rotT(T, g):
    gg = np.outer(g, g)
    gggg = np.outer(gg, gg).reshape(4 * g.shape)
    axes = ((0, 2, 4, 6), (0, 1, 2, 3))
    return np.tensordot(gggg, T, axes)

在我的系统上,这大约比 Sven 的解决方案快七倍.如果 g 张量不经常变化,您还可以缓存 gggg 张量.如果您这样做并打开一些微优化(内联 tensordot 代码,没有检查,没有通用形状),您仍然可以使其速度提高两倍:

On my system, this is around seven times faster than Sven's solution. If the g tensor doesn't change often, you can also cache the gggg tensor. If you do this and turn on some micro-optimizations (inlining the tensordot code, no checks, no generic shapes), you can still make it two times faster:

def rotT(T, gggg):
    return np.dot(gggg.transpose((1, 3, 5, 7, 0, 2, 4, 6)).reshape((81, 81)),
                  T.reshape(81, 1)).reshape((3, 3, 3, 3))

timeit 在我家用笔记本电脑上的结果(500 次迭代):

Results of timeit on my home laptop (500 iterations):

Your original code: 19.471129179
Sven's code: 0.718412876129
My first code: 0.118047952652
My second code: 0.0690279006958

我的工作机器上的数字是:

The numbers on my work machine are:

Your original code: 9.77922987938
Sven's code: 0.137110948563
My first code: 0.0569641590118
My second code: 0.0308079719543

这篇关于使用 NumPy 进行快速张量旋转的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆