kmeans python_使用python+sklearn实现Kmeans聚类

  • Post author:
  • Post category:python


注意:单击此处https://urlify.cn/zUfeai下载完整的示例代码,或通过Binder在浏览器中运行此示例

以下图首先显示了使用K-means算法产生三个聚类的结果,然后显示了不好的初始化对分类过程的影响:通过将n_init设置为1(默认值为10),可以减少使用不同质心种子运行算法的次数。下一张图显示了使用八个聚类可以传递的内容,最后显示了真实类的情况。

883ebe19c91c901729bc96c478411e8e.png
acb66f68eee3e9d4815406a501ba1333.png
b70445ebb2116e66383c270a5aca131f.png
e0cc7e868766ae63cf505ca8c8b04509.png



print(__doc__)






# 源代码: Gaël Varoquaux


# 由Jaques Grobler修改过文档


# 许可证: BSD 3 clause




import numpy as np


import matplotlib.pyplot as plt


# 尽管不直接使用以下导入的库,但它对于


# 3D投影来说是必需的


from mpl_toolkits.mplot3d import Axes3D




from sklearn.cluster import KMeans


from sklearn import datasets




np.random.seed(5)




iris = datasets.load_iris()


X = iris.data


y = iris.target




estimators = [('k_means_iris_8', KMeans(n_clusters=8)),


('k_means_iris_3', KMeans(n_clusters=3)),


('k_means_iris_bad_init', KMeans(n_clusters=3, n_init=1,


init='random'))]




fignum = 1


titles = ['8 clusters', '3 clusters', '3 clusters, bad initialization']


for name, est in estimators:


fig = plt.figure(fignum, figsize=(4, 3))


ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)


est.fit(X)


labels = est.labels_




ax.scatter(X[:, 3], X[:, 0], X[:, 2],


c=labels.astype(np.float), edgecolor='k')




ax.w_xaxis.set_ticklabels([])


ax.w_yaxis.set_ticklabels([])


ax.w_zaxis.set_ticklabels([])


ax.set_xlabel('Petal width')


ax.set_ylabel('Sepal length')


ax.set_zlabel('Petal length')


ax.set_title(titles[fignum - 1])


ax.dist = 12


fignum = fignum + 1




# 绘制真实类


fig = plt.figure(fignum, figsize=(4, 3))


ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)




for name, label in [('Setosa', 0),


('Versicolour', 1),


('Virginica', 2)]:


ax.text3D(X[y == label, 3].mean(),


X[y == label, 0].mean(),


X[y == label, 2].mean() + 2, name,


horizontalalignment='center',


bbox=dict(alpha=.2, edgecolor='w', facecolor='w'))


# 重新排列标签以使其颜色与聚类结果匹配


y = np.choose(y, [1, 2, 0]).astype(np.float)


ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y, edgecolor='k')




ax.w_xaxis.set_ticklabels([])


ax.w_yaxis.set_ticklabels([])


ax.w_zaxis.set_ticklabels([])


ax.set_xlabel('Petal width')


ax.set_ylabel('Sepal length')


ax.set_zlabel('Petal length')


ax.set_title('Ground Truth')


ax.dist = 12




fig.show()



脚本的总运行时间:

(0分钟0.629秒)


估计的内存使用量:

8 MB

ae1e233286d28f2eb1d7cd924a2aae27.png


下载Python源代码: plot_cluster_iris.py


下载Jupyter notebook源代码: plot_cluster_iris.ipynb

由Sphinx-Gallery生成的画廊

文壹由“伴编辑器”提供技术支持


☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏


文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:


fc5531d6a41fd4347e579bb919ccffc1.png

欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:

mthler




备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)

7fde6e2d3cc29fd938d35bb5735daa94.png