Theano广播与numpy的广播不同 [英] Theano broadcasting different to numpy's

查看:79
本文介绍了Theano广播与numpy的广播不同的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

请考虑以下numpy广播示例:

Consider the following example of numpy broadcasting:

import numpy as np
import theano
from theano import tensor as T

xval = np.array([[1, 2, 3], [4, 5, 6]])
bval = np.array([[10, 20, 30]])
print xval + bval

如预期的那样,向量bval被添加到矩阵xval的每一行中,输出为:

As expected, the vector bval is added to each rows of the matrix xval and the output is:

[[11 22 33]
 [14 25 36]]

尝试在theano的git版本中复制相同的行为:

Trying to replicate the same behaviour in the git version of theano:

x = T.dmatrix('x')
b = theano.shared(bval)
z = x + b
f = theano.function([x], z)

print f(xval)

我收到以下错误:

ValueError: Input dimension mis-match. (input[0].shape[0] = 2, input[1].shape[0] = 1)
Apply node that caused the error: Elemwise{add,no_inplace}(x, <TensorType(int64, matrix)>)
Inputs types: [TensorType(float64, matrix), TensorType(int64, matrix)]
Inputs shapes: [(2, 3), (1, 3)]
Inputs strides: [(24, 8), (24, 8)]
Inputs scalar values: ['not scalar', 'not scalar']

我了解Tensor对象(例如x)具有broadcastable属性,但是我找不到找到以下方法的方法:1)为shared对象正确设置此属性,或2)正确推断出该属性.如何在theano中重新实现numpy的行为?

I understand Tensor objects such as x have a broadcastable attribute, but I can't find a way to 1) set this correctly for the shared object or 2) have it correctly inferred. How can I re-implement numpy's behaviour in theano?

推荐答案

Theano需要在编译之前在图中声明所有可广播的维. NumPy使用运行时形状信息.

Theano need all broadcastable dimensions to be declared in the graph before compilation. NumPy use the run time shape information.

默认情况下,所有共享变量的尺寸均不可广播,因为它们的形状可能会发生变化.

By default, all shared variable dimsions aren't broadcastable, as their shape could change.

要创建示例中所需的具有可广播维度的共享变量,请执行以下操作:

To create the shared variable with the broadcastable dimension that you need in your example:

b = theano.shared(bval, broadcastable=(True,False))

我会将这些信息添加到文档中.

I'll add this information to the documentation.

这篇关于Theano广播与numpy的广播不同的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆