123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- import os
- import time
- import torch.multiprocessing as mp
- import matplotlib.pyplot as plt
- import numpy as np
- from ChexnetTrainer import ChexnetTrainer
- from DatasetGenerator import DatasetGenerator
- from visual import HeatmapGenerator
- # 确保 images 目录存在
- if not os.path.exists('images'):
- os.makedirs('images')
- # 主函数,负责启动不同的功能(训练、测试或运行演示)
- def main():
- # runDemo() # 运行演示模式
- # runTest() # 测试模式(注释掉,可以通过解除注释来运行)
- runTrain() # 训练模式(注释掉,可以通过解除注释来运行)
- # --------------------------------------------------------------------------------
- # 训练函数,定义训练所需的参数并启动模型训练
- def runTrain():
- DENSENET121 = 'DENSE-NET-121' # 定义DenseNet121模型名称
- DENSENET169 = 'DENSE-NET-169' # 定义DenseNet169模型名称
- DENSENET201 = 'DENSE-NET-201' # 定义DenseNet201模型名称
- Resnet50 = 'RESNET-50' # 定义Resnet50模型名称
- # 获取当前的时间戳,作为训练过程的标记
- timestampTime = time.strftime("%H%M%S")
- timestampDate = time.strftime("%d%m%Y")
- timestampLaunch = timestampDate + '-' + timestampTime
- print("Launching " + timestampLaunch)
- # 图像数据所在的路径
- pathDirData = './chest xray14'
- # 训练、验证和测试数据集文件路径
- # 每个文件中包含图像路径及其对应的标签
- pathFileTrain = './dataset/train_2.txt'
- pathFileVal = './dataset/valid_2.txt'
- pathFileTest = './dataset/test_2.txt'
- # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
- nnArchitecture = DENSENET121
- nnIsTrained = True # 使用预训练的权重
- nnClassCount = 14 # 数据集包含14个分类
- # 训练参数:批量大小和最大迭代次数(epochs)
- trBatchSize = 2
- trMaxEpoch = 2
- # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
- imgtransResize = 256
- imgtransCrop = 224
- # 保存模型的路径,包含时间戳
- pathModel = 'm-' + timestampLaunch + '.pth.tar'
- print('Training NN architecture = ', nnArchitecture)
- ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
- nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
- trMaxEpoch, imgtransResize, imgtransCrop,
- timestampLaunch, None)
- print('Testing the trained model')
- pathRfModel = os.path.join('images', 'random_forest_model.pkl') # 更新路径
- labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
- nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
- imgtransCrop, timestampLaunch)
- # 生成并保存热力图
- # 选择一些测试集中的样本来生成热力图
- transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
- datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
- transform=transformSequence, model=None)
- # 确保测试集中有足够的样本
- num_samples = min(8, len(datasetTest))
- sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本
- # 定义分类名称
- CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
- 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening',
- 'Hernia']
- # 创建一个图形窗口
- plt.figure(figsize=(20, 10))
- n_cols = 4
- n_rows = 2
- for idx, sample_idx in enumerate(sample_indices):
- image_path = datasetTest.listImagePaths[sample_idx]
- true_labels = labels[sample_idx]
- pred_labels = rf_preds[sample_idx]
- # 加载图像
- img = plt.imread(image_path)
- # 创建子图
- ax = plt.subplot(n_rows, n_cols, idx + 1)
- ax.imshow(img)
- ax.axis('off')
- # 获取真实标签和预测标签的名称
- true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
- pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
- # 设置标题
- title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
- ax.set_title(title, fontsize=10)
- plt.tight_layout()
- output_plot_path = os.path.join('images', 'test_predictions.png')
- plt.savefig(output_plot_path)
- plt.show()
- print(f"预测结果图已保存到 {output_plot_path}")
- # 生成热力图(可选)
- for idx in sample_indices:
- image_path = datasetTest.listImagePaths[idx]
- output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
- h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
- h.generate(image_path, output_heatmap_path, imgtransCrop)
- print(f"热力图已保存到 {output_heatmap_path}")
- # --------------------------------------------------------------------------------
- # 测试函数,加载预训练模型并在测试数据集上进行测试
- def runTest():
- pathDirData = '/chest xray14'
- pathFileTest = './dataset/test.txt'
- nnArchitecture = 'DENSE-NET-121'
- nnIsTrained = True
- nnClassCount = 14
- trBatchSize = 4
- imgtransResize = 256
- imgtransCrop = 224
- pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
- timestampLaunch = ''
- # 获取统一的 transformSequence
- transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
- pathRfModel = 'images/random_forest_model.pkl' # 确保路径正确
- labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
- nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
- imgtransCrop, timestampLaunch)
- # 生成并保存热力图
- datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
- transform=transformSequence, model=None)
- # 确保测试集中有足够的样本
- num_samples = min(8, len(datasetTest))
- sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本
- # 定义分类名称
- CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
- 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening',
- 'Hernia']
- # 创建一个图形窗口
- plt.figure(figsize=(20, 10))
- n_cols = 4
- n_rows = 2
- for idx, sample_idx in enumerate(sample_indices):
- image_path = datasetTest.listImagePaths[sample_idx]
- true_labels = labels[sample_idx]
- pred_labels = rf_preds[sample_idx]
- # 加载图像
- img = plt.imread(image_path)
- # 创建子图
- ax = plt.subplot(n_rows, n_cols, idx + 1)
- ax.imshow(img)
- ax.axis('off')
- # 获取真实标签和预测标签的名称
- true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
- pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
- # 设置标题
- title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
- ax.set_title(title, fontsize=10)
- plt.tight_layout()
- output_plot_path = os.path.join('images', 'test_predictions.png')
- plt.savefig(output_plot_path)
- plt.show()
- print(f"预测结果图已保存到 {output_plot_path}")
- # 生成热力图(可选)
- for idx in sample_indices:
- image_path = datasetTest.listImagePaths[idx]
- output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
- h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
- h.generate(image_path, output_heatmap_path, imgtransCrop)
- print(f"热力图已保存到 {output_heatmap_path}")
- # --------------------------------------------------------------------------------
- # 演示函数,展示模型在测试集上的推理过程
- def runDemo():
- # 原有代码保持不变
- pass
- # 确保代码在主进程中运行
- if __name__ == '__main__':
- mp.set_start_method('spawn', force=True)
- main() # 启动主函数
|