sklearn partial_fit()未显示准确的结果为fit() [英] sklearn partial_fit() not showing accurate results as fit()

查看:113
本文介绍了sklearn partial_fit()未显示准确的结果为fit()的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在训练数据L1,L2,L3的3个列表.首先,我使用SGDClassifier fit()训练所有一个,然后使用partial_fit()实例进行训练.我用L4,L5测试数据. [列表中的数据是图像数据,L4,L5图像与L2相同.

fit()的预测是正确的,这是我对partial_fit()的期望.但是,以下代码的输出显示,无论partial_fit()进行10,000次迭代,两者的行为都不同.

输出:

fit
[1] #Tested L1. Predicts label as 1 correctly
[2] #Tested L2. Predicts label as 2 correctly
[3] #Tested L3. Predicts label as 3 correctly
[2] #Tested L4. Predicts label as 2 correctly [Data close to L2]
[2] #Tested L5. Predicts label as 2 correctly [Data close to L2]
partial_fit
[3] #Tested L1. Predicts label as 3 incorrectly
[3] #Tested L2. Predicts label as 3 incorrectly
[3] #Tested L3. Predicts label as 3 incorrectly
[3] #Tested L4. Predicts label as 3 incorrectly 
[3] #Tested L5. Predicts label as 3 incorrectly 

代码:

from sklearn import linear_model, neighbors
import numpy as np

L1 = [-1.98257446e-01,  1.02612168e-01,  1.06458694e-01, -4.44016755e-02,
       -1.25126377e-01, -1.03119195e-01, -1.89867821e-02, -5.70720285e-02,
        1.65993825e-01, -4.91751768e-02,  1.35020703e-01,  5.58929071e-02,
       -1.79934561e-01, -1.61055699e-02, -3.67883481e-02,  7.28202313e-02,
       -8.59514326e-02, -1.19364798e-01, -6.03461489e-02, -9.60081592e-02,
        9.60884690e-02,  7.37309158e-02, -4.95407730e-02, -2.30211094e-02,
       -1.59170195e-01, -3.23998809e-01, -8.31042454e-02, -7.68149048e-02,
        3.26708518e-03, -5.57898730e-02,  3.65743786e-02,  3.37894261e-02,
       -1.61165833e-01, -9.21991318e-02,  3.83259654e-02,  1.30853474e-01,
        2.16114409e-02,  1.56024918e-02,  1.63483590e-01,  3.55638564e-04,
       -1.01068482e-01,  3.11988778e-02,  2.79297493e-02,  3.43645960e-01,
        7.68225491e-02,  7.39665255e-02,  9.03626233e-02, -4.77984771e-02,
        1.46613032e-01, -2.24640951e-01,  9.37603638e-02,  1.30618230e-01,
        5.41394278e-02,  3.57956365e-02,  9.59608406e-02, -1.01410612e-01,
        1.15592867e-01,  7.47590065e-02, -2.77784020e-01,  1.61038041e-01,
        2.08325848e-01, -1.48789823e-01, -9.12107825e-02, -2.09741015e-02,
        2.12046385e-01,  4.47734147e-02, -8.59520137e-02, -8.20810571e-02,
        1.37491941e-01, -1.57671914e-01, -1.28236525e-02, -2.89905779e-02,
       -9.23343226e-02, -1.41179219e-01, -2.73343533e-01,  8.64235312e-02,
        4.51376319e-01,  2.13798493e-01, -1.68360874e-01,  7.94294775e-02,
       -1.16615891e-01,  4.44242992e-02,  1.32415727e-01, -1.00808069e-02,
       -7.62857720e-02,  4.50578667e-02, -1.62037611e-01,  8.80152583e-02,
        2.10405558e-01,  5.48043177e-02, -2.42764503e-03,  2.23779172e-01,
        1.04215354e-01,  6.21869229e-03,  4.02947590e-02,  1.28729194e-02,
       -1.31998569e-01, -8.53061676e-02, -7.21085370e-02,  3.05483658e-02,
        7.17334375e-02, -1.21093884e-01,  4.04045768e-02,  8.53371918e-02,
       -1.82588950e-01,  1.95098877e-01, -3.77971642e-02,  2.39514187e-02,
       -6.40425161e-02,  2.60147993e-02, -1.23514839e-01, -5.75782135e-02,
        1.23560801e-01, -1.81436151e-01,  1.73729539e-01,  1.55140847e-01,
        9.45670251e-03,  1.76663831e-01,  4.24060002e-02,  5.23296222e-02,
       -2.61488743e-02, -1.90883875e-04, -1.07142523e-01, -1.19456224e-01,
       -4.72589768e-03, -1.22928023e-02,  1.22105561e-01,  1.08871996e-01]

L2 = [-0.13126934,  0.04299157,  0.03283413, -0.07268133, -0.0575216 ,
       -0.05970731, -0.04122763, -0.12341423,  0.23687837, -0.19369504,
        0.18289158, -0.02773106, -0.17346333, -0.03682114, -0.01798879,
        0.12592959, -0.13210742, -0.14877586, -0.03237661, -0.08512233,
        0.03863079, -0.0244094 ,  0.03298262,  0.07976148, -0.14883795,
       -0.41100848, -0.17795764, -0.08934171,  0.00651174, -0.0744134 ,
        0.0313075 ,  0.08470915, -0.18205762, -0.01133199, -0.0155912 ,
        0.11513804,  0.00782543, -0.05359597,  0.18193047, -0.00212595,
       -0.20811354, -0.16053183,  0.05181924,  0.23603486,  0.10422225,
        0.02778829,  0.05380247, -0.04042226,  0.0341601 , -0.17557909,
        0.05018872,  0.11027649,  0.05657898,  0.02233699,  0.08839077,
       -0.15501094,  0.01485735,  0.04386368, -0.11386063, -0.01646214,
        0.00378657, -0.10775882, -0.12292566, -0.02450235,  0.25261074,
        0.14213347, -0.09663931, -0.11174012,  0.22364001, -0.17145677,
       -0.00569641,  0.02280853, -0.12527066, -0.18559724, -0.29374081,
       -0.00162096,  0.42862758,  0.12023295, -0.12319036,  0.10102081,
       -0.05752999, -0.02222615,  0.04897028,  0.1726429 , -0.09291326,
        0.12992594, -0.05943635,  0.1127295 ,  0.13184965, -0.02819252,
       -0.02569888,  0.13797338, -0.05463714,  0.07084383,  0.03620753,
        0.02154547, -0.09113872, -0.00730729, -0.11946794, -0.00743609,
        0.13593611,  0.01564942, -0.02297226,  0.11888021, -0.18092889,
        0.11661324,  0.02172676, -0.09794122,  0.01236411,  0.0558071 ,
       -0.1001874 , -0.1216456 ,  0.13321149, -0.22005031,  0.08024856,
        0.19123463, -0.06378062,  0.2226923 ,  0.07309284,  0.11730921,
        0.0262427 , -0.03699137, -0.1887596 , -0.02048384,  0.04079603,
       -0.02144403,  0.00859149, -0.01283618]


L3 = [-1.39073551e-01,  5.75132817e-02,  1.06875971e-01, -4.47942242e-02,
        6.49299771e-02, -8.30453411e-02,  3.50628048e-02, -4.86568436e-02,
        1.11577645e-01, -9.53562111e-02,  2.84853131e-01, -5.57231307e-02,
       -2.10671812e-01, -1.03007048e-01,  1.96518339e-02,  7.77831525e-02,
       -7.90358335e-02, -3.00030578e-02, -7.82457143e-02, -1.04805976e-01,
        8.18016306e-02,  6.47072643e-02,  1.21586584e-02,  8.08022916e-04,
       -8.00280571e-02, -3.14502358e-01, -1.17208570e-01, -9.81831551e-02,
        2.68037282e-02, -1.33987337e-01,  1.33101437e-02,  2.91747972e-02,
       -1.87404498e-01, -5.92408441e-02, -7.84080178e-02,  1.05799856e-02,
       -6.32970333e-02, -2.37192065e-02,  1.31071255e-01,  5.25641590e-02,
       -8.04402679e-02, -9.32691842e-02, -2.31102034e-02,  2.82592803e-01,
        1.47951603e-01,  8.49031657e-03, -6.55979887e-02, -1.86005980e-03,
        2.86830403e-03, -2.48319194e-01, -5.38104884e-02,  1.02639243e-01,
        5.23314849e-02,  7.83263296e-02,  7.35125244e-02, -5.58062941e-02,
        3.26449387e-02, -2.09478531e-02, -1.95044577e-01,  9.34160873e-03,
       -2.26898044e-02, -8.78838003e-02, -6.57741576e-02, -2.00360566e-02,
        1.71352893e-01,  6.89927936e-02, -7.95211121e-02, -8.00146461e-02,
        1.32486463e-01, -1.35504007e-01,  2.61258446e-02,  1.05848603e-01,
       -9.21048969e-02, -1.80963904e-01, -1.98812112e-01,  7.26982281e-02,
        3.29640329e-01,  1.04015507e-01, -1.24389552e-01,  2.69887168e-02,
       -1.54598460e-01, -5.56088090e-02,  1.01781934e-01, -3.85247841e-02,
       -3.20458487e-02,  3.86849903e-02, -8.98609757e-02,  8.27674717e-02,
        1.06020764e-01, -7.34615028e-02, -4.03962284e-02,  1.98970288e-01,
       -5.60568720e-02,  5.78189567e-02,  4.93795872e-02, -2.47523189e-04,
       -6.07730448e-02,  2.19929889e-02, -1.10751927e-01,  6.69334084e-04,
        8.69397819e-02, -1.09967209e-01,  1.43145397e-03,  8.74901861e-02,
       -1.14516295e-01,  1.38158470e-01,  7.43495077e-02, -3.98697220e-02,
        3.39040905e-02,  2.46684682e-02, -1.51388928e-01, -7.87943155e-02,
        1.09218210e-01, -2.05471277e-01,  1.49658069e-01,  1.86885983e-01,
       -3.31082232e-02,  1.01324990e-01,  3.32798958e-02,  5.33202365e-02,
       -6.65426776e-02, -2.35776380e-02, -1.32266074e-01, -2.31741816e-02,
        3.98471728e-02,  4.69821505e-02, -2.74340808e-02, -5.45420833e-02]

L4 = [-9.80433971e-02, -7.03648664e-03, -8.67843628e-04, -1.18527517e-01,
       -5.99347353e-02, -3.52256261e-02, -4.00453769e-02, -9.58476141e-02,
        2.23521233e-01, -1.88561112e-01,  1.72594860e-01, -4.11576033e-02,
       -1.52830154e-01, -5.84353730e-02, -4.33000550e-03,  1.20912530e-01,
       -1.34689406e-01, -1.79964483e-01, -3.15833911e-02, -9.25036967e-02,
       -1.05666816e-02, -4.42105718e-03,  2.60549188e-02,  9.88835841e-02,
       -1.62467003e-01, -4.19883490e-01, -1.71131760e-01, -9.64985639e-02,
       -1.19223613e-02, -9.55987573e-02,  2.25513764e-02,  1.07761353e-01,
       -2.36451998e-01, -1.74359381e-02,  5.71147725e-03,  1.24660656e-01,
        6.69890456e-03, -1.86523274e-02,  1.85175732e-01,  2.91687660e-02,
       -2.09594339e-01, -1.34366542e-01,  4.75538447e-02,  2.49922469e-01,
        1.22993328e-01,  2.24278457e-02,  1.52391801e-02, -1.24563389e-02,
        4.96755280e-02, -1.92227215e-01,  9.83141586e-02,  1.23155341e-01,
        3.48911509e-02,  1.25203300e-02,  6.06377572e-02, -1.32613182e-01,
       -5.22616133e-03,  7.46049434e-02, -1.53830111e-01,  4.96822223e-03,
       -6.75934367e-03, -9.12150443e-02, -1.03079259e-01, -2.60316133e-02,
        2.52563179e-01,  1.48371726e-01, -9.73276347e-02, -1.42138824e-01,
        2.50091761e-01, -1.66190103e-01,  1.91132445e-02,  3.98359001e-02,
       -1.27865523e-01, -1.90915748e-01, -2.90090829e-01,  2.87051760e-02,
        4.39558297e-01,  1.14880979e-01, -1.23038329e-01,  1.02565333e-01,
       -6.96414784e-02, -4.86778058e-02,  3.95676941e-02,  1.31223276e-01,
       -7.37062097e-02,  1.40905678e-01, -4.61848751e-02,  1.32415891e-01,
        1.50173992e-01,  1.56789012e-02, -6.01302609e-02,  1.37784094e-01,
       -8.30642357e-02,  7.05572739e-02,  8.34304839e-02,  4.12208587e-02,
       -8.44793320e-02, -2.76077650e-02, -1.74217999e-01, -7.80004263e-03,
        7.51234069e-02, -2.18363479e-04, -4.15662788e-02,  1.44352645e-01,
       -1.46695063e-01,  1.61359623e-01,  2.00959761e-02, -1.15739897e-01,
       -4.57503423e-02,  8.08721706e-02, -1.02865808e-01, -1.25917166e-01,
        1.34963557e-01, -2.33383894e-01,  1.03095181e-01,  1.53916180e-01,
       -2.00787671e-02,  2.26398230e-01,  5.59305362e-02,  9.53603685e-02,
        1.47923566e-02, -5.58686256e-02, -2.01987177e-01, -2.75421105e-02,
        4.75574993e-02, -1.08102616e-02,  5.95078953e-02,  1.26588587e-02]#Close to L2

L5 = [-0.09945749, -0.00729111,  0.0092897 , -0.13243762, -0.06422047,
       -0.02094417, -0.04948308, -0.12064691,  0.25643739, -0.19205171,
        0.15657693, -0.03121898, -0.15308823, -0.02828152, -0.00710347,
        0.11809425, -0.14299625, -0.16806611, -0.03130123, -0.08865803,
       -0.0071869 , -0.00937061,  0.06185013,  0.10348818, -0.18077886,
       -0.43158019, -0.17442586, -0.08369756,  0.00713679, -0.08146362,
       -0.00203652,  0.09452251, -0.24805595, -0.02332739, -0.00440642,
        0.13737108,  0.00089538, -0.04461086,  0.17354517,  0.02099614,
       -0.22964232, -0.14414147,  0.07377731,  0.21512158,  0.12966961,
        0.03000744,  0.01046804, -0.0051102 ,  0.04499209, -0.1823051 ,
        0.07896246,  0.11629909,  0.02137423,  0.02415319,  0.06205415,
       -0.12419473,  0.01515957,  0.06340452, -0.1500473 , -0.01087676,
        0.02246305, -0.0924818 , -0.09429674, -0.01974701,  0.25166726,
        0.16988155, -0.09064031, -0.15273461,  0.21510246, -0.17729256,
        0.00261592,  0.02652721, -0.13491498, -0.17640282, -0.31118405,
       -0.00512062,  0.41723928,  0.13354909, -0.09930452,  0.10033775,
       -0.06307391, -0.02699157,  0.04080637,  0.13098213, -0.08033849,
        0.16044492, -0.04734115,  0.12942326,  0.14534265,  0.0249849 ,
       -0.06554834,  0.13151604, -0.07915305,  0.08410332,  0.07018198,
        0.06627715, -0.11851253, -0.02576792, -0.18880717, -0.00411349,
        0.08233207,  0.04832725, -0.01709246,  0.15401676, -0.15097997,
        0.16647491,  0.01185772, -0.11977788, -0.02823763,  0.08750527,
       -0.10837749, -0.12731393,  0.11664411, -0.22722226,  0.09817819,
        0.16637388, -0.01940754,  0.21179773,  0.06896579,  0.0847318 ,
        0.00796246, -0.01696757, -0.19169487, -0.03898101,  0.0400917 ,
       -0.03423833,  0.08150289,  0.0139573 ]#Close to L2


sgd_clf = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
classes = np.arange(5)

sgd_clf_fit = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
sgd_clf_fit.fit([L1,L2,L3],[1,2,3])

print("fit")
print(sgd_clf_fit.predict([L1]))
print(sgd_clf_fit.predict([L2]))
print(sgd_clf_fit.predict([L3]))
print(sgd_clf_fit.predict([L4]))
print(sgd_clf_fit.predict([L5]))

idx1 = 1
for i in range(10000):
    sgd_clf.partial_fit([L1], [idx1], classes=classes)

idx2 = 2
for i in range(10000):
    sgd_clf.partial_fit([L2],[idx2])

idx3 = 3
for i in range(10000):
    sgd_clf.partial_fit([L3],[idx3])

print("partial_fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

如何提高我的partial_fit()预测结果以与fit()匹配?我想逐个实例地学习并且仍然准确地进行预测.我尝试了不同数量的迭代,但是没有用.

解决方案

在实例化分类器时,也许必须使用warm_start=True组合fitpartial_fit.

并且您应该为传入的数据和标签(类)创建一个缓冲区以对其进行优化.使用缓冲液改善拟合度(每个fit不能添加1个样品)并防止类不平衡(每批添加相同数量的每个类).像这样:


from sklearn import linear_model
import numpy as np

L1 = [-1.98257446e-01,  1.02612168e-01,  1.06458694e-01, -4.44016755e-02,
       -1.25126377e-01, -1.03119195e-01, -1.89867821e-02, -5.70720285e-02,
        1.65993825e-01, -4.91751768e-02,  1.35020703e-01,  5.58929071e-02,
       -1.79934561e-01, -1.61055699e-02, -3.67883481e-02,  7.28202313e-02,
       -8.59514326e-02, -1.19364798e-01, -6.03461489e-02, -9.60081592e-02,
        9.60884690e-02,  7.37309158e-02, -4.95407730e-02, -2.30211094e-02,
       -1.59170195e-01, -3.23998809e-01, -8.31042454e-02, -7.68149048e-02,
        3.26708518e-03, -5.57898730e-02,  3.65743786e-02,  3.37894261e-02,
       -1.61165833e-01, -9.21991318e-02,  3.83259654e-02,  1.30853474e-01,
        2.16114409e-02,  1.56024918e-02,  1.63483590e-01,  3.55638564e-04,
       -1.01068482e-01,  3.11988778e-02,  2.79297493e-02,  3.43645960e-01,
        7.68225491e-02,  7.39665255e-02,  9.03626233e-02, -4.77984771e-02,
        1.46613032e-01, -2.24640951e-01,  9.37603638e-02,  1.30618230e-01,
        5.41394278e-02,  3.57956365e-02,  9.59608406e-02, -1.01410612e-01,
        1.15592867e-01,  7.47590065e-02, -2.77784020e-01,  1.61038041e-01,
        2.08325848e-01, -1.48789823e-01, -9.12107825e-02, -2.09741015e-02,
        2.12046385e-01,  4.47734147e-02, -8.59520137e-02, -8.20810571e-02,
        1.37491941e-01, -1.57671914e-01, -1.28236525e-02, -2.89905779e-02,
       -9.23343226e-02, -1.41179219e-01, -2.73343533e-01,  8.64235312e-02,
        4.51376319e-01,  2.13798493e-01, -1.68360874e-01,  7.94294775e-02,
       -1.16615891e-01,  4.44242992e-02,  1.32415727e-01, -1.00808069e-02,
       -7.62857720e-02,  4.50578667e-02, -1.62037611e-01,  8.80152583e-02,
        2.10405558e-01,  5.48043177e-02, -2.42764503e-03,  2.23779172e-01,
        1.04215354e-01,  6.21869229e-03,  4.02947590e-02,  1.28729194e-02,
       -1.31998569e-01, -8.53061676e-02, -7.21085370e-02,  3.05483658e-02,
        7.17334375e-02, -1.21093884e-01,  4.04045768e-02,  8.53371918e-02,
       -1.82588950e-01,  1.95098877e-01, -3.77971642e-02,  2.39514187e-02,
       -6.40425161e-02,  2.60147993e-02, -1.23514839e-01, -5.75782135e-02,
        1.23560801e-01, -1.81436151e-01,  1.73729539e-01,  1.55140847e-01,
        9.45670251e-03,  1.76663831e-01,  4.24060002e-02,  5.23296222e-02,
       -2.61488743e-02, -1.90883875e-04, -1.07142523e-01, -1.19456224e-01,
       -4.72589768e-03, -1.22928023e-02,  1.22105561e-01,  1.08871996e-01]

L2 = [-0.13126934,  0.04299157,  0.03283413, -0.07268133, -0.0575216 ,
       -0.05970731, -0.04122763, -0.12341423,  0.23687837, -0.19369504,
        0.18289158, -0.02773106, -0.17346333, -0.03682114, -0.01798879,
        0.12592959, -0.13210742, -0.14877586, -0.03237661, -0.08512233,
        0.03863079, -0.0244094 ,  0.03298262,  0.07976148, -0.14883795,
       -0.41100848, -0.17795764, -0.08934171,  0.00651174, -0.0744134 ,
        0.0313075 ,  0.08470915, -0.18205762, -0.01133199, -0.0155912 ,
        0.11513804,  0.00782543, -0.05359597,  0.18193047, -0.00212595,
       -0.20811354, -0.16053183,  0.05181924,  0.23603486,  0.10422225,
        0.02778829,  0.05380247, -0.04042226,  0.0341601 , -0.17557909,
        0.05018872,  0.11027649,  0.05657898,  0.02233699,  0.08839077,
       -0.15501094,  0.01485735,  0.04386368, -0.11386063, -0.01646214,
        0.00378657, -0.10775882, -0.12292566, -0.02450235,  0.25261074,
        0.14213347, -0.09663931, -0.11174012,  0.22364001, -0.17145677,
       -0.00569641,  0.02280853, -0.12527066, -0.18559724, -0.29374081,
       -0.00162096,  0.42862758,  0.12023295, -0.12319036,  0.10102081,
       -0.05752999, -0.02222615,  0.04897028,  0.1726429 , -0.09291326,
        0.12992594, -0.05943635,  0.1127295 ,  0.13184965, -0.02819252,
       -0.02569888,  0.13797338, -0.05463714,  0.07084383,  0.03620753,
        0.02154547, -0.09113872, -0.00730729, -0.11946794, -0.00743609,
        0.13593611,  0.01564942, -0.02297226,  0.11888021, -0.18092889,
        0.11661324,  0.02172676, -0.09794122,  0.01236411,  0.0558071 ,
       -0.1001874 , -0.1216456 ,  0.13321149, -0.22005031,  0.08024856,
        0.19123463, -0.06378062,  0.2226923 ,  0.07309284,  0.11730921,
        0.0262427 , -0.03699137, -0.1887596 , -0.02048384,  0.04079603,
       -0.02144403,  0.00859149, -0.01283618]


L3 = [-1.39073551e-01,  5.75132817e-02,  1.06875971e-01, -4.47942242e-02,
        6.49299771e-02, -8.30453411e-02,  3.50628048e-02, -4.86568436e-02,
        1.11577645e-01, -9.53562111e-02,  2.84853131e-01, -5.57231307e-02,
       -2.10671812e-01, -1.03007048e-01,  1.96518339e-02,  7.77831525e-02,
       -7.90358335e-02, -3.00030578e-02, -7.82457143e-02, -1.04805976e-01,
        8.18016306e-02,  6.47072643e-02,  1.21586584e-02,  8.08022916e-04,
       -8.00280571e-02, -3.14502358e-01, -1.17208570e-01, -9.81831551e-02,
        2.68037282e-02, -1.33987337e-01,  1.33101437e-02,  2.91747972e-02,
       -1.87404498e-01, -5.92408441e-02, -7.84080178e-02,  1.05799856e-02,
       -6.32970333e-02, -2.37192065e-02,  1.31071255e-01,  5.25641590e-02,
       -8.04402679e-02, -9.32691842e-02, -2.31102034e-02,  2.82592803e-01,
        1.47951603e-01,  8.49031657e-03, -6.55979887e-02, -1.86005980e-03,
        2.86830403e-03, -2.48319194e-01, -5.38104884e-02,  1.02639243e-01,
        5.23314849e-02,  7.83263296e-02,  7.35125244e-02, -5.58062941e-02,
        3.26449387e-02, -2.09478531e-02, -1.95044577e-01,  9.34160873e-03,
       -2.26898044e-02, -8.78838003e-02, -6.57741576e-02, -2.00360566e-02,
        1.71352893e-01,  6.89927936e-02, -7.95211121e-02, -8.00146461e-02,
        1.32486463e-01, -1.35504007e-01,  2.61258446e-02,  1.05848603e-01,
       -9.21048969e-02, -1.80963904e-01, -1.98812112e-01,  7.26982281e-02,
        3.29640329e-01,  1.04015507e-01, -1.24389552e-01,  2.69887168e-02,
       -1.54598460e-01, -5.56088090e-02,  1.01781934e-01, -3.85247841e-02,
       -3.20458487e-02,  3.86849903e-02, -8.98609757e-02,  8.27674717e-02,
        1.06020764e-01, -7.34615028e-02, -4.03962284e-02,  1.98970288e-01,
       -5.60568720e-02,  5.78189567e-02,  4.93795872e-02, -2.47523189e-04,
       -6.07730448e-02,  2.19929889e-02, -1.10751927e-01,  6.69334084e-04,
        8.69397819e-02, -1.09967209e-01,  1.43145397e-03,  8.74901861e-02,
       -1.14516295e-01,  1.38158470e-01,  7.43495077e-02, -3.98697220e-02,
        3.39040905e-02,  2.46684682e-02, -1.51388928e-01, -7.87943155e-02,
        1.09218210e-01, -2.05471277e-01,  1.49658069e-01,  1.86885983e-01,
       -3.31082232e-02,  1.01324990e-01,  3.32798958e-02,  5.33202365e-02,
       -6.65426776e-02, -2.35776380e-02, -1.32266074e-01, -2.31741816e-02,
        3.98471728e-02,  4.69821505e-02, -2.74340808e-02, -5.45420833e-02]

L4 = [-9.80433971e-02, -7.03648664e-03, -8.67843628e-04, -1.18527517e-01,
       -5.99347353e-02, -3.52256261e-02, -4.00453769e-02, -9.58476141e-02,
        2.23521233e-01, -1.88561112e-01,  1.72594860e-01, -4.11576033e-02,
       -1.52830154e-01, -5.84353730e-02, -4.33000550e-03,  1.20912530e-01,
       -1.34689406e-01, -1.79964483e-01, -3.15833911e-02, -9.25036967e-02,
       -1.05666816e-02, -4.42105718e-03,  2.60549188e-02,  9.88835841e-02,
       -1.62467003e-01, -4.19883490e-01, -1.71131760e-01, -9.64985639e-02,
       -1.19223613e-02, -9.55987573e-02,  2.25513764e-02,  1.07761353e-01,
       -2.36451998e-01, -1.74359381e-02,  5.71147725e-03,  1.24660656e-01,
        6.69890456e-03, -1.86523274e-02,  1.85175732e-01,  2.91687660e-02,
       -2.09594339e-01, -1.34366542e-01,  4.75538447e-02,  2.49922469e-01,
        1.22993328e-01,  2.24278457e-02,  1.52391801e-02, -1.24563389e-02,
        4.96755280e-02, -1.92227215e-01,  9.83141586e-02,  1.23155341e-01,
        3.48911509e-02,  1.25203300e-02,  6.06377572e-02, -1.32613182e-01,
       -5.22616133e-03,  7.46049434e-02, -1.53830111e-01,  4.96822223e-03,
       -6.75934367e-03, -9.12150443e-02, -1.03079259e-01, -2.60316133e-02,
        2.52563179e-01,  1.48371726e-01, -9.73276347e-02, -1.42138824e-01,
        2.50091761e-01, -1.66190103e-01,  1.91132445e-02,  3.98359001e-02,
       -1.27865523e-01, -1.90915748e-01, -2.90090829e-01,  2.87051760e-02,
        4.39558297e-01,  1.14880979e-01, -1.23038329e-01,  1.02565333e-01,
       -6.96414784e-02, -4.86778058e-02,  3.95676941e-02,  1.31223276e-01,
       -7.37062097e-02,  1.40905678e-01, -4.61848751e-02,  1.32415891e-01,
        1.50173992e-01,  1.56789012e-02, -6.01302609e-02,  1.37784094e-01,
       -8.30642357e-02,  7.05572739e-02,  8.34304839e-02,  4.12208587e-02,
       -8.44793320e-02, -2.76077650e-02, -1.74217999e-01, -7.80004263e-03,
        7.51234069e-02, -2.18363479e-04, -4.15662788e-02,  1.44352645e-01,
       -1.46695063e-01,  1.61359623e-01,  2.00959761e-02, -1.15739897e-01,
       -4.57503423e-02,  8.08721706e-02, -1.02865808e-01, -1.25917166e-01,
        1.34963557e-01, -2.33383894e-01,  1.03095181e-01,  1.53916180e-01,
       -2.00787671e-02,  2.26398230e-01,  5.59305362e-02,  9.53603685e-02,
        1.47923566e-02, -5.58686256e-02, -2.01987177e-01, -2.75421105e-02,
        4.75574993e-02, -1.08102616e-02,  5.95078953e-02,  1.26588587e-02] # Close to L2

L5 = [-0.09945749, -0.00729111,  0.0092897 , -0.13243762, -0.06422047,
       -0.02094417, -0.04948308, -0.12064691,  0.25643739, -0.19205171,
        0.15657693, -0.03121898, -0.15308823, -0.02828152, -0.00710347,
        0.11809425, -0.14299625, -0.16806611, -0.03130123, -0.08865803,
       -0.0071869 , -0.00937061,  0.06185013,  0.10348818, -0.18077886,
       -0.43158019, -0.17442586, -0.08369756,  0.00713679, -0.08146362,
       -0.00203652,  0.09452251, -0.24805595, -0.02332739, -0.00440642,
        0.13737108,  0.00089538, -0.04461086,  0.17354517,  0.02099614,
       -0.22964232, -0.14414147,  0.07377731,  0.21512158,  0.12966961,
        0.03000744,  0.01046804, -0.0051102 ,  0.04499209, -0.1823051 ,
        0.07896246,  0.11629909,  0.02137423,  0.02415319,  0.06205415,
       -0.12419473,  0.01515957,  0.06340452, -0.1500473 , -0.01087676,
        0.02246305, -0.0924818 , -0.09429674, -0.01974701,  0.25166726,
        0.16988155, -0.09064031, -0.15273461,  0.21510246, -0.17729256,
        0.00261592,  0.02652721, -0.13491498, -0.17640282, -0.31118405,
       -0.00512062,  0.41723928,  0.13354909, -0.09930452,  0.10033775,
       -0.06307391, -0.02699157,  0.04080637,  0.13098213, -0.08033849,
        0.16044492, -0.04734115,  0.12942326,  0.14534265,  0.0249849 ,
       -0.06554834,  0.13151604, -0.07915305,  0.08410332,  0.07018198,
        0.06627715, -0.11851253, -0.02576792, -0.18880717, -0.00411349,
        0.08233207,  0.04832725, -0.01709246,  0.15401676, -0.15097997,
        0.16647491,  0.01185772, -0.11977788, -0.02823763,  0.08750527,
       -0.10837749, -0.12731393,  0.11664411, -0.22722226,  0.09817819,
        0.16637388, -0.01940754,  0.21179773,  0.06896579,  0.0847318 ,
        0.00796246, -0.01696757, -0.19169487, -0.03898101,  0.0400917 ,
       -0.03423833,  0.08150289,  0.0139573] # Close to L2


init_data = np.array([L1, L2, L3])
init_classes = np.array([1, 2, 3], dtype=np.uint8)

sgd_clf = linear_model.SGDClassifier(loss="modified_huber", max_iter=100, warm_start=True)

sgd_clf.fit(init_data, init_classes)

print("fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

## Create buffer of incoming samples ##
data_buffer = np.empty((0, 128))
class_buffer = np.empty((0, 1), dtype=np.uint8)

for class_idx in init_classes:
    for i in range(1000):
        data_buffer = np.append(data_buffer, [init_data[class_idx - 1]], axis=0)
        class_buffer = np.append(class_buffer, class_idx)


## fit buffered data
print(data_buffer)
print(data_buffer.shape)
print(class_buffer)
print(class_buffer.shape)
sgd_clf.partial_fit(data_buffer, class_buffer)

print("partial_fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

I am training 3 lists of data L1, L2, L3. First i train all one them with SGDClassifier fit() and later instance by instance with partial_fit(). I I test the data with L4, L5. [The data in lists is image data and L4, L5 images are same as L2].

The predictions with fit() is correct and it is what i am expecting with partial_fit(). However the output of below code shows that both behave differently irrespective of 10,000 number of iterations for partial_fit().

Output:

fit
[1] #Tested L1. Predicts label as 1 correctly
[2] #Tested L2. Predicts label as 2 correctly
[3] #Tested L3. Predicts label as 3 correctly
[2] #Tested L4. Predicts label as 2 correctly [Data close to L2]
[2] #Tested L5. Predicts label as 2 correctly [Data close to L2]
partial_fit
[3] #Tested L1. Predicts label as 3 incorrectly
[3] #Tested L2. Predicts label as 3 incorrectly
[3] #Tested L3. Predicts label as 3 incorrectly
[3] #Tested L4. Predicts label as 3 incorrectly 
[3] #Tested L5. Predicts label as 3 incorrectly 

Code:

from sklearn import linear_model, neighbors
import numpy as np

L1 = [-1.98257446e-01,  1.02612168e-01,  1.06458694e-01, -4.44016755e-02,
       -1.25126377e-01, -1.03119195e-01, -1.89867821e-02, -5.70720285e-02,
        1.65993825e-01, -4.91751768e-02,  1.35020703e-01,  5.58929071e-02,
       -1.79934561e-01, -1.61055699e-02, -3.67883481e-02,  7.28202313e-02,
       -8.59514326e-02, -1.19364798e-01, -6.03461489e-02, -9.60081592e-02,
        9.60884690e-02,  7.37309158e-02, -4.95407730e-02, -2.30211094e-02,
       -1.59170195e-01, -3.23998809e-01, -8.31042454e-02, -7.68149048e-02,
        3.26708518e-03, -5.57898730e-02,  3.65743786e-02,  3.37894261e-02,
       -1.61165833e-01, -9.21991318e-02,  3.83259654e-02,  1.30853474e-01,
        2.16114409e-02,  1.56024918e-02,  1.63483590e-01,  3.55638564e-04,
       -1.01068482e-01,  3.11988778e-02,  2.79297493e-02,  3.43645960e-01,
        7.68225491e-02,  7.39665255e-02,  9.03626233e-02, -4.77984771e-02,
        1.46613032e-01, -2.24640951e-01,  9.37603638e-02,  1.30618230e-01,
        5.41394278e-02,  3.57956365e-02,  9.59608406e-02, -1.01410612e-01,
        1.15592867e-01,  7.47590065e-02, -2.77784020e-01,  1.61038041e-01,
        2.08325848e-01, -1.48789823e-01, -9.12107825e-02, -2.09741015e-02,
        2.12046385e-01,  4.47734147e-02, -8.59520137e-02, -8.20810571e-02,
        1.37491941e-01, -1.57671914e-01, -1.28236525e-02, -2.89905779e-02,
       -9.23343226e-02, -1.41179219e-01, -2.73343533e-01,  8.64235312e-02,
        4.51376319e-01,  2.13798493e-01, -1.68360874e-01,  7.94294775e-02,
       -1.16615891e-01,  4.44242992e-02,  1.32415727e-01, -1.00808069e-02,
       -7.62857720e-02,  4.50578667e-02, -1.62037611e-01,  8.80152583e-02,
        2.10405558e-01,  5.48043177e-02, -2.42764503e-03,  2.23779172e-01,
        1.04215354e-01,  6.21869229e-03,  4.02947590e-02,  1.28729194e-02,
       -1.31998569e-01, -8.53061676e-02, -7.21085370e-02,  3.05483658e-02,
        7.17334375e-02, -1.21093884e-01,  4.04045768e-02,  8.53371918e-02,
       -1.82588950e-01,  1.95098877e-01, -3.77971642e-02,  2.39514187e-02,
       -6.40425161e-02,  2.60147993e-02, -1.23514839e-01, -5.75782135e-02,
        1.23560801e-01, -1.81436151e-01,  1.73729539e-01,  1.55140847e-01,
        9.45670251e-03,  1.76663831e-01,  4.24060002e-02,  5.23296222e-02,
       -2.61488743e-02, -1.90883875e-04, -1.07142523e-01, -1.19456224e-01,
       -4.72589768e-03, -1.22928023e-02,  1.22105561e-01,  1.08871996e-01]

L2 = [-0.13126934,  0.04299157,  0.03283413, -0.07268133, -0.0575216 ,
       -0.05970731, -0.04122763, -0.12341423,  0.23687837, -0.19369504,
        0.18289158, -0.02773106, -0.17346333, -0.03682114, -0.01798879,
        0.12592959, -0.13210742, -0.14877586, -0.03237661, -0.08512233,
        0.03863079, -0.0244094 ,  0.03298262,  0.07976148, -0.14883795,
       -0.41100848, -0.17795764, -0.08934171,  0.00651174, -0.0744134 ,
        0.0313075 ,  0.08470915, -0.18205762, -0.01133199, -0.0155912 ,
        0.11513804,  0.00782543, -0.05359597,  0.18193047, -0.00212595,
       -0.20811354, -0.16053183,  0.05181924,  0.23603486,  0.10422225,
        0.02778829,  0.05380247, -0.04042226,  0.0341601 , -0.17557909,
        0.05018872,  0.11027649,  0.05657898,  0.02233699,  0.08839077,
       -0.15501094,  0.01485735,  0.04386368, -0.11386063, -0.01646214,
        0.00378657, -0.10775882, -0.12292566, -0.02450235,  0.25261074,
        0.14213347, -0.09663931, -0.11174012,  0.22364001, -0.17145677,
       -0.00569641,  0.02280853, -0.12527066, -0.18559724, -0.29374081,
       -0.00162096,  0.42862758,  0.12023295, -0.12319036,  0.10102081,
       -0.05752999, -0.02222615,  0.04897028,  0.1726429 , -0.09291326,
        0.12992594, -0.05943635,  0.1127295 ,  0.13184965, -0.02819252,
       -0.02569888,  0.13797338, -0.05463714,  0.07084383,  0.03620753,
        0.02154547, -0.09113872, -0.00730729, -0.11946794, -0.00743609,
        0.13593611,  0.01564942, -0.02297226,  0.11888021, -0.18092889,
        0.11661324,  0.02172676, -0.09794122,  0.01236411,  0.0558071 ,
       -0.1001874 , -0.1216456 ,  0.13321149, -0.22005031,  0.08024856,
        0.19123463, -0.06378062,  0.2226923 ,  0.07309284,  0.11730921,
        0.0262427 , -0.03699137, -0.1887596 , -0.02048384,  0.04079603,
       -0.02144403,  0.00859149, -0.01283618]


L3 = [-1.39073551e-01,  5.75132817e-02,  1.06875971e-01, -4.47942242e-02,
        6.49299771e-02, -8.30453411e-02,  3.50628048e-02, -4.86568436e-02,
        1.11577645e-01, -9.53562111e-02,  2.84853131e-01, -5.57231307e-02,
       -2.10671812e-01, -1.03007048e-01,  1.96518339e-02,  7.77831525e-02,
       -7.90358335e-02, -3.00030578e-02, -7.82457143e-02, -1.04805976e-01,
        8.18016306e-02,  6.47072643e-02,  1.21586584e-02,  8.08022916e-04,
       -8.00280571e-02, -3.14502358e-01, -1.17208570e-01, -9.81831551e-02,
        2.68037282e-02, -1.33987337e-01,  1.33101437e-02,  2.91747972e-02,
       -1.87404498e-01, -5.92408441e-02, -7.84080178e-02,  1.05799856e-02,
       -6.32970333e-02, -2.37192065e-02,  1.31071255e-01,  5.25641590e-02,
       -8.04402679e-02, -9.32691842e-02, -2.31102034e-02,  2.82592803e-01,
        1.47951603e-01,  8.49031657e-03, -6.55979887e-02, -1.86005980e-03,
        2.86830403e-03, -2.48319194e-01, -5.38104884e-02,  1.02639243e-01,
        5.23314849e-02,  7.83263296e-02,  7.35125244e-02, -5.58062941e-02,
        3.26449387e-02, -2.09478531e-02, -1.95044577e-01,  9.34160873e-03,
       -2.26898044e-02, -8.78838003e-02, -6.57741576e-02, -2.00360566e-02,
        1.71352893e-01,  6.89927936e-02, -7.95211121e-02, -8.00146461e-02,
        1.32486463e-01, -1.35504007e-01,  2.61258446e-02,  1.05848603e-01,
       -9.21048969e-02, -1.80963904e-01, -1.98812112e-01,  7.26982281e-02,
        3.29640329e-01,  1.04015507e-01, -1.24389552e-01,  2.69887168e-02,
       -1.54598460e-01, -5.56088090e-02,  1.01781934e-01, -3.85247841e-02,
       -3.20458487e-02,  3.86849903e-02, -8.98609757e-02,  8.27674717e-02,
        1.06020764e-01, -7.34615028e-02, -4.03962284e-02,  1.98970288e-01,
       -5.60568720e-02,  5.78189567e-02,  4.93795872e-02, -2.47523189e-04,
       -6.07730448e-02,  2.19929889e-02, -1.10751927e-01,  6.69334084e-04,
        8.69397819e-02, -1.09967209e-01,  1.43145397e-03,  8.74901861e-02,
       -1.14516295e-01,  1.38158470e-01,  7.43495077e-02, -3.98697220e-02,
        3.39040905e-02,  2.46684682e-02, -1.51388928e-01, -7.87943155e-02,
        1.09218210e-01, -2.05471277e-01,  1.49658069e-01,  1.86885983e-01,
       -3.31082232e-02,  1.01324990e-01,  3.32798958e-02,  5.33202365e-02,
       -6.65426776e-02, -2.35776380e-02, -1.32266074e-01, -2.31741816e-02,
        3.98471728e-02,  4.69821505e-02, -2.74340808e-02, -5.45420833e-02]

L4 = [-9.80433971e-02, -7.03648664e-03, -8.67843628e-04, -1.18527517e-01,
       -5.99347353e-02, -3.52256261e-02, -4.00453769e-02, -9.58476141e-02,
        2.23521233e-01, -1.88561112e-01,  1.72594860e-01, -4.11576033e-02,
       -1.52830154e-01, -5.84353730e-02, -4.33000550e-03,  1.20912530e-01,
       -1.34689406e-01, -1.79964483e-01, -3.15833911e-02, -9.25036967e-02,
       -1.05666816e-02, -4.42105718e-03,  2.60549188e-02,  9.88835841e-02,
       -1.62467003e-01, -4.19883490e-01, -1.71131760e-01, -9.64985639e-02,
       -1.19223613e-02, -9.55987573e-02,  2.25513764e-02,  1.07761353e-01,
       -2.36451998e-01, -1.74359381e-02,  5.71147725e-03,  1.24660656e-01,
        6.69890456e-03, -1.86523274e-02,  1.85175732e-01,  2.91687660e-02,
       -2.09594339e-01, -1.34366542e-01,  4.75538447e-02,  2.49922469e-01,
        1.22993328e-01,  2.24278457e-02,  1.52391801e-02, -1.24563389e-02,
        4.96755280e-02, -1.92227215e-01,  9.83141586e-02,  1.23155341e-01,
        3.48911509e-02,  1.25203300e-02,  6.06377572e-02, -1.32613182e-01,
       -5.22616133e-03,  7.46049434e-02, -1.53830111e-01,  4.96822223e-03,
       -6.75934367e-03, -9.12150443e-02, -1.03079259e-01, -2.60316133e-02,
        2.52563179e-01,  1.48371726e-01, -9.73276347e-02, -1.42138824e-01,
        2.50091761e-01, -1.66190103e-01,  1.91132445e-02,  3.98359001e-02,
       -1.27865523e-01, -1.90915748e-01, -2.90090829e-01,  2.87051760e-02,
        4.39558297e-01,  1.14880979e-01, -1.23038329e-01,  1.02565333e-01,
       -6.96414784e-02, -4.86778058e-02,  3.95676941e-02,  1.31223276e-01,
       -7.37062097e-02,  1.40905678e-01, -4.61848751e-02,  1.32415891e-01,
        1.50173992e-01,  1.56789012e-02, -6.01302609e-02,  1.37784094e-01,
       -8.30642357e-02,  7.05572739e-02,  8.34304839e-02,  4.12208587e-02,
       -8.44793320e-02, -2.76077650e-02, -1.74217999e-01, -7.80004263e-03,
        7.51234069e-02, -2.18363479e-04, -4.15662788e-02,  1.44352645e-01,
       -1.46695063e-01,  1.61359623e-01,  2.00959761e-02, -1.15739897e-01,
       -4.57503423e-02,  8.08721706e-02, -1.02865808e-01, -1.25917166e-01,
        1.34963557e-01, -2.33383894e-01,  1.03095181e-01,  1.53916180e-01,
       -2.00787671e-02,  2.26398230e-01,  5.59305362e-02,  9.53603685e-02,
        1.47923566e-02, -5.58686256e-02, -2.01987177e-01, -2.75421105e-02,
        4.75574993e-02, -1.08102616e-02,  5.95078953e-02,  1.26588587e-02]#Close to L2

L5 = [-0.09945749, -0.00729111,  0.0092897 , -0.13243762, -0.06422047,
       -0.02094417, -0.04948308, -0.12064691,  0.25643739, -0.19205171,
        0.15657693, -0.03121898, -0.15308823, -0.02828152, -0.00710347,
        0.11809425, -0.14299625, -0.16806611, -0.03130123, -0.08865803,
       -0.0071869 , -0.00937061,  0.06185013,  0.10348818, -0.18077886,
       -0.43158019, -0.17442586, -0.08369756,  0.00713679, -0.08146362,
       -0.00203652,  0.09452251, -0.24805595, -0.02332739, -0.00440642,
        0.13737108,  0.00089538, -0.04461086,  0.17354517,  0.02099614,
       -0.22964232, -0.14414147,  0.07377731,  0.21512158,  0.12966961,
        0.03000744,  0.01046804, -0.0051102 ,  0.04499209, -0.1823051 ,
        0.07896246,  0.11629909,  0.02137423,  0.02415319,  0.06205415,
       -0.12419473,  0.01515957,  0.06340452, -0.1500473 , -0.01087676,
        0.02246305, -0.0924818 , -0.09429674, -0.01974701,  0.25166726,
        0.16988155, -0.09064031, -0.15273461,  0.21510246, -0.17729256,
        0.00261592,  0.02652721, -0.13491498, -0.17640282, -0.31118405,
       -0.00512062,  0.41723928,  0.13354909, -0.09930452,  0.10033775,
       -0.06307391, -0.02699157,  0.04080637,  0.13098213, -0.08033849,
        0.16044492, -0.04734115,  0.12942326,  0.14534265,  0.0249849 ,
       -0.06554834,  0.13151604, -0.07915305,  0.08410332,  0.07018198,
        0.06627715, -0.11851253, -0.02576792, -0.18880717, -0.00411349,
        0.08233207,  0.04832725, -0.01709246,  0.15401676, -0.15097997,
        0.16647491,  0.01185772, -0.11977788, -0.02823763,  0.08750527,
       -0.10837749, -0.12731393,  0.11664411, -0.22722226,  0.09817819,
        0.16637388, -0.01940754,  0.21179773,  0.06896579,  0.0847318 ,
        0.00796246, -0.01696757, -0.19169487, -0.03898101,  0.0400917 ,
       -0.03423833,  0.08150289,  0.0139573 ]#Close to L2


sgd_clf = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
classes = np.arange(5)

sgd_clf_fit = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
sgd_clf_fit.fit([L1,L2,L3],[1,2,3])

print("fit")
print(sgd_clf_fit.predict([L1]))
print(sgd_clf_fit.predict([L2]))
print(sgd_clf_fit.predict([L3]))
print(sgd_clf_fit.predict([L4]))
print(sgd_clf_fit.predict([L5]))

idx1 = 1
for i in range(10000):
    sgd_clf.partial_fit([L1], [idx1], classes=classes)

idx2 = 2
for i in range(10000):
    sgd_clf.partial_fit([L2],[idx2])

idx3 = 3
for i in range(10000):
    sgd_clf.partial_fit([L3],[idx3])

print("partial_fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

How to improve my prediction result of partial_fit() to match with fit() ? I want to learn instance by instance and still predict accurately. I tried with different number of iterations but it didnot work.

解决方案

Maybe you have to combine fit and partial_fit using warm_start=True when instantiating the classifier.

And you should create a buffer for the incoming data and labels (classes) to optimize it. Use the buffer to improve fitting (not 1 sample per fit) and to prevent class imbalance (add the same number of each class per batch). Like this:


from sklearn import linear_model
import numpy as np

L1 = [-1.98257446e-01,  1.02612168e-01,  1.06458694e-01, -4.44016755e-02,
       -1.25126377e-01, -1.03119195e-01, -1.89867821e-02, -5.70720285e-02,
        1.65993825e-01, -4.91751768e-02,  1.35020703e-01,  5.58929071e-02,
       -1.79934561e-01, -1.61055699e-02, -3.67883481e-02,  7.28202313e-02,
       -8.59514326e-02, -1.19364798e-01, -6.03461489e-02, -9.60081592e-02,
        9.60884690e-02,  7.37309158e-02, -4.95407730e-02, -2.30211094e-02,
       -1.59170195e-01, -3.23998809e-01, -8.31042454e-02, -7.68149048e-02,
        3.26708518e-03, -5.57898730e-02,  3.65743786e-02,  3.37894261e-02,
       -1.61165833e-01, -9.21991318e-02,  3.83259654e-02,  1.30853474e-01,
        2.16114409e-02,  1.56024918e-02,  1.63483590e-01,  3.55638564e-04,
       -1.01068482e-01,  3.11988778e-02,  2.79297493e-02,  3.43645960e-01,
        7.68225491e-02,  7.39665255e-02,  9.03626233e-02, -4.77984771e-02,
        1.46613032e-01, -2.24640951e-01,  9.37603638e-02,  1.30618230e-01,
        5.41394278e-02,  3.57956365e-02,  9.59608406e-02, -1.01410612e-01,
        1.15592867e-01,  7.47590065e-02, -2.77784020e-01,  1.61038041e-01,
        2.08325848e-01, -1.48789823e-01, -9.12107825e-02, -2.09741015e-02,
        2.12046385e-01,  4.47734147e-02, -8.59520137e-02, -8.20810571e-02,
        1.37491941e-01, -1.57671914e-01, -1.28236525e-02, -2.89905779e-02,
       -9.23343226e-02, -1.41179219e-01, -2.73343533e-01,  8.64235312e-02,
        4.51376319e-01,  2.13798493e-01, -1.68360874e-01,  7.94294775e-02,
       -1.16615891e-01,  4.44242992e-02,  1.32415727e-01, -1.00808069e-02,
       -7.62857720e-02,  4.50578667e-02, -1.62037611e-01,  8.80152583e-02,
        2.10405558e-01,  5.48043177e-02, -2.42764503e-03,  2.23779172e-01,
        1.04215354e-01,  6.21869229e-03,  4.02947590e-02,  1.28729194e-02,
       -1.31998569e-01, -8.53061676e-02, -7.21085370e-02,  3.05483658e-02,
        7.17334375e-02, -1.21093884e-01,  4.04045768e-02,  8.53371918e-02,
       -1.82588950e-01,  1.95098877e-01, -3.77971642e-02,  2.39514187e-02,
       -6.40425161e-02,  2.60147993e-02, -1.23514839e-01, -5.75782135e-02,
        1.23560801e-01, -1.81436151e-01,  1.73729539e-01,  1.55140847e-01,
        9.45670251e-03,  1.76663831e-01,  4.24060002e-02,  5.23296222e-02,
       -2.61488743e-02, -1.90883875e-04, -1.07142523e-01, -1.19456224e-01,
       -4.72589768e-03, -1.22928023e-02,  1.22105561e-01,  1.08871996e-01]

L2 = [-0.13126934,  0.04299157,  0.03283413, -0.07268133, -0.0575216 ,
       -0.05970731, -0.04122763, -0.12341423,  0.23687837, -0.19369504,
        0.18289158, -0.02773106, -0.17346333, -0.03682114, -0.01798879,
        0.12592959, -0.13210742, -0.14877586, -0.03237661, -0.08512233,
        0.03863079, -0.0244094 ,  0.03298262,  0.07976148, -0.14883795,
       -0.41100848, -0.17795764, -0.08934171,  0.00651174, -0.0744134 ,
        0.0313075 ,  0.08470915, -0.18205762, -0.01133199, -0.0155912 ,
        0.11513804,  0.00782543, -0.05359597,  0.18193047, -0.00212595,
       -0.20811354, -0.16053183,  0.05181924,  0.23603486,  0.10422225,
        0.02778829,  0.05380247, -0.04042226,  0.0341601 , -0.17557909,
        0.05018872,  0.11027649,  0.05657898,  0.02233699,  0.08839077,
       -0.15501094,  0.01485735,  0.04386368, -0.11386063, -0.01646214,
        0.00378657, -0.10775882, -0.12292566, -0.02450235,  0.25261074,
        0.14213347, -0.09663931, -0.11174012,  0.22364001, -0.17145677,
       -0.00569641,  0.02280853, -0.12527066, -0.18559724, -0.29374081,
       -0.00162096,  0.42862758,  0.12023295, -0.12319036,  0.10102081,
       -0.05752999, -0.02222615,  0.04897028,  0.1726429 , -0.09291326,
        0.12992594, -0.05943635,  0.1127295 ,  0.13184965, -0.02819252,
       -0.02569888,  0.13797338, -0.05463714,  0.07084383,  0.03620753,
        0.02154547, -0.09113872, -0.00730729, -0.11946794, -0.00743609,
        0.13593611,  0.01564942, -0.02297226,  0.11888021, -0.18092889,
        0.11661324,  0.02172676, -0.09794122,  0.01236411,  0.0558071 ,
       -0.1001874 , -0.1216456 ,  0.13321149, -0.22005031,  0.08024856,
        0.19123463, -0.06378062,  0.2226923 ,  0.07309284,  0.11730921,
        0.0262427 , -0.03699137, -0.1887596 , -0.02048384,  0.04079603,
       -0.02144403,  0.00859149, -0.01283618]


L3 = [-1.39073551e-01,  5.75132817e-02,  1.06875971e-01, -4.47942242e-02,
        6.49299771e-02, -8.30453411e-02,  3.50628048e-02, -4.86568436e-02,
        1.11577645e-01, -9.53562111e-02,  2.84853131e-01, -5.57231307e-02,
       -2.10671812e-01, -1.03007048e-01,  1.96518339e-02,  7.77831525e-02,
       -7.90358335e-02, -3.00030578e-02, -7.82457143e-02, -1.04805976e-01,
        8.18016306e-02,  6.47072643e-02,  1.21586584e-02,  8.08022916e-04,
       -8.00280571e-02, -3.14502358e-01, -1.17208570e-01, -9.81831551e-02,
        2.68037282e-02, -1.33987337e-01,  1.33101437e-02,  2.91747972e-02,
       -1.87404498e-01, -5.92408441e-02, -7.84080178e-02,  1.05799856e-02,
       -6.32970333e-02, -2.37192065e-02,  1.31071255e-01,  5.25641590e-02,
       -8.04402679e-02, -9.32691842e-02, -2.31102034e-02,  2.82592803e-01,
        1.47951603e-01,  8.49031657e-03, -6.55979887e-02, -1.86005980e-03,
        2.86830403e-03, -2.48319194e-01, -5.38104884e-02,  1.02639243e-01,
        5.23314849e-02,  7.83263296e-02,  7.35125244e-02, -5.58062941e-02,
        3.26449387e-02, -2.09478531e-02, -1.95044577e-01,  9.34160873e-03,
       -2.26898044e-02, -8.78838003e-02, -6.57741576e-02, -2.00360566e-02,
        1.71352893e-01,  6.89927936e-02, -7.95211121e-02, -8.00146461e-02,
        1.32486463e-01, -1.35504007e-01,  2.61258446e-02,  1.05848603e-01,
       -9.21048969e-02, -1.80963904e-01, -1.98812112e-01,  7.26982281e-02,
        3.29640329e-01,  1.04015507e-01, -1.24389552e-01,  2.69887168e-02,
       -1.54598460e-01, -5.56088090e-02,  1.01781934e-01, -3.85247841e-02,
       -3.20458487e-02,  3.86849903e-02, -8.98609757e-02,  8.27674717e-02,
        1.06020764e-01, -7.34615028e-02, -4.03962284e-02,  1.98970288e-01,
       -5.60568720e-02,  5.78189567e-02,  4.93795872e-02, -2.47523189e-04,
       -6.07730448e-02,  2.19929889e-02, -1.10751927e-01,  6.69334084e-04,
        8.69397819e-02, -1.09967209e-01,  1.43145397e-03,  8.74901861e-02,
       -1.14516295e-01,  1.38158470e-01,  7.43495077e-02, -3.98697220e-02,
        3.39040905e-02,  2.46684682e-02, -1.51388928e-01, -7.87943155e-02,
        1.09218210e-01, -2.05471277e-01,  1.49658069e-01,  1.86885983e-01,
       -3.31082232e-02,  1.01324990e-01,  3.32798958e-02,  5.33202365e-02,
       -6.65426776e-02, -2.35776380e-02, -1.32266074e-01, -2.31741816e-02,
        3.98471728e-02,  4.69821505e-02, -2.74340808e-02, -5.45420833e-02]

L4 = [-9.80433971e-02, -7.03648664e-03, -8.67843628e-04, -1.18527517e-01,
       -5.99347353e-02, -3.52256261e-02, -4.00453769e-02, -9.58476141e-02,
        2.23521233e-01, -1.88561112e-01,  1.72594860e-01, -4.11576033e-02,
       -1.52830154e-01, -5.84353730e-02, -4.33000550e-03,  1.20912530e-01,
       -1.34689406e-01, -1.79964483e-01, -3.15833911e-02, -9.25036967e-02,
       -1.05666816e-02, -4.42105718e-03,  2.60549188e-02,  9.88835841e-02,
       -1.62467003e-01, -4.19883490e-01, -1.71131760e-01, -9.64985639e-02,
       -1.19223613e-02, -9.55987573e-02,  2.25513764e-02,  1.07761353e-01,
       -2.36451998e-01, -1.74359381e-02,  5.71147725e-03,  1.24660656e-01,
        6.69890456e-03, -1.86523274e-02,  1.85175732e-01,  2.91687660e-02,
       -2.09594339e-01, -1.34366542e-01,  4.75538447e-02,  2.49922469e-01,
        1.22993328e-01,  2.24278457e-02,  1.52391801e-02, -1.24563389e-02,
        4.96755280e-02, -1.92227215e-01,  9.83141586e-02,  1.23155341e-01,
        3.48911509e-02,  1.25203300e-02,  6.06377572e-02, -1.32613182e-01,
       -5.22616133e-03,  7.46049434e-02, -1.53830111e-01,  4.96822223e-03,
       -6.75934367e-03, -9.12150443e-02, -1.03079259e-01, -2.60316133e-02,
        2.52563179e-01,  1.48371726e-01, -9.73276347e-02, -1.42138824e-01,
        2.50091761e-01, -1.66190103e-01,  1.91132445e-02,  3.98359001e-02,
       -1.27865523e-01, -1.90915748e-01, -2.90090829e-01,  2.87051760e-02,
        4.39558297e-01,  1.14880979e-01, -1.23038329e-01,  1.02565333e-01,
       -6.96414784e-02, -4.86778058e-02,  3.95676941e-02,  1.31223276e-01,
       -7.37062097e-02,  1.40905678e-01, -4.61848751e-02,  1.32415891e-01,
        1.50173992e-01,  1.56789012e-02, -6.01302609e-02,  1.37784094e-01,
       -8.30642357e-02,  7.05572739e-02,  8.34304839e-02,  4.12208587e-02,
       -8.44793320e-02, -2.76077650e-02, -1.74217999e-01, -7.80004263e-03,
        7.51234069e-02, -2.18363479e-04, -4.15662788e-02,  1.44352645e-01,
       -1.46695063e-01,  1.61359623e-01,  2.00959761e-02, -1.15739897e-01,
       -4.57503423e-02,  8.08721706e-02, -1.02865808e-01, -1.25917166e-01,
        1.34963557e-01, -2.33383894e-01,  1.03095181e-01,  1.53916180e-01,
       -2.00787671e-02,  2.26398230e-01,  5.59305362e-02,  9.53603685e-02,
        1.47923566e-02, -5.58686256e-02, -2.01987177e-01, -2.75421105e-02,
        4.75574993e-02, -1.08102616e-02,  5.95078953e-02,  1.26588587e-02] # Close to L2

L5 = [-0.09945749, -0.00729111,  0.0092897 , -0.13243762, -0.06422047,
       -0.02094417, -0.04948308, -0.12064691,  0.25643739, -0.19205171,
        0.15657693, -0.03121898, -0.15308823, -0.02828152, -0.00710347,
        0.11809425, -0.14299625, -0.16806611, -0.03130123, -0.08865803,
       -0.0071869 , -0.00937061,  0.06185013,  0.10348818, -0.18077886,
       -0.43158019, -0.17442586, -0.08369756,  0.00713679, -0.08146362,
       -0.00203652,  0.09452251, -0.24805595, -0.02332739, -0.00440642,
        0.13737108,  0.00089538, -0.04461086,  0.17354517,  0.02099614,
       -0.22964232, -0.14414147,  0.07377731,  0.21512158,  0.12966961,
        0.03000744,  0.01046804, -0.0051102 ,  0.04499209, -0.1823051 ,
        0.07896246,  0.11629909,  0.02137423,  0.02415319,  0.06205415,
       -0.12419473,  0.01515957,  0.06340452, -0.1500473 , -0.01087676,
        0.02246305, -0.0924818 , -0.09429674, -0.01974701,  0.25166726,
        0.16988155, -0.09064031, -0.15273461,  0.21510246, -0.17729256,
        0.00261592,  0.02652721, -0.13491498, -0.17640282, -0.31118405,
       -0.00512062,  0.41723928,  0.13354909, -0.09930452,  0.10033775,
       -0.06307391, -0.02699157,  0.04080637,  0.13098213, -0.08033849,
        0.16044492, -0.04734115,  0.12942326,  0.14534265,  0.0249849 ,
       -0.06554834,  0.13151604, -0.07915305,  0.08410332,  0.07018198,
        0.06627715, -0.11851253, -0.02576792, -0.18880717, -0.00411349,
        0.08233207,  0.04832725, -0.01709246,  0.15401676, -0.15097997,
        0.16647491,  0.01185772, -0.11977788, -0.02823763,  0.08750527,
       -0.10837749, -0.12731393,  0.11664411, -0.22722226,  0.09817819,
        0.16637388, -0.01940754,  0.21179773,  0.06896579,  0.0847318 ,
        0.00796246, -0.01696757, -0.19169487, -0.03898101,  0.0400917 ,
       -0.03423833,  0.08150289,  0.0139573] # Close to L2


init_data = np.array([L1, L2, L3])
init_classes = np.array([1, 2, 3], dtype=np.uint8)

sgd_clf = linear_model.SGDClassifier(loss="modified_huber", max_iter=100, warm_start=True)

sgd_clf.fit(init_data, init_classes)

print("fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

## Create buffer of incoming samples ##
data_buffer = np.empty((0, 128))
class_buffer = np.empty((0, 1), dtype=np.uint8)

for class_idx in init_classes:
    for i in range(1000):
        data_buffer = np.append(data_buffer, [init_data[class_idx - 1]], axis=0)
        class_buffer = np.append(class_buffer, class_idx)


## fit buffered data
print(data_buffer)
print(data_buffer.shape)
print(class_buffer)
print(class_buffer.shape)
sgd_clf.partial_fit(data_buffer, class_buffer)

print("partial_fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))

这篇关于sklearn partial_fit()未显示准确的结果为fit()的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆