MindSpore中用户使用GeneratorDataset来进行自定义数据集的加载,并支持多类输入源,用户可以定义yield生成器,迭代器以及使用魔术方法__getitem__的随机访问类。
详细可以参考:
mindspore.dataset.GeneratorDataset
用户使用GeneratorDataset接口的大部分错误,主要集中在数据加载迭代器以及随机访问类中,下面会列举典型的几类错误:
用户脚本:
1 import mindspore.dataset as ds
2 import matplotlib.pyplot as plt
3 import mindspore.dataset.vision.py_transforms as py_trans
4 from mindspore.dataset.transforms.py_transforms import Compose
5 import mindspore.ops as ops
6 # from mindspore import Tensor
7
8 import numpy as np
9 from PIL import Image
10 import os
11 import random
12
13 from option import opt
14
15
16 class DatasetGenerator:
17 def __init__(self, path, train, format='.png'):
18 # self.size = size
19 self.format = format
20 self.train = train
21 if train:
22 self.haze_imgs_dir = os.listdir(os.path.join(path, 'train', 'hazy'))
23 self.haze_imgs = [os.path.join(path, 'train', 'hazy', img) for img in self.haze_imgs_dir]
24 self.clear_dir = os.path.join(path, 'train', 'gt')
25 else:
26 self.haze_imgs_dir = os.listdir(os.path.join(path, 'test', 'hazy'))
27 self.haze_imgs = [os.path.join(path, 'test', 'hazy', img) for img in self.haze_imgs_dir]
28 self.clear_dir = os.path.join(path, 'test', 'gt')
29 # print(self.haze_imgs_dir, self.clear_dir)
30
31 np.random.seed(58)
32 self.__random_seed = []
33 for _ in range(len(self.haze_imgs)):
34 self.__random_seed.append(random.randint(0, 1000000))
35 self.__index = -1
36
37 def __getitem__(self, index):
38 self.__index += 1
39
40 haze = Image.open(self.haze_imgs[index])
41 # if isinstance(self.size,int):
42 # while haze.size[0]<self.size or haze.size[1]<self.size :
43 # index=random.randint(0,20000)
44 # haze=Image.open(self.haze_imgs[index])
45 img = self.haze_imgs[index].split('\\')[-1]
46 # img=self.haze_imgs[index].split('/')[-1]
47 img_name = img.split('_')
48 # img_name=img.split('\\')[-1].split('_')
49 # print(img_name)
50 clear_name=f"{img_name[0]}_gt_{img_name[2]}"
51 # print(self.clear_dir, clear_name, os.path.join(self.clear_dir,clear_name))
52 clear=Image.open(os.path.join(self.clear_dir,clear_name))
53
54 w, h = clear.size
55 nw, nh = haze.size
56 left = (w - nw)/2
57 top = (h - nh)/2
58 right = (w + nw)/2
59 bottom = (h + nh)/2
60 clear = clear.crop((left, top, right, bottom))
61
62 return (haze, clear, index)
63
64 def __len__(self):
65 print(len(self.haze_imgs))
66 return len(self.haze_imgs)
67
68 def get_seed(self):
69 seed = self.__random_seed[self.__index]
70 return seed
71
72 def decode(img):
73 return Image.fromarray(img)
74
75 def set_random_seed(img_name, seed):
76 random.seed(seed)
77 return img_name
78
79 ds.config.set_seed(8)
80 # DATA_DIR = opt.data_url
81 DATA_DIR = 'C:\\Users\\44753\\Desktop\\NTIRE2021'
82
83 train_dataset_generator = DatasetGenerator(DATA_DIR, train=True)
84 train_dataset = ds.GeneratorDataset(train_dataset_generator, ["hazy", "gt", "img_name"], shuffle=True)
85 test_dataset_generator = DatasetGenerator(DATA_DIR, train=False)
86 test_dataset = ds.GeneratorDataset(test_dataset_generator, ["hazy", "gt", "img_name"], shuffle=False)
87
88 transforms_list = [
89 decode,
90 (lambda img_name: set_random_seed(img_name, train_dataset_generator.get_seed())),
91 py_trans.RandomCrop(opt.crop_size),
92 py_trans.ToTensor(),
93 ]
94 compose_trans = Compose(transforms_list)
95 train_dataset = train_dataset.map(operations=compose_trans, input_columns=["hazy"])
96 train_dataset = train_dataset.map(operations=compose_trans, input_columns=["gt"])
97 train_dataset = train_dataset.batch(opt.bs, drop_remainder=True)
98
99 if __name__ == '__main__':
100 for i in range(2):
101 print(i)
102 for batch in train_dataset.create_dict_iterator():
103 # print(batch)
104 # hazy = Tensor(batch["hazy"], dtype=mindspore.float32)
105 # clear = Tensor(batch["gt"], dtype=mindspore.float32)
106
107 print(batch["hazy"].shape, batch["gt"].shape)
报错信息如下,可以看出数据第一轮可以正常迭代,第二轮出现IndexError: list index out of range 错误。
原因分析:
用户脚本中第37行利用魔术方法
def __getitem__(self, index)定义了随机访问的类,其中index参数作为样本\元素索引,外部可以直接使用index来进行数据的访问,外部不需要对其进行修改(如:每个迭代开始前的复位操作),而用代码第35行户另外自定义了__index,需要用户在每个迭代前进行复位操作,来保证不出现数组越界访问。
解决方法:
移除非必要的__index成员变量,或者在每次迭代前对__index赋值为0进行复位操作。
版权声明:本文为skytttttt9394原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。