损失函数从 Keras 中的 .fit() 方法触发多少次 [英] How many times the loss function is triggered from .fit() method in Keras

查看:25
本文介绍了损失函数从 Keras 中的 .fit() 方法触发多少次的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在自定义损失函数中进行一些自定义计算.但是当我记录自定义损失函数的语句时,似乎自定义损失函数只被调用一次(在 .fit() 方法的开头).

I am trying to do some custom calculations in the custom loss function. But when I log the statements from the custom loss function, it seems that custom loss function is only called once (in the begin of .fit() method).

损失函数示例:

def loss(y_true, y_pred):
    print("--- Starting of the loss function ---")
    print(y_true)
    loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
    print("--- Ending of the loss function ---")
    return loss

使用回调来检查批处理何时开始和结束:

Using callback to check when the batch starts and ends:

class monitor(Callback):
    def on_batch_begin(self, batch, logs=None):
        print("\n >> Starting a new batch (batch index) :: ", batch)

    def on_batch_end(self, batch, logs=None):
        print(">> Ending a batch (batch index) :: ", batch)

.fit() 方法用作:

.fit() method used as:

history = model.fit(
    x=[inputs],
    y=[outputs],
    shuffle=False,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCH,
    verbose=1,
    callbacks=[monitor()]
)

使用的参数:

BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)

和输出:

Epoch 1/3
 >> Starting a new batch (batch index) ::  0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307

Epoch 2/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246

Epoch 3/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219

为什么自定义损失函数只在开始时调用,而不是每次批量计算时调用?我也想知道损失函数什么时候被调用/触发?

Why the custom loss function is only called in the starting and its not called for every batch calculations? And I would also like to know when the loss function is called/triggered?

推荐答案

损失函数调试消息仅在培训开始.

The loss function debug messages were printed only at the beginning of the training.

这是因为为了提高性能,您的损失函数在内部被转换为张量流图,而 python 打印函数仅在您的函数被跟踪时才起作用.即它仅在训练开始时打印,这意味着当时正在跟踪您的损失函数.请参阅以下页面了解更多信息:https://www.tensorflow.org/guide/function

This is because internally your loss function got converted into tensorflow graph for the sake of performance, and the python print function only works when your function is being traced. i.e. it printed only at the beginning of the training which implies your loss function was being traced at that time. Please refer to the following page for more information: https://www.tensorflow.org/guide/function

简答:要正确打印,请使用 tf.print() 而不是 print()

Short answer: To print properly, use tf.print() instead of print()

我也想知道损失函数什么时候被调用/触发?

And I would also like to know when the loss function is called/triggered?

使用 tf.print() 后,调试消息将正确打印.你会看到你的损失函数每一步至少被调用一次以获得损失值,从而获得梯度.

After you use tf.print(), the debug messages will be printed properly. You will see your loss function is called at least once per step for getting the loss value and thus the gradient.

这篇关于损失函数从 Keras 中的 .fit() 方法触发多少次的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆