123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import os
- import numpy as np
- from PIL import Image
- import torch
- from torch.utils.data import Dataset
- # --------------------------------------------------------------------------------
- # 定义一个数据集类 DatasetGenerator,继承自 PyTorch 的 Dataset 类
- class DatasetGenerator(Dataset):
- # --------------------------------------------------------------------------------
- # 初始化函数,传入图像目录路径、数据集文件路径和图像预处理变换
- def __init__(self, pathImageDirectory, pathDatasetFile, transform):
- self.listImagePaths = [] # 用于存储图像路径的列表
- self.listImageLabels = [] # 用于存储标签的列表
- self.transform = transform # 图像的预处理方法
- # ---- 打开文件,获取图像路径和标签
- with open(pathDatasetFile, "r") as fileDescriptor:
- for line in fileDescriptor:
- lineItems = line.strip().split()
- imagePath = os.path.join(pathImageDirectory, lineItems[0]) # 获取图像文件的完整路径
- imageLabel = lineItems[1:] # 获取对应的标签(位于行的第二部分)
- # 将标签转换为整数列表,并确保每个标签是整数(有可能是浮点数的情况需要处理)
- imageLabel = [int(float(i)) for i in imageLabel]
- # 如果标签数组中至少有一个值为1(即图片至少有一个分类标签)
- if np.array(imageLabel).sum() >= 1:
- self.listImagePaths.append(imagePath) # 将图像路径加入列表
- self.listImageLabels.append(imageLabel) # 将图像标签加入列表
- # --------------------------------------------------------------------------------
- # 获取数据集中特定索引的图像及其标签
- def __getitem__(self, index):
- # 根据索引获取图像路径
- imagePath = self.listImagePaths[index]
- # 打开图像文件,并将图像转换为 RGB 模式
- imageData = Image.open(imagePath).convert('RGB')
- # 将对应的标签转换为 PyTorch 的 FloatTensor(浮点数张量)
- imageLabel = torch.FloatTensor(self.listImageLabels[index])
- # 如果有图像预处理操作,应用预处理
- if self.transform is not None:
- imageData = self.transform(imageData)
- # 返回图像数据和其对应的标签
- return imageData, imageLabel
- # --------------------------------------------------------------------------------
- # 返回数据集的样本数量
- def __len__(self):
- return len(self.listImagePaths)
|