如何在TensorFlow和Keras的损失函数中打印中间变量? [英] How can I print the intermediate variables in the loss function in TensorFlow and Keras?

查看:772
本文介绍了如何在TensorFlow和Keras的损失函数中打印中间变量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在编写一个自定义目标来训练Keras(带有TensorFlow后端)模型,但是我需要调试一些中间计算.为了简单起见,假设我有:

def custom_loss(y_pred, y_true):
    diff = y_pred - y_true
    return K.square(diff)

我找不到一种简单的方法,例如在训练过程中访问中间变量diff或它的形状.在这个简单的示例中,我知道我可以返回diff来打印其值,但是我的实际损失更为复杂,并且在没有得到编译错误的情况下我无法返回中间值.

有没有一种简单的方法可以在Keras中调试中间变量?

解决方案

据我所知,这不是在Keras中解决的问题,因此您必须诉诸于特定于后端的功能. Theano Lambda 层中.

I'm writing a custom objective to train a Keras (with TensorFlow backend) model but I need to debug some intermediate computation. For simplicity, let's say I have:

def custom_loss(y_pred, y_true):
    diff = y_pred - y_true
    return K.square(diff)

I could not find an easy way to access, for example, the intermediate variable diff or its shape during training. In this simple example, I know that I could return diff to print its values, but my actual loss is more complex and I can't return intermediate values without getting compiling errors.

Is there an easy way to debug intermediate variables in Keras?

解决方案

This is not something that is solved in Keras as far as I know, so you have to resort to backend-specific functionality. Both Theano and TensorFlow have Print nodes that are identity nodes (i.e., they return the input node) and have the side-effect of printing the input (or some tensor of the input).

Example for Theano:

diff = y_pred - y_true
diff = theano.printing.Print('shape of diff', attrs=['shape'])(diff)
return K.square(diff)

Example for TensorFlow:

diff = y_pred - y_true
diff = tf.Print(diff, [tf.shape(diff)])
return K.square(diff)

Note that this only works for intermediate values. Keras expects tensors that are passed to other layers to have specific attributes such as _keras_shape. Values processed by the backend, i.e. through Print, usually do not have that attribute. To solve this, you can wrap debug statements in a Lambda layer for example.

这篇关于如何在TensorFlow和Keras的损失函数中打印中间变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆