dataset.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import numpy as np
  2. import pandas as pd
  3. import cv2
  4. from PIL import Image
  5. from sklearn.model_selection import train_test_split
  6. from torch.utils.data import Dataset
  7. import os
  8. from torchvision import transforms
  9. # 定义图像预处理的函数,参数包括是否为训练模式以及自定义参数(args)
  10. def build_transform(train, args):
  11. if train:
  12. # 如果是训练模式,进行一系列数据增强和归一化处理
  13. transform = transforms.Compose((
  14. transforms.RandomResizedCrop(int(args.img_size / 0.875),
  15. scale=(0.8, 1.0)), # 随机裁剪图像
  16. transforms.RandomRotation(7), # 随机旋转图像
  17. transforms.RandomHorizontalFlip(), # 随机水平翻转
  18. transforms.CenterCrop(args.img_size), # 中心裁剪
  19. transforms.ToTensor(), # 转换为张量
  20. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  21. # 使用Imagenet的均值和方差进行归一化
  22. ))
  23. else:
  24. # 如果是验证或测试模式,只进行裁剪和归一化处理
  25. transform = transforms.Compose((
  26. transforms.Resize(int(args.img_size / 0.875)), # 调整大小
  27. transforms.CenterCrop(args.img_size), # 中心裁剪
  28. transforms.ToTensor(), # 转换为张量
  29. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  30. # 归一化
  31. ))
  32. return transform
  33. # 定义数据集类 photodatatest,继承自 PyTorch 的 Dataset 类
  34. class ChestXray14Dataset(Dataset):
  35. def __init__(self,
  36. data_root, # 数据集路径
  37. classes=['Atelectasis', 'Cardiomegaly', 'Effusion',
  38. 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
  39. 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
  40. 'Fibrosis', 'Pleural_Thickening', 'Hernia'],
  41. mode='train', # 模式:'train', 'valid', 或 'test'
  42. split='official', # 数据划分方式:'official' 或 'non-official'
  43. has_val_set=True, # 是否包含验证集
  44. transform=None): # 图像预处理方法
  45. super().__init__()
  46. self.data_root = data_root # 数据集根目录
  47. self.classes = classes # 多标签类别
  48. self.num_classes = len(self.classes) # 类别数
  49. self.mode = mode # 当前数据集模式
  50. # 根据split的类型选择加载方式
  51. if split == 'official':
  52. # 使用官方划分方式加载数据
  53. self.dataframe, self.num_patients = self.load_split_file(
  54. self.data_root, self.mode, has_val_set)
  55. else:
  56. # 使用非官方划分方式加载数据
  57. self.dataframe, self.num_patients = self.load_split_file_non_official(
  58. self.data_root, self.mode)
  59. self.transform = transform # 图像预处理方法
  60. self.num_samples = len(self.dataframe) # 样本数量
  61. # 加载官方划分文件的方法,根据mode加载不同的数据
  62. def load_split_file(self, folder, mode, has_val=True):
  63. df = pd.read_csv(
  64. os.path.join(
  65. 'F:\chexnet\chexnet-master\photodatatest',
  66. 'Data_Entry_2017.csv')) # 使用绝对路径
  67. # 如果模式为训练或验证
  68. if mode in ['train', 'valid']:
  69. # 加载训练和验证数据文件
  70. file_name = os.path.join(folder, 'train_val_list.txt')
  71. with open(file_name, 'r') as f:
  72. lines = f.read().splitlines() # 读取所有图像文件名
  73. df_train_val = df[df['Image Index'].isin(lines)] # 过滤出对应的图像
  74. # 如果需要验证集,将患者ID拆分为训练和验证集
  75. if has_val:
  76. patient_ids = df_train_val['Patient ID'].unique() # 获取所有患者ID
  77. train_ids, val_ids = train_test_split(patient_ids,
  78. test_size=1 - 0.7 / 0.8,
  79. random_state=0,
  80. shuffle=True)
  81. target_ids = train_ids if mode == 'train' else val_ids # 根据模式选择训练或验证集的患者ID
  82. df = df_train_val[
  83. df_train_val['Patient ID'].isin(target_ids)] # 根据ID过滤数据
  84. else:
  85. df = df_train_val
  86. elif mode == 'test':
  87. # 如果模式为测试,加载测试数据文件
  88. file_name = os.path.join(folder, 'test_list.txt')
  89. with open(file_name, 'r') as f:
  90. target_files = f.read().splitlines() # 读取测试集文件名
  91. df = df[df['Image Index'].isin(target_files)] # 过滤测试数据
  92. else:
  93. raise NotImplementedError(f'Unidentified split: {mode}') # 未识别的模式报错
  94. num_patients = len(df['Patient ID'].unique()) # 统计患者数
  95. return df, num_patients
  96. # 非官方划分文件的加载方法,根据比例拆分数据集
  97. def load_split_file_non_official(self, folder, mode):
  98. train_rt, val_rt, test_rt = 0.7, 0.1, 0.2 # 定义训练、验证和测试的比例
  99. df = pd.read_csv(
  100. os.path.join(folder, 'Data_Entry_2017.csv')) # 加载数据标签文件
  101. patient_ids = df['Patient ID'].unique() # 获取所有患者ID
  102. # 先划分出测试集,然后在剩余数据中划分出验证集和训练集
  103. train_val_ids, test_ids = train_test_split(patient_ids,
  104. test_size=test_rt,
  105. random_state=0, shuffle=True)
  106. train_ids, val_ids = train_test_split(train_val_ids,
  107. test_size=val_rt / (
  108. train_rt + val_rt),
  109. random_state=0, shuffle=True)
  110. # 根据模式选择目标ID
  111. target_ids = {'train': train_ids, 'valid': val_ids, 'test': test_ids}[
  112. mode]
  113. df = df[df['Patient ID'].isin(target_ids)] # 根据ID过滤数据
  114. num_patients = len(target_ids) # 统计患者数
  115. return df, num_patients
  116. # 将疾病标签转换为多标签编码
  117. def encode_label(self, label):
  118. encoded_label = np.zeros(self.num_classes,
  119. dtype=np.float32) # 初始化全0的标签数组
  120. if label != 'No Finding': # 如果标签不为"No Finding"
  121. for l in label.split('|'): # 对每个疾病标签进行处理
  122. encoded_label[self.classes.index(l)] = 1 # 将对应疾病的索引位置置1
  123. return encoded_label
  124. # 图像预处理函数,调整图像尺寸
  125. def pre_process(self, img):
  126. h, w = img.shape
  127. img = cv2.resize(img, dsize=(max(h, w), max(h, w))) # 将图像调整为正方形
  128. return img
  129. # 统计每个类别的样本数量
  130. def count_class_dist(self):
  131. class_counts = np.zeros(self.num_classes) # 初始化类别计数
  132. for index, row in self.dataframe.iterrows(): # 遍历数据集的每一行
  133. class_counts += self.encode_label(
  134. row['Finding Labels']) # 将标签编码加到计数器中
  135. return self.num_samples, class_counts
  136. # 返回数据集的样本数量
  137. def __len__(self):
  138. return self.num_samples
  139. # 获取指定索引的数据
  140. def __getitem__(self, idx):
  141. row = self.dataframe.iloc[idx] # 获取对应行的数据
  142. img_file, label = row['Image Index'], row[
  143. 'Finding Labels'] # 获取图像文件名和标签
  144. img = cv2.imread(
  145. os.path.join(self.data_root, 'images', img_file)) # 读取图像
  146. img = Image.fromarray(img) # 转换为PIL图像
  147. if self.transform is not None:
  148. img = self.transform(img) # 应用预处理
  149. label = self.encode_label(label) # 编码标签
  150. return img, label # 返回图像和标签