123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import numpy as np
- import pandas as pd
- import cv2
- from PIL import Image
- from sklearn.model_selection import train_test_split
- from torch.utils.data import Dataset
- import os
- from torchvision import transforms
- # 定义图像预处理的函数,参数包括是否为训练模式以及自定义参数(args)
- def build_transform(train, args):
- if train:
- # 如果是训练模式,进行一系列数据增强和归一化处理
- transform = transforms.Compose((
- transforms.RandomResizedCrop(int(args.img_size / 0.875),
- scale=(0.8, 1.0)), # 随机裁剪图像
- transforms.RandomRotation(7), # 随机旋转图像
- transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.CenterCrop(args.img_size), # 中心裁剪
- transforms.ToTensor(), # 转换为张量
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- # 使用Imagenet的均值和方差进行归一化
- ))
- else:
- # 如果是验证或测试模式,只进行裁剪和归一化处理
- transform = transforms.Compose((
- transforms.Resize(int(args.img_size / 0.875)), # 调整大小
- transforms.CenterCrop(args.img_size), # 中心裁剪
- transforms.ToTensor(), # 转换为张量
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- # 归一化
- ))
- return transform
- # 定义数据集类 photodatatest,继承自 PyTorch 的 Dataset 类
- class ChestXray14Dataset(Dataset):
- def __init__(self,
- data_root, # 数据集路径
- classes=['Atelectasis', 'Cardiomegaly', 'Effusion',
- 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening', 'Hernia'],
- mode='train', # 模式:'train', 'valid', 或 'test'
- split='official', # 数据划分方式:'official' 或 'non-official'
- has_val_set=True, # 是否包含验证集
- transform=None): # 图像预处理方法
- super().__init__()
- self.data_root = data_root # 数据集根目录
- self.classes = classes # 多标签类别
- self.num_classes = len(self.classes) # 类别数
- self.mode = mode # 当前数据集模式
- # 根据split的类型选择加载方式
- if split == 'official':
- # 使用官方划分方式加载数据
- self.dataframe, self.num_patients = self.load_split_file(
- self.data_root, self.mode, has_val_set)
- else:
- # 使用非官方划分方式加载数据
- self.dataframe, self.num_patients = self.load_split_file_non_official(
- self.data_root, self.mode)
- self.transform = transform # 图像预处理方法
- self.num_samples = len(self.dataframe) # 样本数量
- # 加载官方划分文件的方法,根据mode加载不同的数据
- def load_split_file(self, folder, mode, has_val=True):
- df = pd.read_csv(
- os.path.join(
- 'F:\chexnet\chexnet-master\photodatatest',
- 'Data_Entry_2017.csv')) # 使用绝对路径
- # 如果模式为训练或验证
- if mode in ['train', 'valid']:
- # 加载训练和验证数据文件
- file_name = os.path.join(folder, 'train_val_list.txt')
- with open(file_name, 'r') as f:
- lines = f.read().splitlines() # 读取所有图像文件名
- df_train_val = df[df['Image Index'].isin(lines)] # 过滤出对应的图像
- # 如果需要验证集,将患者ID拆分为训练和验证集
- if has_val:
- patient_ids = df_train_val['Patient ID'].unique() # 获取所有患者ID
- train_ids, val_ids = train_test_split(patient_ids,
- test_size=1 - 0.7 / 0.8,
- random_state=0,
- shuffle=True)
- target_ids = train_ids if mode == 'train' else val_ids # 根据模式选择训练或验证集的患者ID
- df = df_train_val[
- df_train_val['Patient ID'].isin(target_ids)] # 根据ID过滤数据
- else:
- df = df_train_val
- elif mode == 'test':
- # 如果模式为测试,加载测试数据文件
- file_name = os.path.join(folder, 'test_list.txt')
- with open(file_name, 'r') as f:
- target_files = f.read().splitlines() # 读取测试集文件名
- df = df[df['Image Index'].isin(target_files)] # 过滤测试数据
- else:
- raise NotImplementedError(f'Unidentified split: {mode}') # 未识别的模式报错
- num_patients = len(df['Patient ID'].unique()) # 统计患者数
- return df, num_patients
- # 非官方划分文件的加载方法,根据比例拆分数据集
- def load_split_file_non_official(self, folder, mode):
- train_rt, val_rt, test_rt = 0.7, 0.1, 0.2 # 定义训练、验证和测试的比例
- df = pd.read_csv(
- os.path.join(folder, 'Data_Entry_2017.csv')) # 加载数据标签文件
- patient_ids = df['Patient ID'].unique() # 获取所有患者ID
- # 先划分出测试集,然后在剩余数据中划分出验证集和训练集
- train_val_ids, test_ids = train_test_split(patient_ids,
- test_size=test_rt,
- random_state=0, shuffle=True)
- train_ids, val_ids = train_test_split(train_val_ids,
- test_size=val_rt / (
- train_rt + val_rt),
- random_state=0, shuffle=True)
- # 根据模式选择目标ID
- target_ids = {'train': train_ids, 'valid': val_ids, 'test': test_ids}[
- mode]
- df = df[df['Patient ID'].isin(target_ids)] # 根据ID过滤数据
- num_patients = len(target_ids) # 统计患者数
- return df, num_patients
- # 将疾病标签转换为多标签编码
- def encode_label(self, label):
- encoded_label = np.zeros(self.num_classes,
- dtype=np.float32) # 初始化全0的标签数组
- if label != 'No Finding': # 如果标签不为"No Finding"
- for l in label.split('|'): # 对每个疾病标签进行处理
- encoded_label[self.classes.index(l)] = 1 # 将对应疾病的索引位置置1
- return encoded_label
- # 图像预处理函数,调整图像尺寸
- def pre_process(self, img):
- h, w = img.shape
- img = cv2.resize(img, dsize=(max(h, w), max(h, w))) # 将图像调整为正方形
- return img
- # 统计每个类别的样本数量
- def count_class_dist(self):
- class_counts = np.zeros(self.num_classes) # 初始化类别计数
- for index, row in self.dataframe.iterrows(): # 遍历数据集的每一行
- class_counts += self.encode_label(
- row['Finding Labels']) # 将标签编码加到计数器中
- return self.num_samples, class_counts
- # 返回数据集的样本数量
- def __len__(self):
- return self.num_samples
- # 获取指定索引的数据
- def __getitem__(self, idx):
- row = self.dataframe.iloc[idx] # 获取对应行的数据
- img_file, label = row['Image Index'], row[
- 'Finding Labels'] # 获取图像文件名和标签
- img = cv2.imread(
- os.path.join(self.data_root, 'images', img_file)) # 读取图像
- img = Image.fromarray(img) # 转换为PIL图像
- if self.transform is not None:
- img = self.transform(img) # 应用预处理
- label = self.encode_label(label) # 编码标签
- return img, label # 返回图像和标签
|