使用Keras的快速渐变符号方法 [英] Fast gradient sign method with keras

查看:96
本文介绍了使用Keras的快速渐变符号方法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前正在研究本文 .用异方差神经网络实现快速梯度符号方法.

I'm currently working on this paper . To implement the Fast gradient sign method with a heteroscedastic neural network.

如果将损失函数定义为l(\theta,x,y),其中x是特征,y是标签,而\theta是参数. 我们的目标不是最小化l(\theta,x,y),而是最小化l(\theta,x,y)+l(\theta,x',y)其中

If we define the loss function as l(\theta,x,y) where x is the feature, y the label and \theta the parameters. Instead of minimizing l(\theta,x,y), the goal is to minimize l(\theta,x,y)+l(\theta,x',y) where

x'=x+\eps*\sign(\nabla_x l(\theta,x,y))

这是我的尝试(没有成功):

Here my attempt(without any success):

def customLoss(x):
    def neg_log_likelihood(y_true, y_pred):

        def neg_log(y_t,y_p):
            inter=(y_p[...,0,None]-y_t)/K.clip(y_p[...,1,None],K.epsilon(),None)
            val=K.log(K.clip(K.square(y_p[...,1,None]),K.epsilon(),None))+K.square(inter)
            return val

        val=neg_log(y_true,y_pred)

        deriv=K.gradients(val,x)
        xb=x+0.01*K.sign(deriv)
        out=model.predict(xb)
        valb=neg_log(y_true,out)

        return K.mean(val+valb,axis=-1)
    return neg_log_likelihood

然后调用此损失函数

model.compile(loss=customLoss(model.inputs),...)

您有任何想法该如何实施吗?

Do you have any ideas how can I implement this?

推荐答案

正确的损失函数是:

def customLoss(x):
    def neg_log_likelihood(y_true, y_pred):

        def neg_log(y_t,y_p):
            inter=(y_p[...,0,None]-y_t)/K.clip(y_p[...,1,None],K.epsilon(),None)
            val=K.log(K.clip(K.square(y_p[...,1,None]),K.epsilon(),None))+K.square(inter)
            return val

        val=neg_log(y_true,y_pred)

        deriv=K.gradients(val,x)
        xb=x+0.01*K.sign(deriv)
        out=model(xb)
        valb=neg_log(y_true,out)

        return K.mean(val+valb,axis=-1)
    return neg_log_likelihood

区别在于model(xb)返回张量,而model.predict(xb)不返回张量.

The difference is that model(xb) returns a tensor while model.predict(xb) doesn't.

这篇关于使用Keras的快速渐变符号方法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆