如何在lambda层中获取批处理大小 [英] How to get the batch size inside lambda layer

查看:84
本文介绍了如何在lambda层中获取批处理大小的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试实现一个通过lambda层执行以下numpy过程的层:

I'm trying to implement a layer (via lambda layer) which is doing the following numpy procedure:

def func(x, n):
    return np.concatenate((x[:, :n], np.tile(x[:, n:].mean(axis = 0), (x.shape[0], 1))), axis = 1)

我被困住了,因为我不知道如何获得x的第一维的大小(即批处理大小).后端函数int_shape(x)返回(None, ...).

I'm stuck because I don't know how to get the size of the first dimension of x (which is the batch size). The backend function int_shape(x) returns (None, ...).

因此,如果我知道batch_size,则对应的Keras过程将是:

So, if I know the batch_size, the corresponding Keras procedure would be:

def func(x, n):
    return K.concatenate([x[:, :n], K.tile(K.mean(x[:, n:], axis=0), [batch_size, 1])], axis = 1)

推荐答案

正如@pitfall所说,K.tile的第二个参数应该是张量. 根据 keras后端文档K.shape返回张量,K.int_shape返回一个元组int或无条目.因此正确的方法是使用K.shape.以下是MWE:

Just as @pitfall says, the second argument of K.tile should be a tensor. And according to the doc of keras backend, K.shape returns a tensor and K.int_shape returns a tuple of int or None entries. So the correct way is to use K.shape. Following is the MWE:

import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
import numpy as np

batch_size = 8
op_len = ip_len = 10

def func(X):
    return K.tile(K.mean(X, axis=0, keepdims=True), (K.shape(X)[0], 1))

ip = Input((ip_len,))
lbd = Lambda(lambda x:func(x))(ip)

model = Model(ip, lbd)
model.summary()

model.compile('adam', loss='mse')

X = np.random.randn(batch_size*100, ip_len)
Y = np.random.randn(batch_size*100, op_len)
#no parameters to train!
#model.fit(X,Y,batch_size=batch_size)

#prediction
np_result = np.tile(np.mean(X[:batch_size], axis=0, keepdims=True), 
                    (batch_size,1))
pred_result = model.predict(X[:batch_size])
print(np.allclose(np_result, pred_result))

这篇关于如何在lambda层中获取批处理大小的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆