|
@@ -0,0 +1,224 @@
|
|
|
+
|
|
|
+import os
|
|
|
+import time
|
|
|
+import torch.multiprocessing as mp
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+from ChexnetTrainer import ChexnetTrainer
|
|
|
+from DatasetGenerator import DatasetGenerator
|
|
|
+from visual import HeatmapGenerator
|
|
|
+
|
|
|
+# 确保 images 目录存在
|
|
|
+if not os.path.exists('images'):
|
|
|
+ os.makedirs('images')
|
|
|
+
|
|
|
+
|
|
|
+# 主函数,负责启动不同的功能(训练、测试或运行演示)
|
|
|
+def main():
|
|
|
+ # runDemo() # 运行演示模式
|
|
|
+ # runTest() # 测试模式(注释掉,可以通过解除注释来运行)
|
|
|
+ runTrain() # 训练模式(注释掉,可以通过解除注释来运行)
|
|
|
+
|
|
|
+
|
|
|
+# --------------------------------------------------------------------------------
|
|
|
+
|
|
|
+# 训练函数,定义训练所需的参数并启动模型训练
|
|
|
+def runTrain():
|
|
|
+ DENSENET121 = 'DENSE-NET-121' # 定义DenseNet121模型名称
|
|
|
+ DENSENET169 = 'DENSE-NET-169' # 定义DenseNet169模型名称
|
|
|
+ DENSENET201 = 'DENSE-NET-201' # 定义DenseNet201模型名称
|
|
|
+ Resnet50 = 'RESNET-50' # 定义Resnet50模型名称
|
|
|
+
|
|
|
+ # 获取当前的时间戳,作为训练过程的标记
|
|
|
+ timestampTime = time.strftime("%H%M%S")
|
|
|
+ timestampDate = time.strftime("%d%m%Y")
|
|
|
+ timestampLaunch = timestampDate + '-' + timestampTime
|
|
|
+ print("Launching " + timestampLaunch)
|
|
|
+
|
|
|
+ # 图像数据所在的路径
|
|
|
+ pathDirData = './chest xray14'
|
|
|
+
|
|
|
+ # 训练、验证和测试数据集文件路径
|
|
|
+ # 每个文件中包含图像路径及其对应的标签
|
|
|
+ pathFileTrain = './dataset/train_2.txt'
|
|
|
+ pathFileVal = './dataset/valid_2.txt'
|
|
|
+ pathFileTest = './dataset/test_2.txt'
|
|
|
+
|
|
|
+ # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
|
|
|
+ nnArchitecture = DENSENET121
|
|
|
+ nnIsTrained = True # 使用预训练的权重
|
|
|
+ nnClassCount = 14 # 数据集包含14个分类
|
|
|
+
|
|
|
+ # 训练参数:批量大小和最大迭代次数(epochs)
|
|
|
+ trBatchSize = 2
|
|
|
+ trMaxEpoch = 2
|
|
|
+
|
|
|
+ # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
|
|
|
+ imgtransResize = 256
|
|
|
+ imgtransCrop = 224
|
|
|
+
|
|
|
+ # 保存模型的路径,包含时间戳
|
|
|
+ pathModel = 'm-' + timestampLaunch + '.pth.tar'
|
|
|
+
|
|
|
+ print('Training NN architecture = ', nnArchitecture)
|
|
|
+ ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
|
|
|
+ nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
|
|
|
+ trMaxEpoch, imgtransResize, imgtransCrop,
|
|
|
+ timestampLaunch, None)
|
|
|
+ print('Testing the trained model')
|
|
|
+
|
|
|
+ pathRfModel = os.path.join('images', 'random_forest_model.pkl') # 更新路径
|
|
|
+ labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
|
|
|
+ nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
|
|
|
+ imgtransCrop, timestampLaunch)
|
|
|
+
|
|
|
+ # 生成并保存热力图
|
|
|
+ # 选择一些测试集中的样本来生成热力图
|
|
|
+ transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
|
|
|
+ datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
|
|
|
+ transform=transformSequence, model=None)
|
|
|
+
|
|
|
+ # 确保测试集中有足够的样本
|
|
|
+ num_samples = min(8, len(datasetTest))
|
|
|
+ sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本
|
|
|
+
|
|
|
+ # 定义分类名称
|
|
|
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
|
|
|
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
|
|
|
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
|
|
|
+ 'Fibrosis', 'Pleural_Thickening',
|
|
|
+ 'Hernia']
|
|
|
+
|
|
|
+ # 创建一个图形窗口
|
|
|
+ plt.figure(figsize=(20, 10))
|
|
|
+ n_cols = 4
|
|
|
+ n_rows = 2
|
|
|
+
|
|
|
+ for idx, sample_idx in enumerate(sample_indices):
|
|
|
+ image_path = datasetTest.listImagePaths[sample_idx]
|
|
|
+ true_labels = labels[sample_idx]
|
|
|
+ pred_labels = rf_preds[sample_idx]
|
|
|
+
|
|
|
+ # 加载图像
|
|
|
+ img = plt.imread(image_path)
|
|
|
+
|
|
|
+ # 创建子图
|
|
|
+ ax = plt.subplot(n_rows, n_cols, idx + 1)
|
|
|
+ ax.imshow(img)
|
|
|
+ ax.axis('off')
|
|
|
+
|
|
|
+ # 获取真实标签和预测标签的名称
|
|
|
+ true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
|
|
|
+ pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
|
|
|
+
|
|
|
+ # 设置标题
|
|
|
+ title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
|
|
|
+ ax.set_title(title, fontsize=10)
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+ output_plot_path = os.path.join('images', 'test_predictions.png')
|
|
|
+ plt.savefig(output_plot_path)
|
|
|
+ plt.show()
|
|
|
+ print(f"预测结果图已保存到 {output_plot_path}")
|
|
|
+
|
|
|
+ # 生成热力图(可选)
|
|
|
+ for idx in sample_indices:
|
|
|
+ image_path = datasetTest.listImagePaths[idx]
|
|
|
+ output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
|
|
|
+ h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
|
|
|
+ h.generate(image_path, output_heatmap_path, imgtransCrop)
|
|
|
+ print(f"热力图已保存到 {output_heatmap_path}")
|
|
|
+
|
|
|
+
|
|
|
+# --------------------------------------------------------------------------------
|
|
|
+
|
|
|
+# 测试函数,加载预训练模型并在测试数据集上进行测试
|
|
|
+def runTest():
|
|
|
+ pathDirData = '/chest xray14'
|
|
|
+ pathFileTest = './dataset/test.txt'
|
|
|
+ nnArchitecture = 'DENSE-NET-121'
|
|
|
+ nnIsTrained = True
|
|
|
+ nnClassCount = 14
|
|
|
+ trBatchSize = 4
|
|
|
+ imgtransResize = 256
|
|
|
+ imgtransCrop = 224
|
|
|
+
|
|
|
+ pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
|
|
|
+
|
|
|
+ timestampLaunch = ''
|
|
|
+
|
|
|
+ # 获取统一的 transformSequence
|
|
|
+ transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
|
|
|
+ pathRfModel = 'images/random_forest_model.pkl' # 确保路径正确
|
|
|
+ labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
|
|
|
+ nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
|
|
|
+ imgtransCrop, timestampLaunch)
|
|
|
+
|
|
|
+ # 生成并保存热力图
|
|
|
+ datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
|
|
|
+ transform=transformSequence, model=None)
|
|
|
+
|
|
|
+ # 确保测试集中有足够的样本
|
|
|
+ num_samples = min(8, len(datasetTest))
|
|
|
+ sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本
|
|
|
+
|
|
|
+ # 定义分类名称
|
|
|
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
|
|
|
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
|
|
|
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
|
|
|
+ 'Fibrosis', 'Pleural_Thickening',
|
|
|
+ 'Hernia']
|
|
|
+
|
|
|
+ # 创建一个图形窗口
|
|
|
+ plt.figure(figsize=(20, 10))
|
|
|
+ n_cols = 4
|
|
|
+ n_rows = 2
|
|
|
+
|
|
|
+ for idx, sample_idx in enumerate(sample_indices):
|
|
|
+ image_path = datasetTest.listImagePaths[sample_idx]
|
|
|
+ true_labels = labels[sample_idx]
|
|
|
+ pred_labels = rf_preds[sample_idx]
|
|
|
+
|
|
|
+ # 加载图像
|
|
|
+ img = plt.imread(image_path)
|
|
|
+
|
|
|
+ # 创建子图
|
|
|
+ ax = plt.subplot(n_rows, n_cols, idx + 1)
|
|
|
+ ax.imshow(img)
|
|
|
+ ax.axis('off')
|
|
|
+
|
|
|
+ # 获取真实标签和预测标签的名称
|
|
|
+ true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
|
|
|
+ pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
|
|
|
+
|
|
|
+ # 设置标题
|
|
|
+ title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
|
|
|
+ ax.set_title(title, fontsize=10)
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+ output_plot_path = os.path.join('images', 'test_predictions.png')
|
|
|
+ plt.savefig(output_plot_path)
|
|
|
+ plt.show()
|
|
|
+ print(f"预测结果图已保存到 {output_plot_path}")
|
|
|
+
|
|
|
+ # 生成热力图(可选)
|
|
|
+ for idx in sample_indices:
|
|
|
+ image_path = datasetTest.listImagePaths[idx]
|
|
|
+ output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
|
|
|
+ h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
|
|
|
+ h.generate(image_path, output_heatmap_path, imgtransCrop)
|
|
|
+ print(f"热力图已保存到 {output_heatmap_path}")
|
|
|
+
|
|
|
+
|
|
|
+# --------------------------------------------------------------------------------
|
|
|
+
|
|
|
+# 演示函数,展示模型在测试集上的推理过程
|
|
|
+def runDemo():
|
|
|
+ # 原有代码保持不变
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+# 确保代码在主进程中运行
|
|
|
+if __name__ == '__main__':
|
|
|
+ mp.set_start_method('spawn', force=True)
|
|
|
+ main() # 启动主函数
|