分层采样划分数据集

  • Post author:
  • Post category:其他


import glob

cats_fenbu = {}
for i in range(18):
    cats_fenbu[str(i)] = 0

# 统计每个类别的个数
for txt in glob.glob(r"D:\dataset\TianChi_clothes\guangdong1_round1_train2_20190828\train\labels\*"):
    with open(txt, "r") as f:
        labels = f.read().splitlines()
        for line in labels:
            cat = line.split(" ")[0]
            cats_fenbu[str(cat)] += 1
print(cats_fenbu)

# 按8:2划分训练测试数据集,直接产生测试集,剩余的就是训练集
cats_test = {}      # 存放测试集每个类别应有的个数
for cats_fenbu_key in cats_fenbu:
    cats_test[cats_fenbu_key] = round(cats_fenbu[cats_fenbu_key] / 10 * 2)
print(cats_test)

cats_test_real = {}     # 存放测试集中当前已加入的各类别数
for i in range(18):
    cats_test_real[str(i)] = 0
dataset_test = []          # 存放测试集的图片名
for txt in glob.glob(r"D:\dataset\TianChi_clothes\guangdong1_round1_train2_20190828\train\labels\*"):
    with open(txt, "r") as f:
        labels = f.read().splitlines()
        txt_cats = []   # 存放当前txt文件中的所有类别
        for line in labels:
            cat = line.split(" ")[0]
            txt_cats.append(cat)
        # 判断类别是否已满足需求
        isAdd = True
        for cat in txt_cats:
            if cats_test_real[str(cat)] >= cats_test[str(cat)]:  # 已满足需求
                isAdd = False
                break
        if isAdd:
            for cat in txt_cats:
                cats_test_real[str(cat)] += 1
            dataset_test.append(txt)
print("------------------------------------------------")
print(cats_test_real)

import tqdm
import shutil

for item in tqdm.tqdm(dataset_test):
    shutil.move(item, r"D:\dataset\TianChi_clothes\guangdong1_round1_train2_20190828\val\labels")       # move TXT
    # move Img
    img_path = r"D:\dataset\TianChi_clothes\guangdong1_round1_train2_20190828\train\images" + "\\" +\
               item.split("\\")[-1].split(".")[0] + ".jpg"
    shutil.move(img_path, r"D:\dataset\TianChi_clothes\guangdong1_round1_train2_20190828\val\images")

print("Successful!!!")



版权声明:本文为qq_41825891原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。