前面说了tensorflow的基础知识,本篇利用前面提到的知识来进行鸢尾花数据集实战
话不多说,开整!
首先导入相关库
# 导入相关库
import tensorflow as tf
# 调用sklearn中自带iris数据集
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
# 用于切分数据集
from sklearn.model_selection import train_test_split
导入库之后将鸢尾花数据集获取出来
# 导入数据集,分为x_data和y_data
x_data = load_iris().data
y_data = load_iris().target
查看鸢尾花信息
print(len(x_data))
print(len(y_data))
print(x_data[:10])
print(y_data[:10])
可以看到鸢尾花数据集有150条数据,每条数据有4个维度,分类标签有0,1,2(因为0-50标签为0,51-100标签为1,101-150标签为2),这里只看了前10条,所以y值只有0
之后对数据集进行切分
# 生成训练集和测试集,训练集120个数据,测试集30个数据
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
此时训练集和测试集已经切分完毕,
我们需要进行数据类型转换,把数据集中每条数据变成tensorflow中支持的类型
# 转换x的数据类型,否则后面矩阵相乘会因为数据类型不一致报错
x_train = tf.cast(x_train,tf.float32)
x_test = tf.cast(x_test,tf.float32)
因为数据集有点大,所以我们切分成batch
# 对于一个有 2000 个训练样本的数据集。将 2000 个样本分成大小为 500 的 batch,那么完成一个 epoch 需要 4 个 iteration。
# 如果把准备训练数据比喻成一块准备打火锅的牛肉,那么epoch就是整块牛肉,batch就是切片后的牛肉片,iteration就是涮一块牛肉片
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)
鸢尾花数据集中有四个特征,三种鸢尾花,所以构建一个神经网络(图不好看,汗)
输入层为feature1,feature2,feature3,feature4
输出层为标签0,1,2
# 生成神经网络的参数,4个输入特征,所以输入层有4个输入节点,因为三分类,所以输出层为3个神经元
w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
之后初始化一些参数,便可以正式训练
lr = 0.1 #学习率0.1
train_loss_results = [] #将每轮的loss记录在此列表中,为后续画图提供数据
test_acc = [] #将每轮的acc记录在列表中,为后续画acc曲线提供数据
epoch=500 #循环500轮
loss_all = 0 #每轮分为4个step,loss_all记录四个step生成的4个loss的和
完整代码:
# 导入相关库
import tensorflow as tf
# 调用sklearn中自带iris数据集
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
# 用于切分数据集
from sklearn.model_selection import train_test_split
# 导入数据集,分为x_data和y_data
x_data = load_iris().data
y_data = load_iris().target
# 生成训练集和测试集,训练集120个数据,测试集30个数据
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
# 转换x的数据类型,否则后面矩阵相乘会因为数据类型不一致报错
x_train = tf.cast(x_train,tf.float32)
x_test = tf.cast(x_test,tf.float32)
# 对于一个有 2000 个训练样本的数据集。将 2000 个样本分成大小为 500 的 batch,那么完成一个 epoch 需要 4 个 iteration。
# 如果把准备训练数据比喻成一块准备打火锅的牛肉,那么epoch就是整块牛肉,batch就是切片后的牛肉片,iteration就是涮一块牛肉片
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)
# 生成神经网络的参数,4个输入特征,所以输入层有4个输入节点,因为三分类,所以输出层为3个神经元
w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
lr = 0.1 #学习率0.1
train_loss_results = [] #将每轮的loss记录在此列表中,为后续画图提供数据
test_acc = [] #将每轮的acc记录在列表中,为后续画acc曲线提供数据
epoch=500 #循环500轮
loss_all = 0 #每轮分为4个step,loss_all记录四个step生成的4个loss的和
for epoch in range(epoch): #数据集级别循环,每个epoch循环一次数据集 等同于for i in range(500)
for step,(x_train,y_train) in enumerate(train_db): #batch级别的循环
# tf.GradientTape with结构记录计算过程,gradient求出张量的梯度
with tf.GradientTape() as tape:
# 矩阵相乘y=w*x+b
y = tf.matmul(x_train,w1)+b1
# 当一个样本经过softmax层并输出一个向量,会取这个向量中值最大的那个数的index作为这个样本的预测标签
y = tf.nn.softmax(y)
y_ = tf.one_hot(y_train,depth=3)
# 计算loss的均值
loss = tf.reduce_mean(tf.square(y_-y))
loss_all +=loss.numpy()
grads=tape.gradient(loss,[w1,b1])
#参数w自更新
w1.assign_sub(lr*grads[0])
# 参数b自更新
b1.assign_sub(lr*grads[1])
print(f'epoch {epoch},loss:{loss_all/4}')
train_loss_results.append(loss_all/4)
loss_all=0
# 测试部分
total_correct,total_number=0,0
for x_test,y_test in test_db:
y=tf.matmul(x_test,w1)+b1
y=tf.nn.softmax(y)
pred=tf.argmax(y,axis=1)
pred=tf.cast(pred,dtype=y_test.dtype)
correct = tf.cast(tf.equal(pred,y_test),dtype=tf.int32)
correct=tf.reduce_sum(correct)
total_correct += int(correct)
total_number += x_test.shape[0]
acc = total_correct/total_number
test_acc.append(acc)
print('Test acc:',acc)
print('-'*20)
plt.title('loss function curve')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.plot(train_loss_results,label='$loss$')
plt.legend()
plt.show()
plt.title('acc curve')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.plot(test_acc,label='$accuracy$')
plt.legend()
plt.show()
版权声明:本文为bester_man原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。