Keras中具有附加动态参数的自定义自适应损失函数 [英] Custom adaptive loss function with additional dynamic argument in Keras

查看:72
本文介绍了Keras中具有附加动态参数的自定义自适应损失函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我必须使用自适应自定义损失函数,该函数在keras中需要一个附加的动态参数( eps ).参数 eps 是一个标量,但从一个批次更改为另一个批次:因此,在训练过程中应调整损失函数.我使用生成器,并且可以在训练过程中通过生成器的每个调用传递此参数( generator_train [2] ).基于对类似问题的回答,我尝试编写以下包装:

I have to use an adaptive custom loss function that takes an additional dynamic argument (eps) in keras. The argument eps is a scalar but changes from one batch to the other : the loss function should be therefore adapted during training. I use a generator and I can pass this argument through every call of the generator during training (generator_train[2]). Based on answers to similar questions I tried to write the following wrapping:

def custom_loss(eps):
    def square_err(y_true, y_pred):
        nom = K.sum(K.square(y_pred - y_true), axis=-1)
        denom = eps**2
        loss = nom/denom
        return loss
    return square_err

但是我很难实现它,因为 eps 是一个动态变量:我不知道在训练过程中应如何将此参数传递给损失函数( model.fit ).这是我模型的一个简单版本:

But I am struggling with implementing it since eps is a dynamic variable: I don't know how I should pass this argument to the loss function during training (model.fit). Here is a simple version of my model:

model = keras.Sequential()
model.add(layers.LSTM(units=32, input_shape=(32, 4))
model.add(layers.Dense(units=1))
model.add_loss(custom_loss)
opt = keras.optimizers.Adam()
model.compile(optimizer=opt)
history = model.fit(x=generator_train[0], y=generator_train[1],
                    steps_per_epoch=100
                    epochs=50,
                    validation_data=gen_vl,
                    validation_steps=n_vl)

非常感谢您的帮助.

推荐答案

简单地传递样本权重",每个样本的权重将为 1/(eps ** 2).

Simply pass "sample weights", which will be 1/(eps**2) for each sample.

您的生成器应该只输出 x,y,sample_weights ,仅此而已.

Your generator should just output x, y, sample_weights and that's all.

您的损失可能是:

def loss(y_true, y_pred):
    return K.sum(K.square(y_pred - y_true), axis=-1)

fit 中,您不能在生成器中使用索引,您将只传递 generator_train ,不传递 x ,不传递 y ,只需 generator_train .

In fit, you cannot use indexing in the generator, you will pass just generator_train, no x, no y, just generator_train.

这篇关于Keras中具有附加动态参数的自定义自适应损失函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆