聚类分析(K-means) python代码实现

  • Post author:
  • Post category:python





聚类分析(K-means) python代码实现

​ 今天在课程上再一次系统接触到了聚类分析,当然,以前在使用spss时也解决过此类问题(这还要追溯到18年),但不得不说,软件的便利化使得我对于原理又疏忽了很多。所以,空想不如实干,今天抽出一些时间用python对聚类分析的最基础的一种(K-means)进行了实现,下面,我将就原理、案例及代码进行分析。

​ 首先,提到聚类分析,我们会想到:群体划分、客户分类、基因聚类等应用,它的原理简单来说归纳为一句话“物以类聚,人以群分”,确实,我们在运用此方法时,所需做到的最终结果就是将具有相同特征的归为一类,不同则分组。

​ 那么,作为一种无监督学习方法(无标签、数据驱动),它不需要类别标柱,我们可以直接从数据中探索数据间的结构联系。在解决此类问题时,我们会有基于

分割

和基于

层次

的两类方法,再细化来说,可以有KMeans、Sequential Leader、Density Based Methods等方法。对于K-means方法来说,可以说是聚类分析中最经典的算法,同时运用也最为普遍。

​ 其中,我们在利用K-means解决聚类问题时,需要明确

最重要的两点


  • 相似性度量

    (可分为

    数值

    型、

    二值属性

    和包含

    分类属性



    数值属性

    的混合属性三种情况来进行分析),

    • 对于数值型,可使用距离度量(欧氏距离、曼哈顿距离);
    • 对于二值属性,可利用冗余矩阵来进行不变相似性的测量**(对于不对称的二值变量,如果取值1比0重要,那么这样的二值变量就只有一种状态。)**
    • 对于混合属性,若属性f为二元属性或标称属性,二者相同距离为0,不同距离为1;若属性f为序数型属性,那么d(p,q)=|p-q|/(n-1);若属性f为数值型属性,那么d(p,q)=distance(p,q)

      (注意距离和相似性的倒数关系)

  • 对于

    中心点的选取

    ,我们领域专家知识,这也是此类方法的缺点之一。

    整个算法的详细步骤呢,我从课程ppt中截取了一张图,供大家参考:


    (所有权归老师)


    在这里插入图片描述

以上就是我对于原理的简要阐述,那么回到

案例

​ 对表中二维数据使用k-means算法进行聚类,划分为2个簇,假设初始簇中心选择P7(4,5)和P10(5,5)。

P1 P2 P3 P4 P5 P6 P7 P8 P9 P10
x 3 3 7 4 3 8 4 4 7 5
y 4 6 3 7 8 5 5 1 4 5

​ 针对上述这个题目呢,我的

编程思路

如下所述:

  1. 初始化数据,并进行初始图像的绘制(未分类的散点图);
  2. 进行K-means步骤的编程(其中距离度量采用欧氏距离)
  3. 更新图像,进行两次聚类后结果已趋于不变

​ 下面就是整体的代码实现过程(由于时间紧迫,

没有采用线程进行动态模拟,知识利用循环进行了迭代

):

import numpy as np
import matplotlib.pyplot as plt
import math

# 数据初始化
train = np.array([[3, 4], [3, 6], [7, 3], [4, 7], [3, 8], [8, 5], [4, 5], [4, 1], [7, 4], [5, 5]])

# 初始簇中心选择(4,5),(5,5)
center = np.array([[4, 5], [5, 5]], dtype=float)

# 初始化图像
def initgraph(train):
    plt.scatter(train[:, 0], train[:, 1], color='b')
    plt.scatter(center[:, 0], center[:, 1], color='r')
    plt.grid()
    plt.xlabel("x")
    plt.title("0时刻的初始图像", fontproperties='KaiTi', fontsize=25)
    plt.xticks(np.arange(0, 10, 1))
    plt.ylabel("y")
    plt.yticks(np.arange(0, 10, 1))
    plt.xlim(0, 10)
    plt.ylim(0, 10)
    plt.show()

# 利用欧几里得距离
def kmeans(train, center, title):
    # 存储每次迭代的聚类情况
    list_d1 = []
    list_d2 = []
    for i in range(train.shape[0]):
        # 欧几里得距离
        d1 = math.sqrt(math.pow((train[i, 0] - center[0, 0]), 2) + math.pow((train[i, 1] - center[0, 1]), 2))
        d2 = math.sqrt(math.pow((train[i, 0] - center[1, 0]), 2) + math.pow((train[i, 1] - center[1, 1]), 2))
        # 曼哈顿距离
        # d1 = (abs(train[i, 0] - center[0, 0])) + abs((train[i, 1] - center[0, 1]))
        # d2 = (abs(train[i, 0] - center[1, 0])) + abs((train[i, 1] - center[1, 1]))
        if d1 < d2:
            list_d1.append(i)
        else:
            list_d2.append(i)
    update(list_d1, list_d2, train, title)

# 迭代更新图像
def update(list_d1, list_d2, train,title):
    center1 = np.empty([1, 2])
    center2 = np.empty([1, 2])
    for index in range(len(list_d1)):
        plt.scatter(train[list_d1[index], 0], train[list_d1[index], 1], color='g')
        train1 = np.array([train[list_d1[index], 0], train[list_d1[index], 1]])
        if index == 0:
            center1[index] = train1
        else:
            center1 = np.insert(center1, index, train1, axis=0)
    for index in range(len(list_d2)):
        plt.scatter(train[list_d2[index], 0], train[list_d2[index], 1], color='y')
        train2 = np.array([[train[list_d2[index], 0], train[list_d2[index], 1]]])
        if index == 0:
            center2[index] = train2
        else:
            center2 = np.insert(center2, index, train2, axis=0)
    # 确定新的中心点
    plt.scatter(np.average(center1[:, 0]), np.average(center1[:, 1]), color='r')
    plt.scatter(np.average(center2[:, 0]), np.average(center2[:, 1]), color='r')
    center[0] = np.array([[np.average(center1[:, 0]), np.average(center1[:, 1])]])
    center[1] = np.array([[np.average(center2[:, 0]), np.average(center2[:, 1])]])
    plt.grid()
    plt.xlabel("x")
    plt.title("第"+str(title)+"次聚类结果示意图", fontproperties='KaiTi', fontsize=25)
    title += 1
    plt.xticks(np.arange(0, 10, 1))
    plt.ylabel("y")
    plt.yticks(np.arange(0, 10, 1))
    plt.xlim(0, 10)
    plt.ylim(0, 10)
    plt.show()


if __name__ == '__main__':
    initgraph(train)
    # 进行两次聚类,结果如图
    title = 1
    for i in range(2):
        kmeans(train, center,title)
        title += 1

​ 结果图像见下:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

​ 以上总共进行了两次聚类,图中用不同点来表示不同分类(红点为中心点,绿点为第一类,黄点为第二类),因此问题得以解决!

​ 本篇博文到这里就结束了,聚类分析在我们的日常数据分析的运用中极其广泛,所以希望大家可以多翻阅相应资料进行学习,

文中有错误的地方还希望大家指正

,谢谢大家!

(ps:昨日重装四次MySQL,对于MySQL的使用,远程连接云服务器的要点掌握了很多,近日将写一篇相关文章,大家敬请关注!)



版权声明:本文为qq_42956790原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。