123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import os
- import numpy as np
- import time
- import sys
- from ChexnetTrainer import ChexnetTrainer # 从ChexnetTrainer模块中导入训练相关的类
- # --------------------------------------------------------------------------------
- # 主函数,负责启动不同的功能(训练、测试或运行演示)
- def main():
- #runDemo() # 运行演示模式
- # runTest() # 测试模式(注释掉,可以通过解除注释来运行)
- runTrain() # 训练模式(注释掉,可以通过解除注释来运行)
- # --------------------------------------------------------------------------------
- # 训练函数,定义训练所需的参数并启动模型训练
- def runTrain():
- DENSENET121 = 'DENSE-NET-121' # 定义DenseNet121模型名称
- DENSENET169 = 'DENSE-NET-169' # 定义DenseNet169模型名称
- DENSENET201 = 'DENSE-NET-201' # 定义DenseNet201模型名称
- Resnet50='RESNET-50' # 定义Resnet50模型名称
- # 获取当前的时间戳,作为训练过程的标记
- timestampTime = time.strftime("%H%M%S")
- timestampDate = time.strftime("%d%m%Y")
- timestampLaunch = timestampDate + '-' + timestampTime
- print("Launching " + timestampLaunch)
- # 图像数据所在的路径
- pathDirData = 'chest xray14'
- # 训练、验证和测试数据集文件路径
- # 每个文件中包含图像路径及其对应的标签
- pathFileTrain = './dataset/train_2.txt'
- pathFileVal = './dataset/valid_2.txt'
- pathFileTest = './dataset/test_2.txt'
- # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
- nnArchitecture = DENSENET121
- nnIsTrained = True # 使用预训练的权重
- nnClassCount = 14 # 数据集包含14个分类
- # 训练参数:批量大小和最大迭代次数(epochs)
- trBatchSize = 16
- trMaxEpoch = 48
- # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
- imgtransResize = 256
- imgtransCrop = 224
- # 保存模型的路径,包含时间戳
- pathModel = 'm-' + timestampLaunch + '.pth.tar'
- print('Training NN architecture = ', nnArchitecture)
- ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
- nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
- trMaxEpoch, imgtransResize, imgtransCrop,
- timestampLaunch, None)
- print('Testing the trained model')
- ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
- nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
- imgtransCrop, timestampLaunch)
- # --------------------------------------------------------------------------------
- # 测试函数,加载预训练模型并在测试数据集上进行测试
- def runTest():
- pathDirData = '/chest xray14' # 数据路径
- pathFileTest = './dataset/test.txt' # 测试集路径
- nnArchitecture = 'DENSE-NET-121' # 使用DenseNet121架构
- nnIsTrained = True # 使用预训练模型
- nnClassCount = 14 # 分类数
- trBatchSize = 4 # 批量大小
- imgtransResize = 256 # 图像缩放大小
- imgtransCrop = 224 # 图像裁剪大小
- # 预训练模型路径
- pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
- timestampLaunch = '' # 时间戳
- # 调用测试函数,使用上述参数
- ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
- nnClassCount, nnIsTrained,
- trBatchSize, imgtransResize, imgtransCrop,
- timestampLaunch)
- # --------------------------------------------------------------------------------
- # 演示函数,展示模型在测试集上的推理过程
- def runDemo():
- pathDirData = '/media/sunjc0306/未命名/CODE/chest xray14' # 数据路径
- pathFileTest = './dataset/test.txt' # 测试集路径
- nnArchitecture = 'DENSE-NET-121' # 使用DenseNet121架构
- nnIsTrained = True # 使用预训练模型
- nnClassCount = 14 # 分类数
- trBatchSize = 4 # 批量大小
- imgtransResize = 256 # 图像缩放大小
- imgtransCrop = 224 # 图像裁剪大小
- pathModel = 'm-06102024-235412BCELoss()delete.pth.tar' # 预训练模型路径
- # 定义分类名称
- CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
- 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening', 'Hernia']
- import torch
- import torch.backends.cudnn as cudnn
- import torchvision.transforms as transforms
- from torch.utils.data import DataLoader
- from tqdm import tqdm # 用于进度条显示
- from DensenetModels import DenseNet121 # 导入模型
- from DensenetModels import DenseNet169
- from DensenetModels import DenseNet201
- from DatasetGenerator import DatasetGenerator # 导入数据集生成器
- cudnn.benchmark = True # 加速模型的推理过程
- # 设置网络架构并加载模型
- if nnArchitecture == 'DENSE-NET-121':
- model = DenseNet121(nnClassCount, nnIsTrained).cuda()
- elif nnArchitecture == 'DENSE-NET-169':
- model = DenseNet169(nnClassCount, nnIsTrained).cuda()
- elif nnArchitecture == 'DENSE-NET-201':
- model = DenseNet201(nnClassCount, nnIsTrained).cuda()
- model = model.cuda()
- # 加载预训练的模型权重
- modelCheckpoint = torch.load(pathModel)
- model.load_state_dict(modelCheckpoint['state_dict'])
- # 图像变换:先缩放,再进行裁剪,最后进行归一化
- normalize = transforms.Normalize([0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225])
- # 数据集变换和数据加载器
- transformList = []
- transformList.append(transforms.Resize(imgtransResize))
- transformList.append(transforms.TenCrop(imgtransCrop)) # 对图像进行十裁剪
- transformList.append(transforms.Lambda(lambda crops: torch.stack(
- [transforms.ToTensor()(crop) for crop in crops]))) # 转换为张量
- transformList.append(transforms.Lambda(
- lambda crops: torch.stack([normalize(crop) for crop in crops]))) # 归一化
- transformSequence = transforms.Compose(transformList)
- # 构建测试集数据集生成器
- datasetTest = DatasetGenerator(pathImageDirectory=pathDirData,
- pathDatasetFile=pathFileTest,
- transform=transformSequence)
- model.eval() # 设置模型为评估模式
- i = 0 # 初始索引
- # 定义演示函数,展示模型在某个样本上的推理结果
- def demo(i):
- (input, target) = datasetTest[i] # 获取输入和目标
- n_crops, c, h, w = input.size() # 获取输入的尺寸
- # 将输入放入模型进行推理
- varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(),
- volatile=True)
- with torch.no_grad():
- out = model(varInput)
- outMean = out.view(1, n_crops, -1).mean(1) # 对裁剪后的多个输出取平均
- print('-------------------------------')
- pd = torch.sigmoid(outMean).ge_(0.5).long().detach().cpu().numpy()[
- 0].tolist() # 预测结果
- gt = target.long().detach().cpu().numpy().tolist() # 真实标签
- print(f"PD_{i}", pd) # 输出预测结果
- print(f"GT_{i}", gt) # 输出真实标签
- return pd, gt
- # 逐个演示测试集中的样本
- for i in range(len(datasetTest)):
- pd, gt = demo(i) # 调用演示函数
- if type(gt) == list:
- pd_name = []
- gt_name = []
- for i in range(len(CLASS_NAMES)):
- if pd[i] == 1:
- pd_name.append(CLASS_NAMES[i]) # 将预测为1的标签存入pd_name
- if gt[i] == 1:
- gt_name.append(CLASS_NAMES[i]) # 将真实为1的标签存入gt_name
- print(datasetTest.listImagePaths[i]) # 打印图像路径
- print(pd_name) # 打印预测的疾病名称
- print(gt_name) # 打印真实的疾病名称
- if __name__ == '__main__':
- main() # 启动主函数
|