如何在线性分类器中应用类权重进行二进制分类? [英] How to apply class weights in linear classifier for binary classification?

查看:106
本文介绍了如何在线性分类器中应用类权重进行二进制分类?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是我用来执行二进制分类的线性分类器,这是代码片段:

This is the linear classifier that I am using to perform binary classification, here is code snippet:

my_optimizer = tf.train.AdagradOptimizer(learning_rate = learning_rate)
my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer,5.0)
# Create a linear classifier object
linear_classifier = tf.estimator.LinearClassifier(
          feature_columns = feature_columns, 
          optimizer = my_optimizer 
          )
linear_classifier.train(input_fn = training_input_fn, steps = steps)

数据集不平衡,只有两个类是/否.否类示例的数目是36548,而YES类示例的数目是4640.

The dataset is imbalanced, there are only two classes yes/no. The number of NO class examples are 36548 while number of YES class examples are 4640.

如何对这些数据进行平衡?我一直在搜索,可以找到与类权重相关的内容,但是找不到如何创建类权重以及如何将其应用于张量流的训练方法.

How can I apply balancing to this data? I have been searching around and I could find stuff related to class weights etc but I couldn't find how can I create class weights and how to apply to the train method of tensor flow.

这是我计算损失的方式:

Here is how I am calculating losses:

training_probabilities = linear_classifier.predict(input_fn = training_predict_input_fn)
training_probabilities = np.array([item['probabilities'] for item in training_probabilities])

validation_probabilities = linear_classifier.predict(input_fn=validation_predict_input_fn)
validation_probabilities = np.array([item['probabilities'] for item in validation_probabilities])

training_log_loss = metrics.log_loss(training_targets, training_probabilities)
validation_log_loss = metrics.log_loss(validation_targets, validation_probabilities)

推荐答案

我假设您正在使用

I assume that you are using the log_loss function from sklearn for computing your loss. If that is the case you can add class weights by using the argument sample_weight and pass on an array containing the weight to be given for each data point. sample_weight is an rolled out version of class_weights. You can compute sample_weight array by passing on the sample weights as given here.

在代码中添加以下几行:

Add the following lines to your code:

sample_wts = compute_sample_weight("balanced", training_targets)
training_log_loss = metrics.log_loss(training_targets, training_probabilities, sample_weight= sample_wts)

希望这会有所帮助!

这篇关于如何在线性分类器中应用类权重进行二进制分类?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆