ChexnetTrainer(2)(4).py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import os
  2. import numpy as np
  3. import time
  4. import sys
  5. import matplotlib.pyplot as plt
  6. import torch
  7. import torch.nn as nn
  8. import torch.backends.cudnn as cudnn
  9. import torchvision
  10. import torchvision.transforms as transforms
  11. import torch.optim as optim
  12. import torch.nn.functional as tfunc
  13. import tqdm
  14. from torch.utils.data import DataLoader
  15. from torch.optim.lr_scheduler import ReduceLROnPlateau
  16. import torch.nn.functional as func
  17. from tqdm import tqdm
  18. from sklearn.metrics import roc_auc_score,roc_curve,auc,f1_score
  19. from DensenetModels import DenseNet121, DenseNet169, \
  20. DenseNet201,ResNet50,FocalLoss # 引入不同版本的DenseNet
  21. from DatasetGenerator import DatasetGenerator # 引入自定义的数据集生成器
  22. # 定义一个用于训练、验证、测试DenseNet模型的类
  23. class ChexnetTrainer():
  24. # 训练网络的主函数
  25. def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
  26. nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
  27. transCrop, launchTimestamp, checkpoint):
  28. # 根据选择的模型架构来初始化不同的DenseNet模型
  29. if nnArchitecture == 'DENSE-NET-121':
  30. model = DenseNet121(nnClassCount,
  31. nnIsTrained).cuda() # 初始化DenseNet121
  32. elif nnArchitecture == 'DENSE-NET-169':
  33. model = DenseNet169(nnClassCount,
  34. nnIsTrained).cuda() # 初始化DenseNet169
  35. elif nnArchitecture == 'DENSE-NET-201':
  36. model = DenseNet201(nnClassCount,
  37. nnIsTrained).cuda() # 初始化DenseNet201
  38. elif nnArchitecture == 'RESNET-50':
  39. model = ResNet50(nnClassCount, nnIsTrained).cuda()
  40. else:
  41. raise ValueError(
  42. f"Unknown architecture: {nnArchitecture}. Please choose from 'DENSE-NET-121', 'DENSE-NET-169', 'DENSE-NET-201', 'RESNET-50'.")
  43. model = model.cuda() # 将模型加载到GPU上
  44. # 数据预处理,包含随机裁剪、水平翻转、归一化等
  45. normalize = transforms.Normalize([0.485, 0.456, 0.406],
  46. [0.229, 0.224, 0.225])
  47. transformList = [
  48. transforms.RandomResizedCrop(transCrop), # 随机裁剪到指定大小
  49. transforms.RandomHorizontalFlip(), # 随机水平翻转
  50. transforms.ToTensor(), # 转换为张量
  51. normalize # 归一化
  52. ]
  53. transformSequence = transforms.Compose(transformList) # 将这些变换组成序列
  54. # 创建训练和验证集的数据加载器
  55. datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData,
  56. pathDatasetFile=pathFileTrain,
  57. transform=transformSequence)
  58. #pos_weight = datasetTrain.calculate_pos_weights() # 计算 pos_weight
  59. datasetVal = DatasetGenerator(pathImageDirectory=pathDirData,
  60. pathDatasetFile=pathFileVal,
  61. transform=transformSequence)
  62. dataLoaderTrain = DataLoader(dataset=datasetTrain,
  63. batch_size=trBatchSize, shuffle=True,
  64. num_workers=8, pin_memory=True) # 训练集
  65. dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize,
  66. shuffle=False, num_workers=8,
  67. pin_memory=True) # 验证集
  68. # 设置优化器和学习率调度器
  69. optimizer = optim.Adam(model.parameters(), lr=0.0001,
  70. betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
  71. scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5,
  72. mode='min') # 当损失不再下降时,减少学习率
  73. # 使用多标签分类的损失函数
  74. # 设置每个类别的权重,基于提供的AUC-ROC曲线
  75. # class_weights = torch.tensor(
  76. # [1.39, 0.73, 1.33, 1.62, 1.32, 1.41, 1.54, 1.27, 1.59, 1.25, 1.28,
  77. # 1.27, 1.48, 1.18], dtype=torch.float).cuda()
  78. # 使用加权的 BCEWithLogitsLoss 作为损失函数
  79. #loss = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights)
  80. #loss = torch.nn.MultiLabelSoftMarginLoss()
  81. # 使用Focal Loss作为损失函数
  82. # loss = FocalLoss(alpha=1, gamma=2, logits=True)
  83. # launchTimestamp += str(loss)
  84. # launchTimestamp += 'delete'
  85. class_weights = torch.tensor([
  86. 1.1762, # Atelectasis
  87. 0.6735, # Cardiomegaly
  88. 0.9410, # Effusion
  89. 1.6680, # Infiltration
  90. 0.9699, # Mass
  91. 1.1950, # Nodule
  92. 1.7584, # Pneumonia
  93. 0.6859, # Pneumothorax
  94. 1.6683, # Consolidation
  95. 0.7744, # Edema
  96. 0.4625, # Emphysema
  97. 0.7385, # Fibrosis
  98. 1.3764, # Pleural_Thickening
  99. 0.1758 # Hernia
  100. ], dtype=torch.float32).cuda()
  101. loss = FocalLoss(alpha=class_weights, gamma=2, logits=True)
  102. # 加载检查点文件(如果存在),继续训练
  103. #checkpoint = 'm-29102024-093913MultiLabelSoftMarginLoss()delete.pth.tar' # 测试时写死的文件名
  104. # if checkpoint != None:
  105. # modelCheckpoint = torch.load(checkpoint)
  106. # model.load_state_dict(modelCheckpoint['state_dict'])
  107. # optimizer.load_state_dict(modelCheckpoint['optimizer'])
  108. lossMIN = 100000 # 记录最小损失值
  109. train_f1_scores, val_f1_scores = [], []
  110. # 训练循环,遍历指定的epoch数
  111. for epochID in range(trMaxEpoch):
  112. # 获取当前时间戳,记录每个epoch的开始时间
  113. timestampTime = time.strftime("%H%M%S")
  114. timestampDate = time.strftime("%d%m%Y")
  115. timestampSTART = timestampDate + '-' + timestampTime
  116. # 训练一个epoch
  117. ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer,
  118. scheduler, trMaxEpoch, nnClassCount, loss,train_f1_scores)
  119. # 验证一个epoch
  120. lossVal, losstensor,val_f1 = ChexnetTrainer.epochVal(model, dataLoaderVal,optimizer, scheduler,
  121. trMaxEpoch,nnClassCount, loss)
  122. val_f1_scores.append(val_f1)
  123. # 获取每个epoch结束时的时间戳
  124. timestampTime = time.strftime("%H%M%S")
  125. timestampDate = time.strftime("%d%m%Y")
  126. timestampEND = timestampDate + '-' + timestampTime
  127. # 使用调度器调整学习率
  128. scheduler.step(losstensor.item())
  129. # 保存当前最佳模型
  130. if lossVal < lossMIN:
  131. lossMIN = lossVal
  132. torch.save(
  133. {'epoch': epochID + 1, 'state_dict': model.state_dict(),
  134. 'best_loss': lossMIN, 'optimizer': optimizer.state_dict()},
  135. 'm-' + launchTimestamp + '.pth.tar')
  136. print('Epoch [' + str(
  137. epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(
  138. lossVal))
  139. else:
  140. print('Epoch [' + str(
  141. epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(
  142. lossVal))
  143. plt.plot(train_f1_scores, label="Train F1-Score")
  144. plt.plot(val_f1_scores, label="Val F1-Score")
  145. plt.xlabel("Epoch")
  146. plt.ylabel("F1 Score")
  147. plt.title("F1 Score per Epoch")
  148. plt.legend()
  149. #plt.savefig("f1score.png")
  150. #plt.show()
  151. # 训练每个epoch的具体过程
  152. def epochTrain(model, dataLoader, optimizer, scheduler, epochMax,
  153. classCount, loss,f1_scores):
  154. model.train()
  155. all_targets = []
  156. all_preds = []
  157. for batchID, (input, target) in enumerate(tqdm(dataLoader)):
  158. target = target.cuda()
  159. input = input.cuda()
  160. varInput = torch.autograd.Variable(input)
  161. varTarget = torch.autograd.Variable(target)
  162. varOutput = model(varInput)
  163. lossvalue = loss(varOutput, varTarget)
  164. optimizer.zero_grad()
  165. lossvalue.backward()
  166. optimizer.step()
  167. pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
  168. all_targets.extend(target.cpu().numpy())
  169. all_preds.extend(pred)
  170. f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
  171. f1_scores.append(f1)
  172. # 验证每个epoch的具体过程
  173. def epochVal(model, dataLoader, optimizer, scheduler, epochMax, classCount,
  174. loss):
  175. model.eval() # 设置模型为评估模式
  176. lossVal = 0
  177. lossValNorm = 0
  178. losstensorMean = 0
  179. all_targets = []
  180. all_preds = []
  181. for i, (input, target) in enumerate(dataLoader): # 遍历每个批次
  182. target = target.cuda()
  183. input = input.cuda()
  184. with torch.no_grad(): # 禁用梯度计算,节省内存
  185. varInput = torch.autograd.Variable(input)
  186. varTarget = torch.autograd.Variable(target)
  187. varOutput = model(varInput) # 前向传播
  188. losstensor = loss(varOutput, varTarget) # 计算损失
  189. losstensorMean += losstensor # 累积损失
  190. lossVal += losstensor.item() # 累积损失值
  191. lossValNorm += 1 # 记录批次数量
  192. pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
  193. all_targets.extend(target.cpu().numpy())
  194. all_preds.extend(pred)
  195. f1 = f1_score(np.array(all_targets), np.array(all_preds),
  196. average="macro")
  197. outLoss = lossVal / lossValNorm # 计算平均损失
  198. losstensorMean = losstensorMean / lossValNorm # 平均损失张量
  199. return outLoss, losstensorMean,f1 # 返回验证损失和损失张量均值
  200. # 计算AUROC(AUC-ROC曲线下的面积)
  201. def computeAUROC(dataGT, dataPRED, classCount, plot_roc_curve=False,
  202. class_names=None):
  203. outAUROC = []
  204. datanpGT = dataGT.cpu().numpy() # 将ground truth转换为numpy格式
  205. datanpPRED = dataPRED.cpu().numpy() # 将预测结果转换为numpy格式
  206. if plot_roc_curve and class_names is None:
  207. class_names = [f"Class {i + 1}" for i in range(classCount)]
  208. # 针对每个类别计算ROC AUC分数
  209. plt.figure(figsize=(12, 8))
  210. for i in range(classCount):
  211. # 计算当前类别的ROC AUC分数
  212. outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
  213. if plot_roc_curve:
  214. # 计算 ROC 曲线的点
  215. fpr, tpr, _ = roc_curve(datanpGT[:, i], datanpPRED[:, i])
  216. roc_auc = auc(fpr, tpr)
  217. plt.plot(fpr, tpr, lw=2,
  218. label=f'{class_names[i]} (area = {roc_auc:.2f})')
  219. if plot_roc_curve:
  220. plt.plot([0, 1], [0, 1], color='navy', linestyle='--') # 绘制随机猜测线
  221. plt.xlim([0.0, 1.0])
  222. plt.ylim([0.0, 1.05])
  223. plt.xlabel('False Positive Rate')
  224. plt.ylabel('True Positive Rate')
  225. plt.title('Receiver Operating Characteristic (ROC) Curves')
  226. plt.legend(loc="lower right")
  227. plt.savefig("aucroc.png")
  228. #plt.show()
  229. return outAUROC # 返回每个类别的AUROC值
  230. # 测试模型
  231. def test(pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount,
  232. nnIsTrained, trBatchSize, transResize, transCrop, launchTimeStamp):
  233. CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
  234. 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
  235. 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
  236. 'Fibrosis', 'Pleural_Thickening', 'Hernia']
  237. cudnn.benchmark = True # 加速卷积操作
  238. # 根据架构选择相应的DenseNet模型
  239. if nnArchitecture == 'DENSE-NET-121':
  240. model = DenseNet121(nnClassCount, nnIsTrained).cuda()
  241. elif nnArchitecture == 'DENSE-NET-169':
  242. model = DenseNet169(nnClassCount, nnIsTrained).cuda()
  243. elif nnArchitecture == 'DENSE-NET-201':
  244. model = DenseNet201(nnClassCount, nnIsTrained).cuda()
  245. elif nnArchitecture == 'RESNET-50':
  246. model = ResNet50(nnClassCount, nnIsTrained).cuda()
  247. modelCheckpoint = torch.load(pathModel) # 加载模型
  248. model.load_state_dict(modelCheckpoint['state_dict']) # 载入训练好的参数
  249. model = model.cuda() # 将模型加载到GPU上
  250. model.eval() # 设置为评估模式
  251. # 定义数据预处理
  252. normalize = transforms.Normalize([0.485, 0.456, 0.406],
  253. [0.229, 0.224, 0.225])
  254. transformList = [
  255. transforms.Resize(transResize), # 调整大小
  256. transforms.CenterCrop(transCrop), # 中心裁剪
  257. transforms.ToTensor(), # 转换为张量
  258. normalize # 归一化
  259. ]
  260. transformSequence = transforms.Compose(transformList)
  261. # 创建测试集的数据加载器
  262. datasetTest = DatasetGenerator(pathImageDirectory=pathDirData,
  263. pathDatasetFile=pathFileTest,
  264. transform=transformSequence)
  265. dataLoaderTest = DataLoader(dataset=datasetTest, batch_size=trBatchSize,
  266. shuffle=False, num_workers=8,
  267. pin_memory=True)
  268. # 初始化张量来存储ground truth和预测结果
  269. outGT = torch.FloatTensor().cuda()
  270. outPRED = torch.FloatTensor().cuda()
  271. # 遍历测试集
  272. for i, (input, target) in enumerate(dataLoaderTest):
  273. target = target.cuda()
  274. input = input.cuda()
  275. with torch.no_grad():
  276. varInput = torch.autograd.Variable(input)
  277. out = model(varInput) # 前向传播
  278. outPRED = torch.cat((outPRED, out), 0) # 将输出结果连接起来
  279. outGT = torch.cat((outGT, target), 0) # 将ground truth连接起来
  280. # 计算AUROC值
  281. aurocIndividual = ChexnetTrainer.computeAUROC(outGT, outPRED,
  282. nnClassCount,
  283. plot_roc_curve=True,
  284. class_names=CLASS_NAMES)
  285. aurocMean = np.array(aurocIndividual).mean() # 计算平均AUROC值
  286. # 输出每个类别的AUROC
  287. for i in range(len(aurocIndividual)):
  288. print(f'{CLASS_NAMES[i]}: {aurocIndividual[i]}')
  289. print(f'MEAN: {aurocMean}')