knn算法python实现对鸢尾花数据集分类

  • Post author:
  • Post category:python


knn算法实现对鸢尾花数据集的分类

import pandas as pd
import numpy as np


#  求每一个训练集数据对比测试集的距离,返回一个list
def distance(h):
    list_a = []
    for i in range(len(test__value)):
        A = np.array(train__value[i, 0:3])
        B = np.array(test__value[h, 0:3])
        a = np.sqrt(sum(np.power((A - B), 2)))
        list_a.append(a)
    return list_a


# 将标签添加到列表
def creat_biaoqian(list=[]):
    list__a = []
    group = list
    lable = train__value[:, 4]
    for i in range(len(group)):
        t = (group[i], lable[i])
        list__a.append(t)
        # rint(lable[i])
    return list__a


# 判断距离哪一个lable更近,确定分类
def max__lable(a=[]):
    lables = {"1": 0, "2": 0, "3": 0}
    for i in range(len(a)):
        if a[i][1] == 1:
            lables["1"] = lables["1"] + 1
        if a[i][1] == 2:
            lables["2"] = lables["2"] + 1
        if a[i][1] == 3:
            lables["3"] = lables["3"] + 1
        lable = sorted(lables.items(), key=lambda d: d[1], reverse=True)
    # print(lable)
    return lable[0][0]


test__data = pd.read_csv("iristest.csv", header=None)
train__data = pd.read_csv("iristrain.csv", header=None)
# print(test__data)
# print(train__data)
# 读取需要的数据
test__value = test__data.values
train__value = test__data.values
lable_test = test__value[:, 4]
la = lable_test.tolist()
count__True = 0
for i in range(len(test__value)):
    List_test = []
    List_test = distance(i)
    # 把生成的列表进行小到大的排序  在默认情况下sort和sorted函数接收的参数是元组时,它将会先按元组的第一个元素进行排序再按第二个元素进行排序,再按第三个、第四个…依次排序。
    a = creat_biaoqian(List_test)
    # List_test.sort()
    a.sort()
    # print(a)
    # print(len(List_test))
    # 3)选取距离最小的K个点
    K = 7
    List_a = a[:K]
    #
    # print(a)
    max = float(max__lable(List_a))
    # print("出现次数最多的标签是 %f"%max)
    # print(la[i])
    if max == la[i]:
        # print("第{}个数据预测结果是{},与测试集结果相同".format(i+1 , max))
        count__True += 1
    else:
        print("第{}数据测出来是{},测试集的标签是{}".format(i + 1,max,la[i]))
Ture__sb = count__True / len(test__value)
# print(count__True)
print("当k值取{}时精度是:	{:.2%}".format(K, Ture__sb))



版权声明:本文为qq_45691937原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。