Keras模型-在自定义损失函数中获取输入 [英] Keras Model - Get input in custom loss function

查看:276
本文介绍了Keras模型-在自定义损失函数中获取输入的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在使用Keras Custom丢失功能时遇到了麻烦.我希望能够以numpy数组访问 truth . 因为它是一个回调函数,所以我认为我不在急切执行中,这意味着我无法使用backend.get_value()函数进行访问.我也尝试了不同的方法,但总会回到这个'Tensor'对象不存在的事实.

I am having trouble with Keras Custom loss function. I want to be able to access truth as a numpy array. Because it is a callback function, I think I am not in eager execution, which means I can't access it using the backend.get_value() function. i also tried different methods, but it always comes back to the fact that this 'Tensor' object doesn't exist.

我需要在自定义损失函数中创建一个会话吗?

Do I need to create a session inside the custom loss function ?

我正在使用最新的Tensorflow 2.2.

I am using Tensorflow 2.2, which is up to date.

def custom_loss(y_true, y_pred):

    # 4D array that has the label (0) and a multiplier input dependant
    truth = backend.get_value(y_true)

    loss = backend.square((y_pred - truth[:,:,0]) * truth[:,:,1])
    loss = backend.mean(loss, axis=-1)  

    return loss

 model.compile(loss=custom_loss, optimizer='Adam')
 model.fit(X, np.stack(labels, X[:, 0], axis=3), batch_size = 16)


我希望能够访问真相.它有两个组件(标签,乘数,每个项目都不相同.我看到了一个依赖于输入的解决方案,但是我不确定如何访问该值.

I want to be able to access truth. It has two components (Label, Multiplier that his different for each item. I saw a solution that is input dependant, but I am not sure how to access the value. Custom loss function in Keras based on the input data

推荐答案

我认为您可以通过如下所示在model.compile中启用run_eagerly=True来做到这一点.

I think you can do this by enabling run_eagerly=True in model.compile as shown below.

model.compile(loss=custom_loss(weight_building, weight_space),optimizer=keras.optimizers.Adam(), metrics=['accuracy'],run_eagerly=True)

我认为您还需要更新custom_loss,如下所示.

I think you also need to update custom_loss as shown below.

def custom_loss(weight_building, weight_space):
  def loss(y_true, y_pred):
    truth = backend.get_value(y_true)
    error = backend.square((y_pred - y_true))
    mse_error = backend.mean(error, axis=-1) 
    return mse_error
  return loss

我用一个简单的mnist数据演示了这个想法.请在此处查看代码.

I am demonstrating the idea with a simple mnist data. Please take a look at the code here.

这篇关于Keras模型-在自定义损失函数中获取输入的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆