from glob import glob
path = glob('./data/potato_data/*/*') # 所有的图片路径
label = [i.split('\\')[1] for i in path] # 所有图片对应的标签
label_dict = {'Early_blight':0, 'Late_blight':1, 'healthy':2}
label = [label_dict[i] for i in label ] # 把字符标签转成数值
# 通过构建DataSET的方式读取数据
train = tf.data.Dataset.from_tensor_slices( (path, label) )
for i,j in train:
print(i,j)
break
from tensorflow.io import read_file
def process_image(fpath, label):
img = read_file(fpath)#编码后的数据
img = tf.image.decode_png(img)/255 # 解码成图像数组
img = tf.image.resize(img, [256,256]) # 所有图片大小统一
label = tf.one_hot(label, depth=3) # 独热编码
return img,label
# 通过映射,对x,y做处理
train = train.map(process_image)
for i,j in train:
print(i,j)
break
tarin = train.shuffle(10000) # 打乱数据 train = tarin.batch(32) # 给每个数据加批次 train.cache() # 数据缓存 train.prefetch(buffer_size=tf.data.AUTOTUNE) # 预取数,增加资源使用效率 # 划分数据 num = tf.data.experimental.cardinality(train) # 所有批次 valdb = train.take(num//20) # 取3个批次,3 traindb = train.skip(num//20) # 跳过3个批次.65
版权声明:本文为weixin_50663922原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。