批次梯度下降的Sklearn实现 [英] Sklearn Implementation for batch gradient descend

查看:411
本文介绍了批次梯度下降的Sklearn实现的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用sklearn进行分类实现批量梯度下降的方法是什么? 我们为随机GD提供了 SGDClassifier ,它一次将具有一个实例,而 线性/逻辑回归 则使用正常等式.

解决方案

参考链接

fit方法调用_fit,该调用在max_iter设置为分配的\默认最大迭代次数的情况下调用_partial_fit. 参考链接

结论:

partial_fit确实不是不是批次GD,即不是计算梯度并更新每批次的重量,而是对每个样品进行此操作.

在sklearn中似乎没有任何机制可以进行批次梯度下降.

What is the way of implementing Batch gradient descent using sklearn for classification? We have SGDClassifier for Stochastic GD which will take single instance at a time and Linear/Logistic Regression which uses normal equation.

解决方案

The possible answer to the question as pointed out in the other similar question as well from sklearn docs:

SGD allows minibatch (online/out-of-core) learning, see the partial_fit method.

But is partial_fit really a batch gradient decent?

SGD: The gradient of the cost function is calculated and the weights are updated using the gradient decent step for each sample.

Batch/Mini Batch GD: The gradient of the cost function is calculated and the weights are updated using the gradient decent step once per batch.

So Batch GD with batch size of 1 == SGD.

Now that we are clear about definitions lets investigate the code of sklearn SGDClassifier.

The docstring of partial_fit says

Perform one epoch of stochastic gradient descent on given samples.

But this is not a batch GD but it looks more like a helper function to run fit method with max_iter=1 (infact commented as same in docstrings).

partial_fit calls _partial_fit with max_iter==1. Reference link

fit method calls _fit which calls _partial_fit with max_iter set to the assigned\default maximum iterations. Reference link

conclusion:

partial_fit does not really do batch GD, i.e it is not calculating the gradients and updating the weight per batch but rather doing so for each sample.

There seems to be no mechanism in sklearn to do batch gradient descend.

这篇关于批次梯度下降的Sklearn实现的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆