12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import os
- import numpy as np
- from PIL import Image
- import torch
- from torch.utils.data import Dataset
- class DatasetGenerator(Dataset):
- def __init__(self, pathImageDirectory, pathDatasetFile, transform, model=None, nnClassCount=14):
- self.listImagePaths = []
- self.listImageLabels = []
- self.transform = transform
- self.model = model
- # 检查设备
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- if self.model:
- self.model.to(self.device)
- # 读取数据集文件
- with open(pathDatasetFile, "r") as fileDescriptor:
- lines = fileDescriptor.readlines()
- # 遍历文件,筛选有效路径
- for line in lines:
- lineItems = line.split()
- # 使用 os.path.join 来确保路径的正确拼接
- imagePath = os.path.normpath(os.path.join(pathImageDirectory, lineItems[0]))
- # 检查路径是否存在,并且是一个有效文件
- if not os.path.isfile(imagePath):
- print(f"Warning: Path {imagePath} does not exist or is not a file, skipping this file.")
- continue # 跳过不存在的文件
- imageLabel = [int(float(i)) for i in lineItems[1:]]
- if np.array(imageLabel).sum() >= 1: # 确保至少有一个标签为正
- self.listImagePaths.append(imagePath)
- self.listImageLabels.append(imageLabel)
- # 如果没有有效样本,抛出异常
- if len(self.listImagePaths) == 0:
- raise ValueError("No valid samples found. Please check your dataset file and image paths.")
- def __getitem__(self, index):
- imageData = Image.open(self.listImagePaths[index]).convert('RGB')
- imageData = self.transform(imageData)
- return imageData, torch.FloatTensor(self.listImageLabels[index])
- def __len__(self):
- return len(self.listImagePaths)
- def get_features_and_labels(self):
- if not self.features:
- raise ValueError("No features extracted. Ensure `model` is provided during initialization.")
- return self.features, self.labels
|