scikit-learn中用于聚类的超参数评估的网格搜索 [英] Grid search for hyperparameter evaluation of clustering in scikit-learn

查看:419
本文介绍了scikit-learn中用于聚类的超参数评估的网格搜索的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在对大约100条记录(未标记)的样本进行聚类,并尝试使用grid_search评估具有各种超参数的聚类算法。我正在使用 silhouette_score 得分,效果很好。

I'm clustering a sample of about 100 records (unlabelled) and trying to use grid_search to evaluate the clustering algorithm with various hyperparameters. I'm scoring using silhouette_score which works fine.

我的问题是我不需要使用 GridSearchCV / RandomizedSearchCV 的交叉验证方面,但是我找不到简单的 GridSearch / RandomizedSearch 。我可以编写自己的对象,但是 ParameterSampler ParameterGrid 对象非常有用。

My problem here is that I don't need to use the cross-validation aspect of the GridSearchCV/RandomizedSearchCV, but I can't find a simple GridSearch/RandomizedSearch. I can write my own but the ParameterSampler and ParameterGrid objects are very useful.

下一步是继承 BaseSearchCV 并实现我自己的 _fit()方法,但是认为值得一问的是,是否有更简单的方法来做到这一点,例如通过将某些内容传递给 cv 参数?

My next step will be to subclass BaseSearchCV and implement my own _fit() method, but thought it was worth asking is there a simpler way to do this, for example by passing something to the cv parameter?

def silhouette_score(estimator, X):
    clusters = estimator.fit_predict(X)
    score = metrics.silhouette_score(distance_matrix, clusters, metric='precomputed')
    return score

ca = KMeans()
param_grid = {"n_clusters": range(2, 11)}

# run randomized search
search = GridSearchCV(
    ca,
    param_distributions=param_dist,
    n_iter=n_iter_search,
    scoring=silhouette_score,
    cv= # can I pass something here to only use a single fold?
    )
search.fit(distance_matrix)


解决方案

clusteval 库将帮助您评估数据并找到最佳的簇数。该库包含五种可用于评估聚类的方法。 剪影 dbindex 衍生物,* dbscan *和 hdbscan

The clusteval library will help you to evaluate the data and find the optimal number of clusters. This library contains five methods that can be used to evaluate clusterings; silhouette, dbindex, derivative, *dbscan *and hdbscan.

pip install clusteval

取决于数据,可以选择评估方法。

Depending on your data, the evaluation method can be chosen.

# Import library
from clusteval import clusteval

# Set parameters, as an example dbscan
ce = clusteval(method='dbscan')

# Fit to find optimal number of clusters using dbscan
results= ce.fit(X)

# Make plot of the cluster evaluation
ce.plot()

# Make scatter plot. Note that the first two coordinates are used for plotting.
ce.scatter(X)

# results is a dict with various output statistics. One of them are the labels.
cluster_labels = results['labx']

这篇关于scikit-learn中用于聚类的超参数评估的网格搜索的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆