DatasetGenerator(2).py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. import torch
  5. from torch.utils.data import Dataset
  6. # --------------------------------------------------------------------------------
  7. # 定义一个数据集类 DatasetGenerator,继承自 PyTorch 的 Dataset 类
  8. class DatasetGenerator(Dataset):
  9. # --------------------------------------------------------------------------------
  10. # 初始化函数,传入图像目录路径、数据集文件路径和图像预处理变换
  11. def __init__(self, pathImageDirectory, pathDatasetFile, transform):
  12. self.listImagePaths = [] # 用于存储图像路径的列表
  13. self.listImageLabels = [] # 用于存储标签的列表
  14. self.transform = transform # 图像的预处理方法
  15. # ---- 打开文件,获取图像路径和标签
  16. with open(pathDatasetFile, "r") as fileDescriptor:
  17. for line in fileDescriptor:
  18. lineItems = line.strip().split()
  19. imagePath = os.path.join(pathImageDirectory, lineItems[0]) # 获取图像文件的完整路径
  20. imageLabel = lineItems[1:] # 获取对应的标签(位于行的第二部分)
  21. # 将标签转换为整数列表,并确保每个标签是整数(有可能是浮点数的情况需要处理)
  22. imageLabel = [int(float(i)) for i in imageLabel]
  23. # 如果标签数组中至少有一个值为1(即图片至少有一个分类标签)
  24. if np.array(imageLabel).sum() >= 1:
  25. self.listImagePaths.append(imagePath) # 将图像路径加入列表
  26. self.listImageLabels.append(imageLabel) # 将图像标签加入列表
  27. # --------------------------------------------------------------------------------
  28. # 获取数据集中特定索引的图像及其标签
  29. def __getitem__(self, index):
  30. # 根据索引获取图像路径
  31. imagePath = self.listImagePaths[index]
  32. # 打开图像文件,并将图像转换为 RGB 模式
  33. imageData = Image.open(imagePath).convert('RGB')
  34. # 将对应的标签转换为 PyTorch 的 FloatTensor(浮点数张量)
  35. imageLabel = torch.FloatTensor(self.listImageLabels[index])
  36. # 如果有图像预处理操作,应用预处理
  37. if self.transform is not None:
  38. imageData = self.transform(imageData)
  39. # 返回图像数据和其对应的标签
  40. return imageData, imageLabel
  41. # --------------------------------------------------------------------------------
  42. # 返回数据集的样本数量
  43. def __len__(self):
  44. return len(self.listImagePaths)