MindSpore数据集加载-【IndexError: list index out of range】错误

  • Post author:
  • Post category:其他


MindSpore中用户使用GeneratorDataset来进行自定义数据集的加载,并支持多类输入源,用户可以定义yield生成器,迭代器以及使用魔术方法__getitem__的随机访问类。

详细可以参考:

mindspore.dataset.GeneratorDataset

用户使用GeneratorDataset接口的大部分错误,主要集中在数据加载迭代器以及随机访问类中,下面会列举典型的几类错误:

用户脚本:

  1 import mindspore.dataset as ds
  2 import matplotlib.pyplot as plt
  3 import mindspore.dataset.vision.py_transforms as py_trans
  4 from mindspore.dataset.transforms.py_transforms import Compose
  5 import mindspore.ops as ops
  6 # from mindspore import Tensor
  7
  8 import numpy as np
  9 from PIL import Image
 10 import os
 11 import random
 12
 13 from option import opt
 14
 15
 16 class DatasetGenerator:
 17     def __init__(self, path, train, format='.png'):
 18         # self.size = size
 19         self.format = format
 20         self.train = train
 21         if train:
 22             self.haze_imgs_dir = os.listdir(os.path.join(path, 'train', 'hazy'))
 23             self.haze_imgs = [os.path.join(path, 'train', 'hazy', img) for img in self.haze_imgs_dir]
 24             self.clear_dir = os.path.join(path, 'train', 'gt')
 25         else:
 26             self.haze_imgs_dir = os.listdir(os.path.join(path, 'test', 'hazy'))
 27             self.haze_imgs = [os.path.join(path, 'test', 'hazy', img) for img in self.haze_imgs_dir]
 28             self.clear_dir = os.path.join(path, 'test', 'gt')
 29         # print(self.haze_imgs_dir, self.clear_dir)
 30
 31         np.random.seed(58)
 32         self.__random_seed = []
 33         for _ in range(len(self.haze_imgs)):
 34             self.__random_seed.append(random.randint(0, 1000000))
 35         self.__index = -1
 36
 37     def __getitem__(self, index):
 38         self.__index += 1
 39
 40         haze = Image.open(self.haze_imgs[index])
 41         # if isinstance(self.size,int):
 42         #     while haze.size[0]<self.size or haze.size[1]<self.size :
 43         #         index=random.randint(0,20000)
 44         #         haze=Image.open(self.haze_imgs[index])
 45         img = self.haze_imgs[index].split('\\')[-1]
 46         # img=self.haze_imgs[index].split('/')[-1]
 47         img_name = img.split('_')
 48         # img_name=img.split('\\')[-1].split('_')
 49         # print(img_name)
 50         clear_name=f"{img_name[0]}_gt_{img_name[2]}"
 51         # print(self.clear_dir, clear_name, os.path.join(self.clear_dir,clear_name))
 52         clear=Image.open(os.path.join(self.clear_dir,clear_name))
 53
 54         w, h = clear.size
 55         nw, nh = haze.size
 56         left = (w - nw)/2
 57         top = (h - nh)/2
 58         right = (w + nw)/2
 59         bottom = (h + nh)/2
 60         clear = clear.crop((left, top, right, bottom))
 61
 62         return (haze, clear, index)
 63
 64     def __len__(self):
 65         print(len(self.haze_imgs))
 66         return len(self.haze_imgs)
 67
 68     def get_seed(self):
 69         seed = self.__random_seed[self.__index]
 70         return seed
 71
 72 def decode(img):
 73     return Image.fromarray(img)
 74
 75 def set_random_seed(img_name, seed):
 76     random.seed(seed)
 77     return img_name
 78
 79 ds.config.set_seed(8)
 80 # DATA_DIR = opt.data_url
 81 DATA_DIR = 'C:\\Users\\44753\\Desktop\\NTIRE2021'
 82
 83 train_dataset_generator = DatasetGenerator(DATA_DIR, train=True)
 84 train_dataset = ds.GeneratorDataset(train_dataset_generator, ["hazy", "gt", "img_name"], shuffle=True)
 85 test_dataset_generator = DatasetGenerator(DATA_DIR, train=False)
 86 test_dataset = ds.GeneratorDataset(test_dataset_generator, ["hazy", "gt", "img_name"], shuffle=False)
 87
 88 transforms_list = [
 89     decode,
 90     (lambda img_name: set_random_seed(img_name, train_dataset_generator.get_seed())),
 91     py_trans.RandomCrop(opt.crop_size),
 92     py_trans.ToTensor(),
 93 ]
 94 compose_trans = Compose(transforms_list)
 95 train_dataset = train_dataset.map(operations=compose_trans, input_columns=["hazy"])
 96 train_dataset = train_dataset.map(operations=compose_trans, input_columns=["gt"])
 97 train_dataset = train_dataset.batch(opt.bs, drop_remainder=True)
 98
 99 if __name__ == '__main__':
100     for i in range(2):
101         print(i)
102         for batch in train_dataset.create_dict_iterator():
103             # print(batch)
104             # hazy = Tensor(batch["hazy"], dtype=mindspore.float32)
105             # clear = Tensor(batch["gt"], dtype=mindspore.float32)
106
107             print(batch["hazy"].shape, batch["gt"].shape)

报错信息如下,可以看出数据第一轮可以正常迭代,第二轮出现IndexError: list index out of range 错误。


原因分析:

用户脚本中第37行利用魔术方法


def __getitem__(self, index)定义了随机访问的类,其中index参数作为样本\元素索引,外部可以直接使用index来进行数据的访问,外部不需要对其进行修改(如:每个迭代开始前的复位操作),而用代码第35行户另外自定义了__index,需要用户在每个迭代前进行复位操作,来保证不出现数组越界访问。


解决方法:

移除非必要的__index成员变量,或者在每次迭代前对__index赋值为0进行复位操作。



版权声明:本文为skytttttt9394原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。