TensorFlow 2 中带导数的损失函数 [英] Loss function with derivative in TensorFlow 2

查看:97
本文介绍了TensorFlow 2 中带导数的损失函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用 TF2 (2.3.0) NN 来逼近求解 ODE 的函数 y:y'+3y=0

I am using TF2 (2.3.0) NN to approximate the function y which solves the ODE: y'+3y=0

我已经定义了 cutom loss 类和函数,在其中我试图将单个输出与单个输入区分开来,因此方程成立,前提是 y_true 为零:

I have defined cutsom loss class and function in which I am trying to differentiate the single output with respect to the single input so the equation holds, provided that y_true is zero:

from tensorflow.keras.losses import Loss
import tensorflow as tf

class CustomLossOde(Loss):
    def __init__(self, x, model, name='ode_loss'):
        super().__init__(name=name)
        self.x = x
        self.model = model

    def call(self, y_true, y_pred):

        with tf.GradientTape() as tape:
            tape.watch(self.x)
            y_p = self.model(self.x)


        dy_dx = tape.gradient(y_p, self.x)
        loss = tf.math.reduce_mean(tf.square(dy_dx + 3 * y_pred - y_true))
        return loss

但运行以下神经网络:

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input
from custom_loss_ode import CustomLossOde


num_samples = 1024
x_train = 4 * (tf.random.uniform((num_samples, )) - 0.5)
y_train = tf.zeros((num_samples, ))
inputs = Input(shape=(1,))
x = Dense(16, 'tanh')(inputs)
x = Dense(8, 'tanh')(x)
x = Dense(4)(x)
y = Dense(1)(x)
model = Model(inputs=inputs, outputs=y)
loss = CustomLossOde(model.input, model)
model.compile(optimizer=Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.99),loss=loss)
model.run_eagerly = True
model.fit(x_train, y_train, batch_size=16, epochs=30)

现在我从第一个纪元中得到 0 个损失,这没有任何意义.

for now I am getting 0 loss from the fisrt epoch, which doesn't make any sense.

我已经从函数中打印了 y_truey_test 并且它们看起来没问题,所以我怀疑问题出在我没有成功打印的渐变中.感谢任何帮助

I have printed both y_true and y_test from within the function and they seem OK so I suspect that the problem is in the gradien which I didn't succeed to print. Apprecitate any help

推荐答案

在这种情况下,使用高级 Keras API 定义自定义损失有点困难.相反,我会从 scracth 中编写训练循环,因为它允许对您可以做什么进行更细粒度的控制.

Defining a custom loss with the high level Keras API is a bit difficult in that case. I would instead write the training loop from scracth, as it allows a finer grained control over what you can do.

我从这两个指南中获得灵感:

I took inspiration from those two guides :

基本上,我利用了多个磁带可以无缝交互的事实.我用一个来计算损失函数,另一个来计算优化器要传播的梯度.

Basically, I used the fact that multiple tape can interact seamlessly. I use one to compute the loss function, the other to calculate the gradients to be propagated by the optimizer.

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input

num_samples = 1024
x_train = 4 * (tf.random.uniform((num_samples, )) - 0.5)
y_train = tf.zeros((num_samples, ))
inputs = Input(shape=(1,))
x = Dense(16, 'tanh')(inputs)
x = Dense(8, 'tanh')(x)
x = Dense(4)(x)
y = Dense(1)(x)
model = Model(inputs=inputs, outputs=y)

# using the high level tf.data API for data handling
x_train = tf.reshape(x_train,(-1,1))
dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(1)

opt = Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.99)
for step, (x,y_true) in enumerate(dataset):
    # we need to convert x to a variable if we want the tape to be 
    # able to compute the gradient according to x
    x_variable = tf.Variable(x) 
    with tf.GradientTape() as model_tape:
        with tf.GradientTape() as loss_tape:
            loss_tape.watch(x_variable)
            y_pred = model(x_variable)
        dy_dx = loss_tape.gradient(y_pred, x_variable)
        loss = tf.math.reduce_mean(tf.square(dy_dx + 3 * y_pred - y_true))
    grad = model_tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grad, model.trainable_variables))
    if step%20==0:
        print(f"Step {step}: loss={loss.numpy()}")

这篇关于TensorFlow 2 中带导数的损失函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆