Keras,Tensorflow:评估时如何在自定义层中设置断点(调试)? [英] Keras, Tensorflow: How to set breakpoint (debug) in custom layer when evaluating?

查看:64
本文介绍了Keras,Tensorflow:评估时如何在自定义层中设置断点(调试)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我只想在自定义层内做一些数值验证.

I just want to do some numerical validation inside the custom layer.

假设我们有一个非常简单的自定义层:

Suppose we have a very simple custom layer:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

还有主程序:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

如果我在m = x * x这一行设置断点调试,程序在执行y = test_layer()(input)的时候会在这里暂停,这是因为图是建好的,所以调用了call()方法.

If I set a breakpoint debug at the line m = x * x, the program will pause here when executing y = test_layer()(input), this is because the graph is built, the call() method is called.

但是当我使用 model.predict() 来赋予它真正的价值,并想看看它是否正常工作时,它不会在 m = x 行处暂停* x

But when I use model.predict() to give it real value, and wanna look inside the layer if it work properly, it doesn't pause at the line m = x * x

我的问题是:

  1. call() 方法是否仅在构建计算图时调用?(喂真值时不会调用?)

  1. Is call() method only called when the computational graph is being built? (it won't be called when feeding real value?)

如何在层内调试(或在哪里插入断点)以在输入实值时查看变量的值?

How to debug (or where to insert break point) inside a layer to see the value of variables when give it real value input?

推荐答案

  1. 是的.call() 方法仅用于构建计算图.

  1. Yes. The call() method is only used to build the computational graph.

至于调试.我更喜欢使用 TFDBG,这是 tensorflow 的推荐调试工具,虽然它不提供断点功能.

As to the debug. I prefer using TFDBG, which is a recommended debugging tool for tensorflow, although it doesn't provide break point functions.

对于 Keras,您可以将这些行添加到您的脚本中以使用 TFDBG

For Keras, you can add these line to your script to use TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

这篇关于Keras,Tensorflow:评估时如何在自定义层中设置断点(调试)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆