DatasetGenerator.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. import torch
  5. from torch.utils.data import Dataset
  6. class DatasetGenerator(Dataset):
  7. def __init__(self, pathImageDirectory, pathDatasetFile, transform, model=None, nnClassCount=14):
  8. self.listImagePaths = []
  9. self.listImageLabels = []
  10. self.transform = transform
  11. self.model = model
  12. # 检查设备
  13. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14. if self.model:
  15. self.model.to(self.device)
  16. # 读取数据集文件
  17. with open(pathDatasetFile, "r") as fileDescriptor:
  18. lines = fileDescriptor.readlines()
  19. # 遍历文件,筛选有效路径
  20. for line in lines:
  21. lineItems = line.split()
  22. # 使用 os.path.join 来确保路径的正确拼接
  23. imagePath = os.path.normpath(os.path.join(pathImageDirectory, lineItems[0]))
  24. # 检查路径是否存在,并且是一个有效文件
  25. if not os.path.isfile(imagePath):
  26. print(f"Warning: Path {imagePath} does not exist or is not a file, skipping this file.")
  27. continue # 跳过不存在的文件
  28. imageLabel = [int(float(i)) for i in lineItems[1:]]
  29. if np.array(imageLabel).sum() >= 1: # 确保至少有一个标签为正
  30. self.listImagePaths.append(imagePath)
  31. self.listImageLabels.append(imageLabel)
  32. # 如果没有有效样本,抛出异常
  33. if len(self.listImagePaths) == 0:
  34. raise ValueError("No valid samples found. Please check your dataset file and image paths.")
  35. def __getitem__(self, index):
  36. imageData = Image.open(self.listImagePaths[index]).convert('RGB')
  37. imageData = self.transform(imageData)
  38. return imageData, torch.FloatTensor(self.listImageLabels[index])
  39. def __len__(self):
  40. return len(self.listImagePaths)
  41. def get_features_and_labels(self):
  42. if not self.features:
  43. raise ValueError("No features extracted. Ensure `model` is provided during initialization.")
  44. return self.features, self.labels