|  | @@ -0,0 +1,224 @@
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import os
 | 
	
		
			
				|  |  | +import time
 | 
	
		
			
				|  |  | +import torch.multiprocessing as mp
 | 
	
		
			
				|  |  | +import matplotlib.pyplot as plt
 | 
	
		
			
				|  |  | +import numpy as np
 | 
	
		
			
				|  |  | +from ChexnetTrainer import ChexnetTrainer
 | 
	
		
			
				|  |  | +from DatasetGenerator import DatasetGenerator
 | 
	
		
			
				|  |  | +from visual import HeatmapGenerator
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 确保 images 目录存在
 | 
	
		
			
				|  |  | +if not os.path.exists('images'):
 | 
	
		
			
				|  |  | +    os.makedirs('images')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 主函数,负责启动不同的功能(训练、测试或运行演示)
 | 
	
		
			
				|  |  | +def main():
 | 
	
		
			
				|  |  | +    # runDemo()  # 运行演示模式
 | 
	
		
			
				|  |  | +    # runTest()  # 测试模式(注释掉,可以通过解除注释来运行)
 | 
	
		
			
				|  |  | +    runTrain()  # 训练模式(注释掉,可以通过解除注释来运行)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# --------------------------------------------------------------------------------
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 训练函数,定义训练所需的参数并启动模型训练
 | 
	
		
			
				|  |  | +def runTrain():
 | 
	
		
			
				|  |  | +    DENSENET121 = 'DENSE-NET-121'  # 定义DenseNet121模型名称
 | 
	
		
			
				|  |  | +    DENSENET169 = 'DENSE-NET-169'  # 定义DenseNet169模型名称
 | 
	
		
			
				|  |  | +    DENSENET201 = 'DENSE-NET-201'  # 定义DenseNet201模型名称
 | 
	
		
			
				|  |  | +    Resnet50 = 'RESNET-50'  # 定义Resnet50模型名称
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 获取当前的时间戳,作为训练过程的标记
 | 
	
		
			
				|  |  | +    timestampTime = time.strftime("%H%M%S")
 | 
	
		
			
				|  |  | +    timestampDate = time.strftime("%d%m%Y")
 | 
	
		
			
				|  |  | +    timestampLaunch = timestampDate + '-' + timestampTime
 | 
	
		
			
				|  |  | +    print("Launching " + timestampLaunch)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 图像数据所在的路径
 | 
	
		
			
				|  |  | +    pathDirData = './chest xray14'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 训练、验证和测试数据集文件路径
 | 
	
		
			
				|  |  | +    # 每个文件中包含图像路径及其对应的标签
 | 
	
		
			
				|  |  | +    pathFileTrain = './dataset/train_2.txt'
 | 
	
		
			
				|  |  | +    pathFileVal = './dataset/valid_2.txt'
 | 
	
		
			
				|  |  | +    pathFileTest = './dataset/test_2.txt'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
 | 
	
		
			
				|  |  | +    nnArchitecture = DENSENET121
 | 
	
		
			
				|  |  | +    nnIsTrained = True  # 使用预训练的权重
 | 
	
		
			
				|  |  | +    nnClassCount = 14  # 数据集包含14个分类
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 训练参数:批量大小和最大迭代次数(epochs)
 | 
	
		
			
				|  |  | +    trBatchSize = 2
 | 
	
		
			
				|  |  | +    trMaxEpoch = 2
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
 | 
	
		
			
				|  |  | +    imgtransResize = 256
 | 
	
		
			
				|  |  | +    imgtransCrop = 224
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 保存模型的路径,包含时间戳
 | 
	
		
			
				|  |  | +    pathModel = 'm-' + timestampLaunch + '.pth.tar'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    print('Training NN architecture = ', nnArchitecture)
 | 
	
		
			
				|  |  | +    ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
 | 
	
		
			
				|  |  | +                         nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
 | 
	
		
			
				|  |  | +                         trMaxEpoch, imgtransResize, imgtransCrop,
 | 
	
		
			
				|  |  | +                         timestampLaunch, None)
 | 
	
		
			
				|  |  | +    print('Testing the trained model')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pathRfModel = os.path.join('images', 'random_forest_model.pkl')  # 更新路径
 | 
	
		
			
				|  |  | +    labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
 | 
	
		
			
				|  |  | +                                           nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
 | 
	
		
			
				|  |  | +                                           imgtransCrop, timestampLaunch)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 生成并保存热力图
 | 
	
		
			
				|  |  | +    # 选择一些测试集中的样本来生成热力图
 | 
	
		
			
				|  |  | +    transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
 | 
	
		
			
				|  |  | +    datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
 | 
	
		
			
				|  |  | +                                   transform=transformSequence, model=None)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 确保测试集中有足够的样本
 | 
	
		
			
				|  |  | +    num_samples = min(8, len(datasetTest))
 | 
	
		
			
				|  |  | +    sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False)  # 随机选择8个样本
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 定义分类名称
 | 
	
		
			
				|  |  | +    CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
 | 
	
		
			
				|  |  | +                   'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
 | 
	
		
			
				|  |  | +                   'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
 | 
	
		
			
				|  |  | +                   'Fibrosis', 'Pleural_Thickening',
 | 
	
		
			
				|  |  | +                   'Hernia']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 创建一个图形窗口
 | 
	
		
			
				|  |  | +    plt.figure(figsize=(20, 10))
 | 
	
		
			
				|  |  | +    n_cols = 4
 | 
	
		
			
				|  |  | +    n_rows = 2
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for idx, sample_idx in enumerate(sample_indices):
 | 
	
		
			
				|  |  | +        image_path = datasetTest.listImagePaths[sample_idx]
 | 
	
		
			
				|  |  | +        true_labels = labels[sample_idx]
 | 
	
		
			
				|  |  | +        pred_labels = rf_preds[sample_idx]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 加载图像
 | 
	
		
			
				|  |  | +        img = plt.imread(image_path)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 创建子图
 | 
	
		
			
				|  |  | +        ax = plt.subplot(n_rows, n_cols, idx + 1)
 | 
	
		
			
				|  |  | +        ax.imshow(img)
 | 
	
		
			
				|  |  | +        ax.axis('off')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 获取真实标签和预测标签的名称
 | 
	
		
			
				|  |  | +        true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
 | 
	
		
			
				|  |  | +        pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 设置标题
 | 
	
		
			
				|  |  | +        title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
 | 
	
		
			
				|  |  | +        ax.set_title(title, fontsize=10)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    plt.tight_layout()
 | 
	
		
			
				|  |  | +    output_plot_path = os.path.join('images', 'test_predictions.png')
 | 
	
		
			
				|  |  | +    plt.savefig(output_plot_path)
 | 
	
		
			
				|  |  | +    plt.show()
 | 
	
		
			
				|  |  | +    print(f"预测结果图已保存到 {output_plot_path}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 生成热力图(可选)
 | 
	
		
			
				|  |  | +    for idx in sample_indices:
 | 
	
		
			
				|  |  | +        image_path = datasetTest.listImagePaths[idx]
 | 
	
		
			
				|  |  | +        output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
 | 
	
		
			
				|  |  | +        h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
 | 
	
		
			
				|  |  | +        h.generate(image_path, output_heatmap_path, imgtransCrop)
 | 
	
		
			
				|  |  | +        print(f"热力图已保存到 {output_heatmap_path}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# --------------------------------------------------------------------------------
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 测试函数,加载预训练模型并在测试数据集上进行测试
 | 
	
		
			
				|  |  | +def runTest():
 | 
	
		
			
				|  |  | +    pathDirData = '/chest xray14'
 | 
	
		
			
				|  |  | +    pathFileTest = './dataset/test.txt'
 | 
	
		
			
				|  |  | +    nnArchitecture = 'DENSE-NET-121'
 | 
	
		
			
				|  |  | +    nnIsTrained = True
 | 
	
		
			
				|  |  | +    nnClassCount = 14
 | 
	
		
			
				|  |  | +    trBatchSize = 4
 | 
	
		
			
				|  |  | +    imgtransResize = 256
 | 
	
		
			
				|  |  | +    imgtransCrop = 224
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    timestampLaunch = ''
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 获取统一的 transformSequence
 | 
	
		
			
				|  |  | +    transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
 | 
	
		
			
				|  |  | +    pathRfModel = 'images/random_forest_model.pkl'  # 确保路径正确
 | 
	
		
			
				|  |  | +    labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
 | 
	
		
			
				|  |  | +                                           nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
 | 
	
		
			
				|  |  | +                                           imgtransCrop, timestampLaunch)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 生成并保存热力图
 | 
	
		
			
				|  |  | +    datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
 | 
	
		
			
				|  |  | +                                   transform=transformSequence, model=None)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 确保测试集中有足够的样本
 | 
	
		
			
				|  |  | +    num_samples = min(8, len(datasetTest))
 | 
	
		
			
				|  |  | +    sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False)  # 随机选择8个样本
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 定义分类名称
 | 
	
		
			
				|  |  | +    CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
 | 
	
		
			
				|  |  | +                   'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
 | 
	
		
			
				|  |  | +                   'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
 | 
	
		
			
				|  |  | +                   'Fibrosis', 'Pleural_Thickening',
 | 
	
		
			
				|  |  | +                   'Hernia']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 创建一个图形窗口
 | 
	
		
			
				|  |  | +    plt.figure(figsize=(20, 10))
 | 
	
		
			
				|  |  | +    n_cols = 4
 | 
	
		
			
				|  |  | +    n_rows = 2
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for idx, sample_idx in enumerate(sample_indices):
 | 
	
		
			
				|  |  | +        image_path = datasetTest.listImagePaths[sample_idx]
 | 
	
		
			
				|  |  | +        true_labels = labels[sample_idx]
 | 
	
		
			
				|  |  | +        pred_labels = rf_preds[sample_idx]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 加载图像
 | 
	
		
			
				|  |  | +        img = plt.imread(image_path)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 创建子图
 | 
	
		
			
				|  |  | +        ax = plt.subplot(n_rows, n_cols, idx + 1)
 | 
	
		
			
				|  |  | +        ax.imshow(img)
 | 
	
		
			
				|  |  | +        ax.axis('off')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 获取真实标签和预测标签的名称
 | 
	
		
			
				|  |  | +        true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
 | 
	
		
			
				|  |  | +        pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # 设置标题
 | 
	
		
			
				|  |  | +        title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
 | 
	
		
			
				|  |  | +        ax.set_title(title, fontsize=10)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    plt.tight_layout()
 | 
	
		
			
				|  |  | +    output_plot_path = os.path.join('images', 'test_predictions.png')
 | 
	
		
			
				|  |  | +    plt.savefig(output_plot_path)
 | 
	
		
			
				|  |  | +    plt.show()
 | 
	
		
			
				|  |  | +    print(f"预测结果图已保存到 {output_plot_path}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 生成热力图(可选)
 | 
	
		
			
				|  |  | +    for idx in sample_indices:
 | 
	
		
			
				|  |  | +        image_path = datasetTest.listImagePaths[idx]
 | 
	
		
			
				|  |  | +        output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
 | 
	
		
			
				|  |  | +        h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
 | 
	
		
			
				|  |  | +        h.generate(image_path, output_heatmap_path, imgtransCrop)
 | 
	
		
			
				|  |  | +        print(f"热力图已保存到 {output_heatmap_path}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# --------------------------------------------------------------------------------
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 演示函数,展示模型在测试集上的推理过程
 | 
	
		
			
				|  |  | +def runDemo():
 | 
	
		
			
				|  |  | +    # 原有代码保持不变
 | 
	
		
			
				|  |  | +    pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +# 确保代码在主进程中运行
 | 
	
		
			
				|  |  | +if __name__ == '__main__':
 | 
	
		
			
				|  |  | +    mp.set_start_method('spawn', force=True)
 | 
	
		
			
				|  |  | +    main()  # 启动主函数
 |