简介
这两天我用了KNN方法对Cifar10数据进行分类,结果却是差强人意,只有30%左右的正确率。
KNN算法的训练只是将训练数据集存储起来,所以训练不需要花费很多时间,但是测试就需要花费大量时间。
对于MNIST数据集,该分类器效果很好,原因我觉得主要是MNIST数据集都是黑白照片,KNN本质上是通过图象的像素差来进行计算的,所以MNIST数据集图像像素差包含的信息比较多。
代码
my_utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File : my_utils.py
# @Author: Fly_dragon
# @Date : 2019/11/29
# @Desc :
import numpy as np
def getXmean(x_train):
x_train = np.reshape(x_train, (x_train.shape[0], -1)) # Turn the image to 1-D
mean_image = np.mean(x_train, axis=0) # 求每一列均值。即求所有图片每一个像素上的平均值
return mean_image
def centralized(x_test, mean_image):
x_test = np.reshape(x_test, (x_test.shape[0], -1))
x_test = x_test.astype(np.float)
x_test -= mean_image # Subtract the mean from the graph, and you get zero mean graph
return x_test
#%% KNN class
class Knn:
def __init__(self):
pass
def fit(self, X_train, y_train):
self.Xtr = X_train
self.ytr = y_train
def predict(self, k, dis, X_test):
"""
"""
assert dis == 'E' or dis == 'M'
num_test = X_test.shape[0]
label_list = []
# 使用欧拉公式作为距离测量
if dis == 'E':
for i in range(num_test):
distances = np.sqrt(np.sum(((self.Xtr - np.tile(X_test[i],
(self.Xtr.shape
版权声明:本文为qq_44761480原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。