Main(2)(2).py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import os
  2. import numpy as np
  3. import time
  4. import sys
  5. from ChexnetTrainer import ChexnetTrainer # 从ChexnetTrainer模块中导入训练相关的类
  6. # --------------------------------------------------------------------------------
  7. # 主函数,负责启动不同的功能(训练、测试或运行演示)
  8. def main():
  9. #runDemo() # 运行演示模式
  10. # runTest() # 测试模式(注释掉,可以通过解除注释来运行)
  11. runTrain() # 训练模式(注释掉,可以通过解除注释来运行)
  12. # --------------------------------------------------------------------------------
  13. # 训练函数,定义训练所需的参数并启动模型训练
  14. def runTrain():
  15. DENSENET121 = 'DENSE-NET-121' # 定义DenseNet121模型名称
  16. DENSENET169 = 'DENSE-NET-169' # 定义DenseNet169模型名称
  17. DENSENET201 = 'DENSE-NET-201' # 定义DenseNet201模型名称
  18. Resnet50='RESNET-50' # 定义Resnet50模型名称
  19. # 获取当前的时间戳,作为训练过程的标记
  20. timestampTime = time.strftime("%H%M%S")
  21. timestampDate = time.strftime("%d%m%Y")
  22. timestampLaunch = timestampDate + '-' + timestampTime
  23. print("Launching " + timestampLaunch)
  24. # 图像数据所在的路径
  25. pathDirData = 'chest xray14'
  26. # 训练、验证和测试数据集文件路径
  27. # 每个文件中包含图像路径及其对应的标签
  28. pathFileTrain = './dataset/train_2.txt'
  29. pathFileVal = './dataset/valid_2.txt'
  30. pathFileTest = './dataset/test_2.txt'
  31. # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
  32. nnArchitecture = DENSENET121
  33. nnIsTrained = True # 使用预训练的权重
  34. nnClassCount = 14 # 数据集包含14个分类
  35. # 训练参数:批量大小和最大迭代次数(epochs)
  36. trBatchSize = 16
  37. trMaxEpoch = 48
  38. # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
  39. imgtransResize = 256
  40. imgtransCrop = 224
  41. # 保存模型的路径,包含时间戳
  42. pathModel = 'm-' + timestampLaunch + '.pth.tar'
  43. print('Training NN architecture = ', nnArchitecture)
  44. ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
  45. nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
  46. trMaxEpoch, imgtransResize, imgtransCrop,
  47. timestampLaunch, None)
  48. print('Testing the trained model')
  49. ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
  50. nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
  51. imgtransCrop, timestampLaunch)
  52. # --------------------------------------------------------------------------------
  53. # 测试函数,加载预训练模型并在测试数据集上进行测试
  54. def runTest():
  55. pathDirData = '/chest xray14' # 数据路径
  56. pathFileTest = './dataset/test.txt' # 测试集路径
  57. nnArchitecture = 'DENSE-NET-121' # 使用DenseNet121架构
  58. nnIsTrained = True # 使用预训练模型
  59. nnClassCount = 14 # 分类数
  60. trBatchSize = 4 # 批量大小
  61. imgtransResize = 256 # 图像缩放大小
  62. imgtransCrop = 224 # 图像裁剪大小
  63. # 预训练模型路径
  64. pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
  65. timestampLaunch = '' # 时间戳
  66. # 调用测试函数,使用上述参数
  67. ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
  68. nnClassCount, nnIsTrained,
  69. trBatchSize, imgtransResize, imgtransCrop,
  70. timestampLaunch)
  71. # --------------------------------------------------------------------------------
  72. # 演示函数,展示模型在测试集上的推理过程
  73. def runDemo():
  74. pathDirData = '/media/sunjc0306/未命名/CODE/chest xray14' # 数据路径
  75. pathFileTest = './dataset/test.txt' # 测试集路径
  76. nnArchitecture = 'DENSE-NET-121' # 使用DenseNet121架构
  77. nnIsTrained = True # 使用预训练模型
  78. nnClassCount = 14 # 分类数
  79. trBatchSize = 4 # 批量大小
  80. imgtransResize = 256 # 图像缩放大小
  81. imgtransCrop = 224 # 图像裁剪大小
  82. pathModel = 'm-06102024-235412BCELoss()delete.pth.tar' # 预训练模型路径
  83. # 定义分类名称
  84. CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
  85. 'Mass', 'Nodule', 'Pneumonia',
  86. 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
  87. 'Fibrosis', 'Pleural_Thickening', 'Hernia']
  88. import torch
  89. import torch.backends.cudnn as cudnn
  90. import torchvision.transforms as transforms
  91. from torch.utils.data import DataLoader
  92. from tqdm import tqdm # 用于进度条显示
  93. from DensenetModels import DenseNet121 # 导入模型
  94. from DensenetModels import DenseNet169
  95. from DensenetModels import DenseNet201
  96. from DatasetGenerator import DatasetGenerator # 导入数据集生成器
  97. cudnn.benchmark = True # 加速模型的推理过程
  98. # 设置网络架构并加载模型
  99. if nnArchitecture == 'DENSE-NET-121':
  100. model = DenseNet121(nnClassCount, nnIsTrained).cuda()
  101. elif nnArchitecture == 'DENSE-NET-169':
  102. model = DenseNet169(nnClassCount, nnIsTrained).cuda()
  103. elif nnArchitecture == 'DENSE-NET-201':
  104. model = DenseNet201(nnClassCount, nnIsTrained).cuda()
  105. model = model.cuda()
  106. # 加载预训练的模型权重
  107. modelCheckpoint = torch.load(pathModel)
  108. model.load_state_dict(modelCheckpoint['state_dict'])
  109. # 图像变换:先缩放,再进行裁剪,最后进行归一化
  110. normalize = transforms.Normalize([0.485, 0.456, 0.406],
  111. [0.229, 0.224, 0.225])
  112. # 数据集变换和数据加载器
  113. transformList = []
  114. transformList.append(transforms.Resize(imgtransResize))
  115. transformList.append(transforms.TenCrop(imgtransCrop)) # 对图像进行十裁剪
  116. transformList.append(transforms.Lambda(lambda crops: torch.stack(
  117. [transforms.ToTensor()(crop) for crop in crops]))) # 转换为张量
  118. transformList.append(transforms.Lambda(
  119. lambda crops: torch.stack([normalize(crop) for crop in crops]))) # 归一化
  120. transformSequence = transforms.Compose(transformList)
  121. # 构建测试集数据集生成器
  122. datasetTest = DatasetGenerator(pathImageDirectory=pathDirData,
  123. pathDatasetFile=pathFileTest,
  124. transform=transformSequence)
  125. model.eval() # 设置模型为评估模式
  126. i = 0 # 初始索引
  127. # 定义演示函数,展示模型在某个样本上的推理结果
  128. def demo(i):
  129. (input, target) = datasetTest[i] # 获取输入和目标
  130. n_crops, c, h, w = input.size() # 获取输入的尺寸
  131. # 将输入放入模型进行推理
  132. varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(),
  133. volatile=True)
  134. with torch.no_grad():
  135. out = model(varInput)
  136. outMean = out.view(1, n_crops, -1).mean(1) # 对裁剪后的多个输出取平均
  137. print('-------------------------------')
  138. pd = torch.sigmoid(outMean).ge_(0.5).long().detach().cpu().numpy()[
  139. 0].tolist() # 预测结果
  140. gt = target.long().detach().cpu().numpy().tolist() # 真实标签
  141. print(f"PD_{i}", pd) # 输出预测结果
  142. print(f"GT_{i}", gt) # 输出真实标签
  143. return pd, gt
  144. # 逐个演示测试集中的样本
  145. for i in range(len(datasetTest)):
  146. pd, gt = demo(i) # 调用演示函数
  147. if type(gt) == list:
  148. pd_name = []
  149. gt_name = []
  150. for i in range(len(CLASS_NAMES)):
  151. if pd[i] == 1:
  152. pd_name.append(CLASS_NAMES[i]) # 将预测为1的标签存入pd_name
  153. if gt[i] == 1:
  154. gt_name.append(CLASS_NAMES[i]) # 将真实为1的标签存入gt_name
  155. print(datasetTest.listImagePaths[i]) # 打印图像路径
  156. print(pd_name) # 打印预测的疾病名称
  157. print(gt_name) # 打印真实的疾病名称
  158. if __name__ == '__main__':
  159. main() # 启动主函数