在 Tensorflow 中实现对比损失和三元组损失 [英] Implementing contrastive loss and triplet loss in Tensorflow

查看:204
本文介绍了在 Tensorflow 中实现对比损失和三元组损失的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

两天前我开始玩 TensorFlow,我想知道是否实现了三元组和对比损失.

I started to play with TensorFlow two days ago and I'm wondering if there is the triplet and the contrastive losses implemented.

我一直在查看文档,但我还没有找到任何关于这些东西的例子或描述.

I've been looking at the documentation, but I haven't found any example or description about these things.

推荐答案

更新(2018/03/19):我写了一个博文 详细介绍了如何在 TensorFlow 中实现三元组损失.

Update (2018/03/19): I wrote a blog post detailing how to implement triplet loss in TensorFlow.

你需要自己实现对比损失或三元组损失,但是一旦你知道对或三元组,这很容易.

You need to implement yourself the contrastive loss or the triplet loss, but once you know the pairs or triplets this is quite easy.

假设您输入了数据对及其标签(正数或负数,即相同类别或不同类别).例如,您将图像作为大小为 28x28x1 的输入:

Suppose you have as input the pairs of data and their label (positive or negative, i.e. same class or different class). For instance you have images as input of size 28x28x1:

left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2

left_output = model(left)  # shape [None, 128]
right_output = model(right)  # shape [None, 128]

d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)

loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d

loss = 0.5 * tf.reduce_mean(loss)

<小时>

三重损失

与对比损失相同,但使用三元组(锚、正、负).这里不需要标签.


Triplet Loss

Same as with contrastive loss, but with triplets (anchor, positive, negative). You don't need labels here.

anchor_output = ...  # shape [None, 128]
positive_output = ...  # shape [None, 128]
negative_output = ...  # shape [None, 128]

d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)

loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)

<小时>

在 TensorFlow 中实现三元组损失或对比损失的真正麻烦在于如何对三元组或对进行采样.我将专注于生成三元组,因为它比生成对更难.


The real trouble when implementing triplet loss or contrastive loss in TensorFlow is how to sample the triplets or pairs. I will focus on generating triplets because it is harder than generating pairs.

最简单的方法是在 Tensorflow 图之外生成它们,即在 python 中,并通过占位符将它们提供给网络.基本上你一次选择图像 3,前两个来自同一个类,第三个来自另一个类.然后我们对这些三元组执行前馈,并计算三元组损失.

The easiest way is to generate them outside of the Tensorflow graph, i.e. in python and feed them to the network through the placeholders. Basically you select images 3 at a time, with the first two from the same class and the third from another class. We then perform a feedforward on these triplets, and compute the triplet loss.

这里的问题是生成三元组很复杂.我们希望它们是有效的三元组,具有正损失的三元组(否则损失为 0 并且网络不会学习).
要知道三元组是否好,您需要计算其损失,因此您已经通过网络进行了一个前馈...

The issue here is that generating triplets is complicated. We want them to be valid triplets, triplets with a positive loss (otherwise the loss is 0 and the network doesn't learn).
To know whether a triplet is good or not you need to compute its loss, so you already make one feedforward through the network...

显然,在 Tensorflow 中实现三元组损失很困难,并且有一些方法可以使它比在 python 中采样更有效,但解释它们需要一整篇博文!

Clearly, implementing triplet loss in Tensorflow is hard, and there are ways to make it more efficient than sampling in python but explaining them would require a whole blog post !

这篇关于在 Tensorflow 中实现对比损失和三元组损失的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆