手写数字数据集——MINST的读取及预处理

  • Post author:
  • Post category:其他




MNIST是一个手写数字集合,每张图片分辨率为28×28,像素点数值取值范围0~255。


mnist.pkl文件的内容info是一个元组,包括训练集、验证集、测试集,它







没有




直接使用原始图像,而是将其转换成









向量的形式







它已经把手写识别的图片(28*28)转化成了一个向量(1,784),


向量中的每一维分别代表原始图像中对应像素点的灰度值







然后给出了这个图片的标识0-9。

首先要打开pkl文件,需要用到py包_pickle,这里的open要用rb,因为是要以二进制的方式读取文件。(我这里的pkl文件还是个压缩包,所以使用gzip打开)

import _pickle as cPickle
import gzip

f = gzip.open("MNIST\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')

可以到训练集、验证集、测试集分别有50000,10000,10000张

其中训练集分为两部分,第一维存储的是图像对应的50000个1*784的向量,第二维存储的50000个是图像对应的数字标签。

读取训练集的第一张图片看看


数据的预处理

将图像从行向量(1*784)转换成列向量(784*1),并将图像对应的数字也转换成numpy类型的列向量(10*1),数字对应的索引置1,其余位置则为0。

training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))

validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))

test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))
def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e


完整代码:

import matplotlib.pyplot as plt
import numpy as np
import _pickle as cPickle
import gzip


def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e

f = gzip.open("MNIST\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')

#print(type(training_data[0])) #50000张784*1的图像
#print(training_data[0][0].shape)
#print(type(training_data[1])) #50000个数字标签


training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))

validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))

test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))
#print(len(training_data))
#print(training_data)


img = training_inputs[0]
img = img.reshape(28,-1)
print(type(img))
plt.imshow(img)
plt.show()


使用 plt.imshow() 方法进行画图,它将灰度图像按照灰度值的高低映射成彩色图像。



版权声明:本文为whxhy666原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。