利用pytorch训练网络—垃圾分类,(resnet18)

  • Post author:
  • Post category:其他


数据集包含6种垃圾,分别为cardboard(纸箱),glass(玻璃)、metal(金属)、paper(纸)、plastic(塑料)、其他废品(trash),数据数量较小,仅供学习。

数据集标准备工作,包括将数据集分为训练集和测试集,制作标签文件。代码utils.py

import os
import shutil
import json
path="e://dataset//Garbage_classification"#此路径为上图中六类的目录,可根据自己数据集路径修改
classes=[garbage for garbage in os.listdir(path)]

if os.path.exists(os.path.join(os.getcwd(),'train'))==False:
    os.makedirs(os.path.join(os.getcwd(),'train'))
if os.path.exists(os.path.join(os.getcwd(),'val'))==False:
    os.makedirs(os.path.join(os.getcwd(),'val'))
f = open("garbage_train.json", 'w')
g = open("garbage_val.json", 'w')
for garbage in classes:
    s = 0
    for imgname in os.listdir(os.path.join(path,garbage)):



版权声明:本文为YANGJINCHAHAHA原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。