Tensorflow Keras:评估时如何在自定义层中设置断点(调试)? [英] Keras, Tensorflow: How to set breakpoint (debug) in custom layer when evaluating?

查看:641
本文介绍了Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我只想在自定义层中进行一些数值验证.

I just want to do some numerical validation inside the custom layer.

假设我们有一个非常简单的自定义层:

Suppose we have a very simple custom layer:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

和主程序:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

如果在行m = x * x上设置断点调试,则执行y = test_layer()(input)时程序将在此处暂停,这是因为生成了图形,因此调用了call()方法.

If I set a breakpoint debug at the line m = x * x, the program will pause here when executing y = test_layer()(input), this is because the graph is built, the call() method is called.

但是当我使用model.predict()赋予它真实的价值,并且想要在图层内部正常工作时,它不会在m = x * x

But when I use model.predict() to give it real value, and wanna look inside the layer if it work properly, it doesn't pause at the line m = x * x

我的问题是:

  1. 仅在构建计算图时才调用call()方法吗? (提供实际价值时不会调用它吗?)

  1. Is call() method only called when the computational graph is being built? (it won't be called when feeding real value?)

如何在层内调试(或在何处插入断点)以在输入实值时查看变量的值?

How to debug (or where to insert break point) inside a layer to see the value of variables when give it real value input?

推荐答案

  1. 是的. call()方法仅用于构建计算图.

  1. Yes. The call() method is only used to build the computational graph.

关于调试.我更喜欢使用TFDBG,这是张量流的推荐调试工具,尽管它不提供断点功能.

As to the debug. I prefer using TFDBG, which is a recommended debugging tool for tensorflow, although it doesn't provide break point functions.

对于Keras,您可以将这些行添加到脚本中以使用TFDBG

For Keras, you can add these line to your script to use TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

这篇关于Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆