目录
1
MNIST数据集
MNIST数据集
是一组由美国高中生和人口调查员手写的70 000个数字的图片,每张图像都用其代表的数字标记。
获取
MNIST数据集:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
mnist["data"],mnist["target"]
运行结果如下:
(array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]),
array(['5', '0', '4', ..., '4', '5', '6'], dtype=object))
Scikit-Learm加载的数据集通常具有类似的字典结构,包括:
-
DESCR键
:描述数据集 -
data键
:包含一个数组,每个实例为一行,每个特征为一列 -
target键
:包含一个带有标记的数组
我们来看看这些数组:
X, y = mnist["data"], mnist["target"]
X.shape, y.shape
运行结果如下:
((70000, 784), (70000,))
共有7万张图片,每张图片有784个特征。图片是28×28像素,每个特征代表一个像素点的强度。
我们随便抓取一个实例的特征向量:
import matplotlib as mpl
import matplotlib.pyplot as plt
# 抓取一个实例的特征向量
# 提取行数据
#方法一:
import numpy as np
some_digit = np.array(X[36000,])
#方法二:
# some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
#可视化
plt.imshow(some_digit_image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
plt.show()
#查看手写字标签
y[36000]
运行结果如下:
'9'
此处运行结果与书上的“5”不一样是由于fetch_openml()返回给我们的是未排序的数据。
有关
plt.imshow()
用法如下:(参考博客:
plt.imshow()_小小程序猩的博客-CSDN博客_plt.imshow
)
plt.imshow(
X,
cmap=None,
norm=None,
aspect=None,
interpolation=None,
alpha=None,
vmin=None,
vmax=None,
origin=None,
extent=None,
shape=None,
filternorm=1,
filterrad=4.0,
imlim=None,
resample=None,
url=None,
*,
data=None,
**kwargs,
)
**X:**
图像数据。支持的数组形状是:
(M,N) :带有标量数据的图像。数据可视化使用色彩图。
(M,N,3) :具有RGB值的图像(float或uint8)。
(M,N,4) :具有RGBA值的图像(float或uint8),即包括透明度。
前两个维度(M,N)定义了行和列图片,即图片的高和宽;
RGB(A)值应该在浮点数[0, ..., 1]的范围内,或者
整数[0, ... ,255]。超出范围的值将被剪切为这些界限。
**cmap:**
将标量数据映射到色彩图
颜色默认为:rc:image.cmap。
**norm :**
~matplotlib.colors.Normalize
如果使用scalar data ,则Normalize会对其进行缩放[0,1]的数据值内。
默认情况下,数据范围使用线性缩放映射到颜色条范围。 RGB(A)数据忽略该参数。
**aspect:**
{'equal','auto'}或float,可选
控制轴的纵横比。该参数可能使图像失真,即像素不是方形的。
equal:确保宽高比为1,像素将为正方形。(除非像素大小明确地在数据中变为非正方形,坐标使用 extent )。
auto: 更改图像宽高比以匹配轴的宽高比。通常,这将导致非方形像素。
**interpolation:**
str
使用的插值方法
支持的值有:'none', 'nearest', 'bilinear', 'bicubic','spline16', 'spline36', 'hanning', 'hamming', 'hermite', 'kaiser',
'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc','lanczos'.
如果interpolation = 'none',则不执行插值
**alpha:**
alpha值,介于0(透明)和1(不透明)之间。RGBA输入数据忽略此参数。
**vmin, vmax : scalar,**
如果使用* norm 参数,则忽略 vmin , vmax *。
vmin,vmax与norm结合使用以标准化亮度数据。
**origin : {'upper', 'lower'}**
将数组的[0,0]索引放在轴的左上角或左下角。
'upper'通常用于矩阵和图像。
请注意,垂直轴向上指向“下”但向下指向“上”。
**extent:(left, right, bottom, top)**
数据坐标中左下角和右上角的位置。 如果为“无”,则定位图像使得像素中心落在基于零的(行,列)索引上。
这里我们不需要创建测试集,因为
MNIST数据集
已经分成训练集(前6万张图像)和测试集(后1万张图像)了:
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
同样,我们需要先将训练数据洗牌,这样能保证交叉验证时所有的折叠都差不多(避免机器学习算法对训练实例的顺序敏感)。
import numpy as np
# 训练集随机重新排列
shuffle_index = np.random.permutation(60000) #生成一个随机排列的数组
X_train,y_train = X_train[shuffle_index], y_train[shuffle_index]
2 训练一个二元分类器
现在,我们先简化问题,只尝试识别一个数字——比如数字9。那么这个“数字9检测器” 就是一个二元分类器的例子,它只能区分两个类别:9和非9。先为此分类任务创建目标向量:
# 先尝试识别一个数字
y_train_9 = (y_train == '9')
y_test_9 = (y_test == '9')
接着挑选一个分类器并开始训练。一个好的初始选择是随机梯度下降(SGD(stochastic gradient descend))分类器, 使用 Scikit-Learn的SGDClassifier类即