是否有可能返回一个数组(或张量)而不是数字的度量? [英] Is it possible to have a metric that returns an array (or tensor) rather than a number?

查看:61
本文介绍了是否有可能返回一个数组(或张量)而不是数字的度量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个神经网络,输出为NxM,其中N是批处理大小,M是网络需要进行预测的输出数量.我想为网络的每个M输出计算一个指标,即跨批次的所有实例,但分别为每个M输出计算一个指标,以便该指标有M个值.我试图按如下方法创建自定义指标.

I have a neural network with an output NxM, where N is the batch size and M are the number of outputs where the network needs to make a prediction. I would like to compute a metric for each of the M outputs of the network, i.e. across all instances of the batch but separately for each of the M outputs, so that there would be M values of this metric. I tried to create a custom metric as follows.

def my_metric(y_true, y_pred):
    return [3.1, 5.2] # a list of dummy values

,然后将此度量传递到模型的compile方法的度量列表,然后Keras输出一个数字,该数字是3.15.2的平均值(在本例中为(3.1 + 5.2)/2 = 4.15),而不是而不是打印实际列表.那么,是否有一种方法可以返回并打印列表(或numpy数组)作为度量?当然,在我的特定情况下,我不会在上面的示例中返回虚拟列表,但是我的自定义指标更加复杂.

and then pass this metric to the list of metrics of the compile method of the model, then Keras outputs a number that is the average of 3.1 and 5.2 (in this case, (3.1 + 5.2)/2 = 4.15) rather than printing the actual list. So, is there a way of returning and printing a list (or numpy array) as the metric? Of course, in my specific case, I will not return the dummy list in the example above, but my custom metric is more complex.

推荐答案

每M设置一个指标.

一个输出的工作代码:

from keras.layers import Dense, Input
from keras.models import Model
import keras.backend as K
import numpy as np

inputs = Input((5,))
outputs = Dense(3)(inputs)
model = Model(inputs, outputs)

def metricWrapper(m):
    def meanMetric(true, pred):
        return pred[:, m]
    meanMetric.__name__ = 'meanMetric_' + str(m)
    return meanMetric
metrics = [metricWrapper(m) for m in range(3)]

model.compile(loss='mse', metrics=metrics, optimizer='adam')
model.fit(np.random.rand(10,5), np.zeros((10,3)))

这篇关于是否有可能返回一个数组(或张量)而不是数字的度量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆