聚类分析(K-means) python代码实现
聚类分析(K-means) python代码实现
今天在课程上再一次系统接触到了聚类分析,当然,以前在使用spss时也解决过此类问题(这还要追溯到18年),但不得不说,软件的便利化使得我对于原理又疏忽了很多。所以,空想不如实干,今天抽出一些时间用python对聚类分析的最基础的一种(K-means)进行了实现,下面,我将就原理、案例及代码进行分析。
首先,提到聚类分析,我们会想到:群体划分、客户分类、基因聚类等应用,它的原理简单来说归纳为一句话“物以类聚,人以群分”,确实,我们在运用此方法时,所需做到的最终结果就是将具有相同特征的归为一类,不同则分组。
那么,作为一种无监督学习方法(无标签、数据驱动),它不需要类别标柱,我们可以直接从数据中探索数据间的结构联系。在解决此类问题时,我们会有基于
分割
和基于
层次
的两类方法,再细化来说,可以有KMeans、Sequential Leader、Density Based Methods等方法。对于K-means方法来说,可以说是聚类分析中最经典的算法,同时运用也最为普遍。
其中,我们在利用K-means解决聚类问题时,需要明确
最重要的两点
:
-
相似性度量
(可分为
数值
型、
二值属性
和包含
分类属性
和
数值属性
的混合属性三种情况来进行分析),- 对于数值型,可使用距离度量(欧氏距离、曼哈顿距离);
- 对于二值属性,可利用冗余矩阵来进行不变相似性的测量**(对于不对称的二值变量,如果取值1比0重要,那么这样的二值变量就只有一种状态。)**
-
对于混合属性,若属性f为二元属性或标称属性,二者相同距离为0,不同距离为1;若属性f为序数型属性,那么d(p,q)=|p-q|/(n-1);若属性f为数值型属性,那么d(p,q)=distance(p,q)
(注意距离和相似性的倒数关系)
;
-
对于
中心点的选取
,我们领域专家知识,这也是此类方法的缺点之一。整个算法的详细步骤呢,我从课程ppt中截取了一张图,供大家参考:
(所有权归老师)
以上就是我对于原理的简要阐述,那么回到
案例
:
对表中二维数据使用k-means算法进行聚类,划分为2个簇,假设初始簇中心选择P7(4,5)和P10(5,5)。
P1 | P2 | P3 | P4 | P5 | P6 | P7 | P8 | P9 | P10 | |
---|---|---|---|---|---|---|---|---|---|---|
x | 3 | 3 | 7 | 4 | 3 | 8 | 4 | 4 | 7 | 5 |
y | 4 | 6 | 3 | 7 | 8 | 5 | 5 | 1 | 4 | 5 |
针对上述这个题目呢,我的
编程思路
如下所述:
- 初始化数据,并进行初始图像的绘制(未分类的散点图);
- 进行K-means步骤的编程(其中距离度量采用欧氏距离)
- 更新图像,进行两次聚类后结果已趋于不变
下面就是整体的代码实现过程(由于时间紧迫,
没有采用线程进行动态模拟,知识利用循环进行了迭代
):
import numpy as np
import matplotlib.pyplot as plt
import math
# 数据初始化
train = np.array([[3, 4], [3, 6], [7, 3], [4, 7], [3, 8], [8, 5], [4, 5], [4, 1], [7, 4], [5, 5]])
# 初始簇中心选择(4,5),(5,5)
center = np.array([[4, 5], [5, 5]], dtype=float)
# 初始化图像
def initgraph(train):
plt.scatter(train[:, 0], train[:, 1], color='b')
plt.scatter(center[:, 0], center[:, 1], color='r')
plt.grid()
plt.xlabel("x")
plt.title("0时刻的初始图像", fontproperties='KaiTi', fontsize=25)
plt.xticks(np.arange(0, 10, 1))
plt.ylabel("y")
plt.yticks(np.arange(0, 10, 1))
plt.xlim(0, 10)
plt.ylim(0, 10)
plt.show()
# 利用欧几里得距离
def kmeans(train, center, title):
# 存储每次迭代的聚类情况
list_d1 = []
list_d2 = []
for i in range(train.shape[0]):
# 欧几里得距离
d1 = math.sqrt(math.pow((train[i, 0] - center[0, 0]), 2) + math.pow((train[i, 1] - center[0, 1]), 2))
d2 = math.sqrt(math.pow((train[i, 0] - center[1, 0]), 2) + math.pow((train[i, 1] - center[1, 1]), 2))
# 曼哈顿距离
# d1 = (abs(train[i, 0] - center[0, 0])) + abs((train[i, 1] - center[0, 1]))
# d2 = (abs(train[i, 0] - center[1, 0])) + abs((train[i, 1] - center[1, 1]))
if d1 < d2:
list_d1.append(i)
else:
list_d2.append(i)
update(list_d1, list_d2, train, title)
# 迭代更新图像
def update(list_d1, list_d2, train,title):
center1 = np.empty([1, 2])
center2 = np.empty([1, 2])
for index in range(len(list_d1)):
plt.scatter(train[list_d1[index], 0], train[list_d1[index], 1], color='g')
train1 = np.array([train[list_d1[index], 0], train[list_d1[index], 1]])
if index == 0:
center1[index] = train1
else:
center1 = np.insert(center1, index, train1, axis=0)
for index in range(len(list_d2)):
plt.scatter(train[list_d2[index], 0], train[list_d2[index], 1], color='y')
train2 = np.array([[train[list_d2[index], 0], train[list_d2[index], 1]]])
if index == 0:
center2[index] = train2
else:
center2 = np.insert(center2, index, train2, axis=0)
# 确定新的中心点
plt.scatter(np.average(center1[:, 0]), np.average(center1[:, 1]), color='r')
plt.scatter(np.average(center2[:, 0]), np.average(center2[:, 1]), color='r')
center[0] = np.array([[np.average(center1[:, 0]), np.average(center1[:, 1])]])
center[1] = np.array([[np.average(center2[:, 0]), np.average(center2[:, 1])]])
plt.grid()
plt.xlabel("x")
plt.title("第"+str(title)+"次聚类结果示意图", fontproperties='KaiTi', fontsize=25)
title += 1
plt.xticks(np.arange(0, 10, 1))
plt.ylabel("y")
plt.yticks(np.arange(0, 10, 1))
plt.xlim(0, 10)
plt.ylim(0, 10)
plt.show()
if __name__ == '__main__':
initgraph(train)
# 进行两次聚类,结果如图
title = 1
for i in range(2):
kmeans(train, center,title)
title += 1
结果图像见下:
以上总共进行了两次聚类,图中用不同点来表示不同分类(红点为中心点,绿点为第一类,黄点为第二类),因此问题得以解决!
本篇博文到这里就结束了,聚类分析在我们的日常数据分析的运用中极其广泛,所以希望大家可以多翻阅相应资料进行学习,
文中有错误的地方还希望大家指正
,谢谢大家!
(ps:昨日重装四次MySQL,对于MySQL的使用,远程连接云服务器的要点掌握了很多,近日将写一篇相关文章,大家敬请关注!)