TensorFlow 中是否有内置的 KL 散度损失函数? [英] Is there a built-in KL divergence loss function in TensorFlow?

查看:79
本文介绍了TensorFlow 中是否有内置的 KL 散度损失函数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个张量,prob_aprob_b,形状为 [None, 1000],我想从 [None, 1000] 计算 KL 散度code>prob_a 到 prob_b.TensorFlow 中是否有内置函数?我尝试使用 tf.contrib.distributions.kl(prob_a, prob_b),但它给出:

I have two tensors, prob_a and prob_b with shape [None, 1000], and I want to compute the KL divergence from prob_a to prob_b. Is there a built-in function for this in TensorFlow? I tried using tf.contrib.distributions.kl(prob_a, prob_b), but it gives:

NotImplementedError: 没有为 dist_a 类型张量和 dist_b 类型张量注册 KL(dist_a || dist_b)

NotImplementedError: No KL(dist_a || dist_b) registered for dist_a type Tensor and dist_b type Tensor

如果没有内置函数,有什么好的解决方法?

If there is no built-in function, what would be a good workaround?

推荐答案

假设您的输入张量 prob_aprob_b 是沿最后一个轴总和为 1 的概率张量,你可以这样做:

Assuming that your input tensors prob_a and prob_b are probability tensors that sum to 1 along the last axis, you could do it like this:

def kl(x, y):
    X = tf.distributions.Categorical(probs=x)
    Y = tf.distributions.Categorical(probs=y)
    return tf.distributions.kl_divergence(X, Y)

result = kl(prob_a, prob_b)

一个简单的例子:

import numpy as np
import tensorflow as tf
a = np.array([[0.25, 0.1, 0.65], [0.8, 0.15, 0.05]])
b = np.array([[0.7, 0.2, 0.1], [0.15, 0.8, 0.05]])
sess = tf.Session()
print(kl(a, b).eval(session=sess))  # [0.88995184 1.08808468]

你会得到相同的结果

np.sum(a * np.log(a / b), axis=1) 

然而,这个实现有点问题(在 Tensorflow 1.8.0 中检查).

However, this implementation is a bit buggy (checked in Tensorflow 1.8.0).

如果您在 a 中的概率为零,例如如果你尝试使用 [0.8, 0.2, 0.0] 而不是 [0.8, 0.15, 0.05],你会得到 nan 即使被 Kullback-Leibler 定义 0 * log(0/b) 应该贡献为零.

If you have zero probabilities in a, e.g. if you try [0.8, 0.2, 0.0] instead of [0.8, 0.15, 0.05], you will get nan even though by Kullback-Leibler definition 0 * log(0 / b) should contribute as zero.

为了缓解这种情况,应该添加一些小的数值常数.在这种情况下使用 tf.distributions.kl_divergence(X, Y, allow_nan_stats=False) 导致运行时错误也是明智的.

To mitigate this, one should add some small numerical constant. It is also prudent to use tf.distributions.kl_divergence(X, Y, allow_nan_stats=False) to cause a runtime error in such situations.

此外,如果 b 中有一些零,您将获得 inf 值,这些值不会被 allow_nan_stats=False 捕获选项,因此也必须处理这些.

Also, if there are some zeros in b, you will get inf values which won't be caught by the allow_nan_stats=False option so those have to be handled as well.

这篇关于TensorFlow 中是否有内置的 KL 散度损失函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆