深度学习实战——利用Pytorch实现iris数据集的分类

  • Post author:
  • Post category:其他


# iris_multi-classfication.py
"""
Pytorch in action
Pytorch实现iris数据集的分类
"""

from tkinter import HIDDEN
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.datasets import load_iris
from torch.autograd import Variable
from torch.optim import SGD

# 动态的判断GPU是否可用,方便在不同类型的处理器上迁移
use_cuda = torch.cuda.is_available()
print("use_cuda:",use_cuda)

# 加载数据集
iris = load_iris()
print(iris.keys())

# 数据预处理,包括从数据集里区分输入输出,最后把输入输出数据封装成Pytorch期望的Variable格式
x = iris['data']        # 特征信息
y = iris['target']      # 目标分类
print(x.shape)          # (150,4)
print(y.shape)          # (150,)

print(y)
x = torch.FloatTensor(x)
y = torch.LongTensor(y)
x,y = Variable(x),Variable(y)

# Pytorch中自定义的模型都需要继承Module,并重写forward方法完成前向计算过程
class Net(torch.nn.Module):
    # 初始化函数,接受自定义输入特征维数,隐含层特征维数,输出层特征维数
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(n_feature,n_hidden)   # 一个线性隐藏层
        self.predict = torch.nn.Linear(n_hidden,n_output)   # 一个线性输出层
    # 前向传播过程
    def forward(self,x):
        x = F.sigmoid(self.hidden(x))
        x = self.predict(x)
        out = F.log_softmax(x,dim=1)
        return out

# 网络实例化并打印查看网络结构
# iris中输入特征4维,隐藏层和输出层可以自己选择
net = Net(n_feature=4,n_hidden=5,n_output=4)
print(net)

# 如果GPU可用,就将训练数据和模型都放到GPU上
# 调用cuda()函数就可以将相应模块放到GPU上
if use_cuda:
    x = x.cuda()
    y = y.cuda()
    net = net.cuda()

# 定义神经网络训练的优化器,并且设置学习率为0.5
optimizer = SGD(net.parameters(),lr=0.5)

# 训练过程
px,py = [],[]       # 记录要绘制的数据
for i in range(1000):
    # 数据集传入网络前向计算
    prediction = net(x)
    # 计算Loss
    loss = F.nll_loss(prediction,y)
    # 清除网络状态
    optimizer.zero_grad()
    # loss反向传播
    loss.backward()
    # 更新参数
    optimizer.step()

    # 打印并记录当前的index和loss
    print(i,"loss:",loss.item())
    px.append(i)
    py.append(loss.item())
    
    # 每十次迭代绘制训练动态
    if i % 10 == 0:
        plt.cla()
        plt.plot(px,py,'r-',lw=1)
        plt.text(0,0,'Loss=%.4f'%loss.item(),fontdict={'size':20,'color':'red'})
        plt.pause(0.1)



版权声明:本文为ShirasawaHao原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。