计算网络两个输出之间的cosine_proximity损耗 [英] Computing cosine_proximity loss between two outputs of the network

查看:356
本文介绍了计算网络两个输出之间的cosine_proximity损耗的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用Keras 2.0.2功能API(Tensorflow 1.0.1)来实现一个网络,该网络需要多个输入并产生两个输出ab.我需要使用cosine_proximity损失来训练网络,以使ba的标签.我该怎么做?

I'm using Keras 2.0.2 Functional API (Tensorflow 1.0.1) to implement a network that takes several inputs and produces two outputs a and b. I need to train the network using the cosine_proximity loss, such that b is the label for a. How do I do this?

在这里共享我的代码.最后一行model.fit(..)是有问题的部分,因为我本身没有标签数据.标签是由模型本身产生的.

Sharing my code here. The last line model.fit(..) is the problematic part because I don't have labeled data per se. The label is produced by the model itself.

from keras.models import Model
from keras.layers import Input, LSTM
from keras import losses

shared_lstm = LSTM(dim)

q1 = Input(shape=(..,.. ), name='q1')
q2 = Input(shape=(..,.. ), name='q2')
a = shared_lstm(q1)
b = shared_lstm(q2)
model = Model(inputs=[q1,q2], outputs=[a, b])
model.compile(optimizer='adam', loss=losses.cosine_proximity)

model.fit([testq1, testq2], [?????])

推荐答案

您可以先定义一个假的true标签.例如,将其定义为输入数据大小的一维数组.

You can define a fake true label first. For example, define it as a 1-D array of ones of the size of your input data.

现在是损失函数.您可以编写如下.

Now comes the loss function. You can write it as follows.

def my_cosine_proximity(y_true, y_pred):
    a = y_pred[0]
    b = y_pred[1]
    # depends on whether you want to normalize
    a = K.l2_normalize(a, axis=-1)
    b = K.l2_normalize(b, axis=-1)        
    return -K.mean(a * b, axis=-1) + 0 * y_true

我已将y_true乘以零,并将其相加,以使Theano不会丢失输入警告/错误.

I have multiplied y_true by zero and added it just so that Theano does give not missing input warning/error.

您应正常调用fit函数,即通过添加伪造的地面真相标签.

You should call your fit function normally i.e. by including your fake ground-truth labels.

model.compile('adam', my_cosine_proximity) # 'adam' used as an example optimizer 
model.fit([testq1, testq2], fake_y_true)

这篇关于计算网络两个输出之间的cosine_proximity损耗的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆