Theano 中的 numpy.matmul [英] numpy.matmul in Theano

查看:145
本文介绍了Theano 中的 numpy.matmul的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

TL;DR
我想在 theano 中复制 numpy.matmul 的功能.这样做的最佳方法是什么?

TL;DR
I want to replicate the functionality of numpy.matmul in theano. What's the best way to do this?

太短;没看懂
查看 theano.tensor.dottheano.tensor.tensordot,我没有看到一种简单的方法来进行简单的批量矩阵乘法.即将 N 维张量的最后两个维度视为矩阵,并将它们相乘.我是否需要使用 theano.tensor.batched_dot 的一些愚蠢用法?或者*不寒而栗*自己循环播放而不广播!?

Too Short; Didn't Understand
Looking at theano.tensor.dot and theano.tensor.tensordot, I'm not seeing an easy way to do a straightforward batch matrix multiplication. i.e. treat the last two dimensions of N dimensional tensors as matrices, and multiply them. Do I need to resort to some goofy usage of theano.tensor.batched_dot? Or *shudder* loop them myself without broadcasting!?

推荐答案

目前的 pull requests 不支持广播,所以我暂时想到了这个.我可能会清理它,添加更多功能,并提交我自己的 PR 作为临时解决方案.在那之前,我希望这对某人有所帮助!鉴于输入符合我更严格的(临时)断言,我包含了测试以显示它复制了 numpy.matmul.

The current pull requests don't support broadcasting, so I came up with this for now. I may clean it up, add a little more functionality, and submit my own PR as a temporary solution. Until then, I hope this helps someone! I included the test to show it replicates numpy.matmul, given that the input complies with my more strict (temporary) assertions.

此外,.scan 在 argmin(*sequencelengths) 次迭代时停止迭代序列.所以,我相信不匹配的数组形状不会引发任何异常.

Also, .scan stops iterating the sequences at argmin(*sequencelengths) iterations. So, I believe that mismatched array shapes won't raise any exceptions.

import theano as th
import theano.tensor as tt
import numpy as np


def matmul(a: tt.TensorType, b: tt.TensorType, _left=False):
    """Replicates the functionality of numpy.matmul, except that
    the two tensors must have the same number of dimensions, and their ndim must exceed 1."""

    # TODO ensure that broadcastability is maintained if both a and b are broadcastable on a dim.

    assert a.ndim == b.ndim  # TODO support broadcasting for differing ndims.
    ndim = a.ndim
    assert ndim >= 2

    # If we should left multiply, just swap references.
    if _left:
        tmp = a
        a = b
        b = tmp

    # If a and b are 2 dimensional, compute their matrix product.
    if ndim == 2:
        return tt.dot(a, b)
    # If they are larger...
    else:
        # If a is broadcastable but b is not.
        if a.broadcastable[0] and not b.broadcastable[0]:
            # Scan b, but hold a steady.
            # Because b will be passed in as a, we need to left multiply to maintain
            #  matrix orientation.
            output, _ = th.scan(matmul, sequences=[b], non_sequences=[a[0], 1])
        # If b is broadcastable but a is not.
        elif b.broadcastable[0] and not a.broadcastable[0]:
            # Scan a, but hold b steady.
            output, _ = th.scan(matmul, sequences=[a], non_sequences=[b[0]])
        # If neither dimension is broadcastable or they both are.
        else:
            # Scan through the sequences, assuming the shape for this dimension is equal.
            output, _ = th.scan(matmul, sequences=[a, b])
        return output


def matmul_test() -> bool:
    vlist = []
    flist = []
    ndlist = []
    for i in range(2, 30):
        dims = int(np.random.random() * 4 + 2)

        # Create a tuple of tensors with potentially different broadcastability.
        vs = tuple(
            tt.TensorVariable(
                tt.TensorType('float64',
                              tuple((p < .3) for p in np.random.ranf(dims-2))
                              # Make full matrices
                              + (False, False)
                )
            )
            for _ in range(2)
        )
        vs = tuple(tt.swapaxes(v, -2, -1) if j % 2 == 0 else v for j, v in enumerate(vs))

        f = th.function([*vs], [matmul(*vs)])

        # Create the default shape for the test ndarrays
        defshape = tuple(int(np.random.random() * 5 + 1) for _ in range(dims))
        # Create a test array matching the broadcastability of each v, for each v.
        nds = tuple(
            np.random.ranf(
                tuple(s if not v.broadcastable[j] else 1 for j, s in enumerate(defshape))
            )
            for v in vs
        )
        nds = tuple(np.swapaxes(nd, -2, -1) if j % 2 == 0 else nd for j, nd in enumerate(nds))

        ndlist.append(nds)
        vlist.append(vs)
        flist.append(f)

    for i in range(len(ndlist)):
        assert np.allclose(flist[i](*ndlist[i]), np.matmul(*ndlist[i]))

    return True


if __name__ == "__main__":
    print("matmul_test -> " + str(matmul_test()))

这篇关于Theano 中的 numpy.matmul的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆