Pytorch同时让两个dataloader打乱并顺序是对应的

  • Post author:
  • Post category:其他


在深度学习双分支网络的数据加载,需要有两个dataloader时,因此我们在获取dataloader时,让两个dataloader打乱并顺序是对应的。可参考下面的文章。



(69条消息) Pytorch怎么同时让两个dataloader打乱的顺序是相同_huhuan4480的博客-CSDN博客



https://blog.csdn.net/huhuan4480/article/details/113246342



该文章思路是将两个数据集按照对应顺序写成一个数据集,然后再用dataloader取出

import torch
import numpy as np
from torch.utils.data import Dataset


class MyDataset_v1(Dataset):
    def __init__(self):
        self.data = np.array([1, 2, 3, 4])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item], 1    # 在图像处理中,可将self.data[item]视为一个image,数字1视为对应的标签


class MyDataset_v2(Dataset):
    def __init__(self):
        self.data = np.array([1.1, 2.2, 3.3, 4.4])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item], 2    


class MyDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __getitem__(self, index):
        x1 = self.dataset1[index]
        x2 = self.dataset2[index]
        print('x1:', x1)
        print('x2:', x2)
        return x1, x2    

    def __len__(self):
        return len(self.dataset1)


if __name__ == "__main__":
    myDataset1 = MyDataset_v1()
    myDataset2 = MyDataset_v2()

    myDataset = MyDataset(dataset1=myDataset1, dataset2=myDataset2)
    print(myDataset)
    print(type(myDataset))
    dataloader = torch.utils.data.DataLoader(dataset=myDataset, batch_size=2, shuffle=True, pin_memory=True)
    epoch = 2
    step = -1
    for i in range(epoch):
        for batch_ind, data in enumerate(dataloader):
            data1, data2 = data[0], data[1]
            print("Epoch: {} Batch_ind: {} data in Dataset1: {} data in Dataset2: {}".format(i, batch_ind, data1, data2))

输出结果可以看到,在第一个epoch中,data为一个列表,data列表内有两个列表data1和data2。

print(data)
[[tensor([3, 1], dtype=torch.int32), tensor([1, 1])], [tensor([3.3000, 1.1000], dtype=torch.float64), tensor([2, 2])]]
print(data1)
[tensor([3, 1], dtype=torch.int32), tensor([1, 1])]

data1是MyDataset_v1返回的,一个epoch的值。两个tensor(img的tensor和标签的tensor),tensor内可以看到含两个元素的列表,保存的是batch_size = 2时得到的两个数据。

当batch_size = 3时,依旧时返回含两个列表data1和data2的列表data,但里面的tensor内为含3个元素的列表

print(data)
Python 3.8.11 (default, Aug  6 2021, 09:57:55) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.28.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.28.0
[[tensor([3, 4, 1], dtype=torch.int32), tensor([1, 1, 1])], [tensor([3.3000, 4.4000, 1.1000], dtype=torch.float64), tensor([2, 2, 2])]]
print(data[0])
[tensor([3, 4, 1], dtype=torch.int32), tensor([1, 1, 1])]
print(data[1])
[tensor([3.3000, 4.4000, 1.1000], dtype=torch.float64), tensor([2, 2, 2])]

另外附上我自己写的

dataloader.py

# 用于实现自定义的图像与标签转成DataLoader适用的格式
from PIL import Image
import os
import torch.nn as nn
import json


path_to_data = os.path.join('.', 'data')
path_to_json = os.path.join(path_to_data, 'preprocessjson.json')
path_to_preprocss = os.path.join(path_to_data, 'preprocess')
path_to_train = os.path.join(path_to_preprocss, 'train')
path_to_verify = os.path.join(path_to_preprocss, 'verify')

# 定义数据读入
def Load_Image_Information(path):
    # 以RGB格式打开图像
    # Pytorch DataLoader就是使用PIL所读取的图像格式
    # 建议就用这种方法读取图像,当读入灰度图像时convert('')
    return Image.open(path).convert('RGB')


"""
__init__实现将图像的名称和对应的label保存到各自的list内,同样的index的内容是来自同一个病人的
__getitem__的item索引就是图像和标签在当前list的索引,每次调用item是随机值,一个batch里的数据是随机打乱的
返回的是一个实际图像和他的label
DataLoader会循环执行__getitem__,按batch_size大小数据打包好返回
"""

class my_Data_Set_BModeH(nn.Module):
    def __init__(self, path_to_json, option, transform=None, loader=None):
        super(my_Data_Set_BModeH, self).__init__()
        open_json = open(path_to_json, 'r')
        label_json = json.load(open_json)

        self.image = []
        self.label = []

        if option == 'train':
            path_to_root = path_to_train    # ./data/prepeocess/train
        else:
            path_to_root = path_to_verify   # ./data/prepeocess/verify

        filelist_of_root = os.listdir(path_to_root)     # [A1, A2, A3...]
        for filename in filelist_of_root:
            path_to_patient = os.path.join(path_to_root, filename)                # ./data/prepeocess/train/Ai
            image_name = os.listdir(path_to_patient)                              # 获取病人的图像名[Ai_1, Ai_2]
            path_to_BModeH = os.path.join(path_to_root, filename, image_name[0])  # ./data/preprocess/train/Ai/Ai_1.
            self.image.append(path_to_BModeH)
            self.label.append(label_json[filename])
        self.transform = transform
        self.loader = loader

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        img = self.loader(self.image[index])
        if self.transform:
            img = self.transform(img)
        label = self.label[index]
        return img, label


class my_Data_Set_CDFI(nn.Module):
    def __init__(self, path_to_json, option, transform=None, loader=None):
        super(my_Data_Set_CDFI, self).__init__()
        open_json = open(path_to_json, 'r')
        label_json = json.load(open_json)

        self.image = []
        self.label = []

        if option == 'train':
            path_to_root = path_to_train    # ./data/prepeocess/train
        else:
            path_to_root = path_to_verify   # ./data/prepeocess/verify

        filelist_of_root = os.listdir(path_to_root)
        for filename in filelist_of_root:
            path_to_patient = os.path.join(path_to_root, filename)              # ./data/prepeocess/train/Ai
            image_name = os.listdir(path_to_patient)                            # 获取病人的图像名[Ai_1, Ai_2]
            path_to_CDFI = os.path.join(path_to_root, filename, image_name[1])  # ./data/preprocess/train/Ai/Ai_2.
            self.image.append(path_to_CDFI)
            self.label.append(label_json[filename])
        self.transform = transform
        self.loader = loader

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        img = self.loader(self.image[index])
        if self.transform:
            img = self.transform(img)
        label = self.label[index]
        return img, label


# 定义自己数据集的数据读入类
class my_Data_Set(nn.Module):
    def __init__(self, dataset_BModeH, dataset_CDFI):
        super(my_Data_Set, self).__init__()
        self.dataset_BModeH = dataset_BModeH
        self.dataset_CDFI = dataset_CDFI

    # 重写这个函数用来进行图像数据的读取
    def __getitem__(self, index):
        BModeH = self.dataset_BModeH[index]
        CDFI = self.dataset_CDFI[index]
        return BModeH, CDFI

    # 重写这个函数,来看数据集中含有多少数据
    def __len__(self):
        return len(self.dataset_BModeH)

multymodel.py

import os
from torchvision import transforms
import torch
from torch import nn
import dataloader

path_to_preprocess = os.path.join('.', 'data', 'preprocess')
path_to_preprocessjson = os.path.join('.', 'data', 'preprocessjson.json')


def resnet_start():
    data_transforms = {
        "train": transforms.Compose([
            transforms.ToTensor(),
            # 标准化
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "verify": transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    batch_size = 8  # 每组8个

    # 批量处理数据并以模型能够接纳的数据形式保存为image_datasets
    # train和valid的图片,做transform之后用字典保存 两个key:train 和 verify
    # 第一个参数是待处理文件的根目录,第二个是处理方式
    train_BModeH = dataloader.my_Data_Set_BModeH(path_to_json=path_to_preprocessjson, option='train',
                                                 transform=data_transforms['train'], loader=dataloader.Load_Image_Information)
    train_CDFI = dataloader.my_Data_Set_CDFI(path_to_json=path_to_preprocessjson, option='train',
                                             transform=data_transforms['train'], loader=dataloader.Load_Image_Information)

    verify_BModeH = dataloader.my_Data_Set_BModeH(path_to_json=path_to_preprocessjson, option='verify',
                                                 transform=data_transforms['verify'], loader=dataloader.Load_Image_Information)
    verify_CDFI = dataloader.my_Data_Set_CDFI(path_to_json=path_to_preprocessjson, option='verify',
                                             transform=data_transforms['verify'], loader=dataloader.Load_Image_Information)

    # 输出数据的标签
    print(train_BModeH.label)
    print(train_CDFI.label)
    print(verify_BModeH.label)
    print(verify_CDFI.label)

    train_data = dataloader.my_Data_Set(train_BModeH, train_CDFI)
    verify_data = dataloader.my_Data_Set(verify_BModeH, verify_CDFI)

    # 批次获取图像,每四个为一组 -- batch_size = 8
    # 批量处理,这里都是tensor格式(上面compose)      shuffle是否在一个训练周期后对数据集进行打乱
    train_dataloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
    verify_dataloader = torch.utils.data.DataLoader(dataset=verify_data, batch_size=batch_size, shuffle=True)

    dataloader_dict = {'train': train_dataloader, 'verify': verify_dataloader}



版权声明:本文为fwkgeass原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。