在Tensorflow 2.0中实现自定义损失功能 [英] Implement custom loss function in Tensorflow 2.0

查看:728
本文介绍了在Tensorflow 2.0中实现自定义损失功能的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在建立时间序列分类模型.数据非常不平衡,所以我决定使用加权交叉熵函数作为损失.

Tensorflow提供 tf.nn.weighted_cross_entropy_with_logits ,但我不确定如何在TF 2.0中使用它.因为我的模型是使用tf.keras API构建的,所以我正在考虑创建这样的自定义损失函数:

pos_weight=10
def weighted_cross_entropy_with_logits(y_true,y_pred):
  return tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,pos_weight)

# .....
model.compile(loss=weighted_cross_entropy_with_logits,optimizer="adam",metrics=["acc"])

我的问题是:有没有一种方法可以直接将tf.nn.weighted_cross_entropy_with_logits与tf.keras API一起使用?

解决方案

您可以将类权重直接传递给model.fit函数.

class_weight:可选字典将类索引(整数)映射到 权重(浮动)值,用于对损失函数进行加权(在 仅培训).这可能有助于告诉模型支付更多 注意"来自代表性不足的班级的样本.

例如:

{
    0: 0.31, 
    1: 0.33, 
    2: 0.36, 
    3: 0.42, 
    4: 0.48
}

来源


修改: JL Meunier答案很好地解释了如何乘法具有类权重的logits.

I'm building a model for Time series classification. The data is very unbalanced so I've decided to use a weighted cross entropy function as my loss.

Tensorflow provides tf.nn.weighted_cross_entropy_with_logits but I'm not sure how to use it in TF 2.0. Because my model is build using tf.keras API I was thinking about creating my custom loss function like this:

pos_weight=10
def weighted_cross_entropy_with_logits(y_true,y_pred):
  return tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,pos_weight)

# .....
model.compile(loss=weighted_cross_entropy_with_logits,optimizer="adam",metrics=["acc"])

My question is: is there a way to use tf.nn.weighted_cross_entropy_with_logits with tf.keras API directly?

解决方案

You can pass the class weights directly to the model.fit function.

class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

Such as:

{
    0: 0.31, 
    1: 0.33, 
    2: 0.36, 
    3: 0.42, 
    4: 0.48
}

Source


Edit: JL Meunier answer explains well how to multiply the logits with class weights.

这篇关于在Tensorflow 2.0中实现自定义损失功能的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆