ChexnetTrainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torchvision import transforms
  7. import matplotlib.pyplot as plt
  8. from torch.utils.data import DataLoader
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.model_selection import train_test_split
  11. from sklearn.metrics import accuracy_score, f1_score
  12. from tqdm import tqdm
  13. import joblib # 导入joblib库用于保存模型
  14. from DatasetGenerator import DatasetGenerator
  15. from DensenetModels import DenseNet121, DenseNet169, DenseNet201, ResNet50, FocalLoss
  16. from sklearn.multioutput import MultiOutputClassifier
  17. from sklearn.metrics import roc_curve, auc
  18. class ChexnetTrainer:
  19. # 在 ChexnetTrainer 类中添加以下方法
  20. @staticmethod
  21. def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
  22. nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
  23. transCrop, launchTimestamp, checkpoint=None):
  24. """训练函数"""
  25. # 自动选择设备
  26. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  27. # 获取模型并迁移到设备
  28. model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
  29. # 设置数据预处理和数据加载器
  30. transformSequence = ChexnetTrainer._get_transform(transCrop)
  31. datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTrain,
  32. transform=transformSequence, model=model)
  33. datasetVal = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileVal,
  34. transform=transformSequence, model=model)
  35. dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True, num_workers=4,
  36. pin_memory=True)
  37. dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4,
  38. pin_memory=True)
  39. optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
  40. # 替换 ReduceLROnPlateau 为 StepLR
  41. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
  42. # 迁移class_weights到对应设备
  43. class_weights = torch.tensor(
  44. [1.1762, 0.6735, 0.9410, 1.6680, 0.9699, 1.1950, 2.2584, 0.6859, 1.6683, 0.7744, 0.4625, 0.7385, 1.3764,
  45. 0.1758], dtype=torch.float32).to(device)
  46. loss = FocalLoss(alpha=class_weights, gamma=2, logits=True)
  47. lossMIN = 100000
  48. train_f1_scores, val_f1_scores = [], []
  49. all_val_targets = []
  50. all_val_scores = []
  51. # 训练循环
  52. for epochID in range(trMaxEpoch):
  53. ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss,
  54. train_f1_scores, device)
  55. lossVal, losstensor, val_f1, val_targets, val_scores = ChexnetTrainer.epochVal(
  56. model, dataLoaderVal, optimizer, scheduler,
  57. trMaxEpoch, nnClassCount, loss, val_f1_scores, device)
  58. all_val_targets.append(val_targets)
  59. all_val_scores.append(val_scores)
  60. # 更新学习率
  61. scheduler.step()
  62. # 保存最佳模型
  63. if lossVal < lossMIN:
  64. lossMIN = lossVal
  65. torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
  66. 'optimizer': optimizer.state_dict()},
  67. 'm-' + launchTimestamp + '.pth.tar')
  68. # 合并所有验证集的标签和预测
  69. all_val_targets = np.vstack(all_val_targets)
  70. all_val_scores = np.vstack(all_val_scores)
  71. # 绘制F1分数图
  72. plt.figure()
  73. plt.plot(train_f1_scores, label="Train F1-Score")
  74. plt.plot(val_f1_scores, label="Val F1-Score")
  75. plt.xlabel("Epoch")
  76. plt.ylabel("F1 Score")
  77. plt.title("F1 Score per Epoch")
  78. plt.legend()
  79. plt.savefig(os.path.join('images', 'f1_scores.png'))
  80. plt.close()
  81. # 计算并绘制AUC-ROC曲线
  82. CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
  83. 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
  84. 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
  85. 'Fibrosis', 'Pleural_Thickening',
  86. 'Hernia']
  87. ChexnetTrainer.plot_auc_roc(all_val_targets, all_val_scores, CLASS_NAMES,
  88. os.path.join('images', 'auc_roc_curve.png'))
  89. print("AUC-ROC曲线已保存到 images/auc_roc_curve.png")
  90. # 提取特征并训练随机森林
  91. rf_classifier = ChexnetTrainer.train_random_forest(datasetTrain, datasetVal, model, device)
  92. # 保存随机森林模型
  93. joblib.dump(rf_classifier, os.path.join('images', 'random_forest_model.pkl'))
  94. # 保存随机森林分类器
  95. @staticmethod
  96. def epochTrain(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
  97. model.train()
  98. all_targets = []
  99. all_preds = []
  100. # 创建进度条
  101. pbar = tqdm(total=len(dataLoader), desc="Training", leave=True)
  102. for batchID, (input, target) in enumerate(dataLoader):
  103. target = target.to(device)
  104. input = input.to(device)
  105. varInput = torch.autograd.Variable(input)
  106. varTarget = torch.autograd.Variable(target)
  107. varOutput = model(varInput)
  108. lossvalue = loss(varOutput, varTarget)
  109. optimizer.zero_grad()
  110. lossvalue.backward()
  111. optimizer.step()
  112. pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
  113. all_targets.extend(target.cpu().numpy())
  114. all_preds.extend(pred)
  115. # 更新进度条
  116. pbar.update(1)
  117. pbar.close()
  118. f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
  119. f1_scores.append(f1)
  120. # 在每个epoch结束时打印当前的F1-score
  121. print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
  122. @staticmethod
  123. def epochVal(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
  124. model.eval()
  125. lossVal = 0
  126. lossValNorm = 0
  127. losstensorMean = 0
  128. all_targets = []
  129. all_preds = []
  130. all_scores = [] # 收集预测概率
  131. # 创建进度条
  132. pbar = tqdm(total=len(dataLoader), desc="Validation", leave=True)
  133. for i, (input, target) in enumerate(dataLoader):
  134. target = target.to(device)
  135. input = input.to(device)
  136. with torch.no_grad():
  137. varInput = torch.autograd.Variable(input)
  138. varTarget = torch.autograd.Variable(target)
  139. varOutput = model(varInput)
  140. losstensor = loss(varOutput, varTarget)
  141. losstensorMean += losstensor
  142. lossVal += losstensor.item()
  143. lossValNorm += 1
  144. scores = torch.sigmoid(varOutput).cpu().data.numpy()
  145. all_scores.extend(scores)
  146. pred = scores > 0.5
  147. all_targets.extend(target.cpu().numpy())
  148. all_preds.extend(pred)
  149. # 更新进度条
  150. pbar.update(1)
  151. pbar.close()
  152. f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
  153. f1_scores.append(f1)
  154. # 在每个epoch结束时打印当前的F1-score
  155. print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
  156. return lossVal / lossValNorm, losstensorMean / lossValNorm, f1, np.array(all_targets), np.array(all_scores)
  157. @staticmethod
  158. def train_random_forest(datasetTrain, datasetVal, model, device):
  159. """训练随机森林分类器,适应多标签分类"""
  160. print("正在提取训练数据的特征。")
  161. # 提取训练数据的特征
  162. train_features, train_labels = ChexnetTrainer.extract_features(datasetTrain, model, device)
  163. val_features, val_labels = ChexnetTrainer.extract_features(datasetVal, model, device)
  164. print("正在训练随机森林分类器。")
  165. # 使用MultiOutputClassifier来处理多标签问题
  166. rf_classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100))
  167. rf_classifier.fit(train_features, train_labels)
  168. print("正在评估随机森林分类器。")
  169. # 在验证集上评估随机森林
  170. val_preds = rf_classifier.predict(val_features)
  171. val_f1 = f1_score(val_labels, val_preds, average='macro')
  172. print(f"Random Forest F1 Score on validation set: {val_f1}")
  173. return rf_classifier
  174. @staticmethod
  175. def extract_features(dataset, model, device):
  176. """提取数据集的特征"""
  177. features = []
  178. labels = []
  179. print("正在提取数据集的特征。")
  180. model.eval()
  181. dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
  182. for input, target in tqdm(dataLoader, desc="Extracting Features"):
  183. input = input.to(device)
  184. target = target.to(device)
  185. with torch.no_grad():
  186. varInput = torch.autograd.Variable(input)
  187. varOutput = model(varInput)
  188. # 这里假设 varOutput 是一个二维张量,我们需要把它展平成一维向量
  189. features.append(varOutput.view(varOutput.size(0), -1).cpu().data.numpy()) # 展平特征
  190. labels.append(target.cpu().data.numpy()) # 假设 target 是多标签格式的
  191. # 使用 np.vstack 将特征和标签拼接成数组
  192. return np.vstack(features), np.vstack(labels) # 确保标签是二维矩阵,每一行对应一个样本的标签
  193. @staticmethod
  194. def plot_auc_roc(y_true, y_scores, class_names, save_path):
  195. """绘制AUC-ROC曲线并保存"""
  196. plt.figure(figsize=(20, 15))
  197. for i, class_name in enumerate(class_names):
  198. fpr, tpr, _ = roc_curve(y_true[:, i], y_scores[:, i])
  199. roc_auc = auc(fpr, tpr)
  200. plt.plot(fpr, tpr, lw=2, label=f'ROC curve of class {class_name} (area = {roc_auc:.2f})')
  201. plt.plot([0, 1], [0, 1], 'k--', lw=2)
  202. plt.xlim([0.0, 1.0])
  203. plt.ylim([0.0, 1.05])
  204. plt.xlabel('False Positive Rate')
  205. plt.ylabel('True Positive Rate')
  206. plt.title('Receiver Operating Characteristic (ROC) Curves')
  207. plt.legend(loc="lower right")
  208. plt.savefig(save_path)
  209. plt.close()
  210. @staticmethod
  211. def test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained,
  212. trBatchSize, imgtransResize, imgtransCrop, timestampLaunch):
  213. """测试函数,支持深度学习模型和随机森林模型"""
  214. # 加载深度学习模型
  215. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  216. model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
  217. checkpoint = torch.load(pathModel, map_location=device)
  218. model.load_state_dict(checkpoint['state_dict'])
  219. model.eval() # 确保模型处于评估模式
  220. # 加载随机森林模型
  221. rf_classifier = joblib.load(pathRfModel)
  222. # 获取 transform
  223. transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
  224. datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
  225. transform=transformSequence, model=model)
  226. # 使用 extract_features 提取特征和标签
  227. features, labels = ChexnetTrainer.extract_features(datasetTest, model, device)
  228. # 使用随机森林模型进行预测
  229. rf_preds = rf_classifier.predict(features)
  230. # 计算多标签分类的 F1 分数
  231. rf_f1 = f1_score(labels, rf_preds, average="macro")
  232. print(f"Random Forest Multi-label F1 Score on test set: {rf_f1}")
  233. # 计算AUC-ROC
  234. rf_scores = []
  235. if hasattr(rf_classifier, "predict_proba"):
  236. rf_scores = rf_classifier.predict_proba(features)
  237. # 将列表转换为数组,并选择正类的概率
  238. rf_scores = np.array([rf_scores[i][:, 1] for i in range(len(rf_scores))]).reshape(labels.shape)
  239. else:
  240. # 如果没有predict_proba方法,则无法计算AUC
  241. print("随机森林模型不支持predict_proba方法,无法计算AUC-ROC。")
  242. rf_scores = np.zeros_like(labels)
  243. # 绘制AUC-ROC曲线
  244. CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
  245. 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
  246. 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
  247. 'Fibrosis', 'Pleural_Thickening',
  248. 'Hernia']
  249. ChexnetTrainer.plot_auc_roc(labels, rf_scores, CLASS_NAMES,
  250. os.path.join('images', 'test_auc_roc_curve.png'))
  251. print("测试集的AUC-ROC曲线已保存到 images/test_auc_roc_curve.png")
  252. # 输出深度学习模型和随机森林模型的 F1 分数
  253. print(f"Testing completed. Random Forest F1 Score: {rf_f1}")
  254. return labels, rf_preds # 返回真实标签和预测结果
  255. # 其他方法保持不变...
  256. @staticmethod
  257. def _get_model(nnArchitecture, nnClassCount, nnIsTrained):
  258. """根据选择的模型架构返回对应的模型"""
  259. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  260. # 根据模型架构选择模型
  261. if nnArchitecture == 'DENSE-NET-121':
  262. model = DenseNet121(nnClassCount, nnIsTrained).to(device)
  263. elif nnArchitecture == 'DENSE-NET-169':
  264. model = DenseNet169(nnClassCount, nnIsTrained).to(device)
  265. elif nnArchitecture == 'DENSE-NET-201':
  266. model = DenseNet201(nnClassCount, nnIsTrained).to(device)
  267. elif nnArchitecture == 'RESNET-50':
  268. model = ResNet50(nnClassCount, nnIsTrained).to(device)
  269. else:
  270. raise ValueError(
  271. f"Unknown architecture: {nnArchitecture}. Please choose from 'DENSE-NET-121', 'DENSE-NET-169', 'DENSE-NET-201', 'RESNET-50'.")
  272. return model
  273. @staticmethod
  274. def _get_transform(transCrop):
  275. """返回图像预处理的转换"""
  276. normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  277. transformList = [
  278. transforms.RandomResizedCrop(transCrop),
  279. transforms.RandomHorizontalFlip(),
  280. transforms.ToTensor(),
  281. normalize
  282. ]
  283. return transforms.Compose(transformList)