|
@@ -0,0 +1,354 @@
|
|
|
+import os
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.optim as optim
|
|
|
+from torchvision import transforms
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from sklearn.ensemble import RandomForestClassifier
|
|
|
+from sklearn.model_selection import train_test_split
|
|
|
+from sklearn.metrics import accuracy_score, f1_score
|
|
|
+from tqdm import tqdm
|
|
|
+import joblib # 导入joblib库用于保存模型
|
|
|
+from DatasetGenerator import DatasetGenerator
|
|
|
+from DensenetModels import DenseNet121, DenseNet169, DenseNet201, ResNet50, FocalLoss
|
|
|
+from sklearn.multioutput import MultiOutputClassifier
|
|
|
+from sklearn.metrics import roc_curve, auc
|
|
|
+
|
|
|
+
|
|
|
+class ChexnetTrainer:
|
|
|
+
|
|
|
+
|
|
|
+ # 在 ChexnetTrainer 类中添加以下方法
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
|
|
|
+ nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
|
|
|
+ transCrop, launchTimestamp, checkpoint=None):
|
|
|
+ """训练函数"""
|
|
|
+
|
|
|
+ # 自动选择设备
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+ # 获取模型并迁移到设备
|
|
|
+ model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
|
|
|
+
|
|
|
+ # 设置数据预处理和数据加载器
|
|
|
+ transformSequence = ChexnetTrainer._get_transform(transCrop)
|
|
|
+ datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTrain,
|
|
|
+ transform=transformSequence, model=model)
|
|
|
+ datasetVal = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileVal,
|
|
|
+ transform=transformSequence, model=model)
|
|
|
+
|
|
|
+ dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True, num_workers=4,
|
|
|
+ pin_memory=True)
|
|
|
+ dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4,
|
|
|
+ pin_memory=True)
|
|
|
+
|
|
|
+ optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
|
|
|
+
|
|
|
+ # 替换 ReduceLROnPlateau 为 StepLR
|
|
|
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
|
|
|
+
|
|
|
+ # 迁移class_weights到对应设备
|
|
|
+ class_weights = torch.tensor(
|
|
|
+ [1.1762, 0.6735, 0.9410, 1.6680, 0.9699, 1.1950, 2.2584, 0.6859, 1.6683, 0.7744, 0.4625, 0.7385, 1.3764,
|
|
|
+ 0.1758], dtype=torch.float32).to(device)
|
|
|
+
|
|
|
+ loss = FocalLoss(alpha=class_weights, gamma=2, logits=True)
|
|
|
+
|
|
|
+ lossMIN = 100000
|
|
|
+ train_f1_scores, val_f1_scores = [], []
|
|
|
+ all_val_targets = []
|
|
|
+ all_val_scores = []
|
|
|
+
|
|
|
+ # 训练循环
|
|
|
+ for epochID in range(trMaxEpoch):
|
|
|
+ ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss,
|
|
|
+ train_f1_scores, device)
|
|
|
+
|
|
|
+ lossVal, losstensor, val_f1, val_targets, val_scores = ChexnetTrainer.epochVal(
|
|
|
+ model, dataLoaderVal, optimizer, scheduler,
|
|
|
+ trMaxEpoch, nnClassCount, loss, val_f1_scores, device)
|
|
|
+
|
|
|
+ all_val_targets.append(val_targets)
|
|
|
+ all_val_scores.append(val_scores)
|
|
|
+
|
|
|
+ # 更新学习率
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ # 保存最佳模型
|
|
|
+ if lossVal < lossMIN:
|
|
|
+ lossMIN = lossVal
|
|
|
+ torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
|
|
|
+ 'optimizer': optimizer.state_dict()},
|
|
|
+ 'm-' + launchTimestamp + '.pth.tar')
|
|
|
+
|
|
|
+ # 合并所有验证集的标签和预测
|
|
|
+ all_val_targets = np.vstack(all_val_targets)
|
|
|
+ all_val_scores = np.vstack(all_val_scores)
|
|
|
+
|
|
|
+ # 绘制F1分数图
|
|
|
+ plt.figure()
|
|
|
+ plt.plot(train_f1_scores, label="Train F1-Score")
|
|
|
+ plt.plot(val_f1_scores, label="Val F1-Score")
|
|
|
+ plt.xlabel("Epoch")
|
|
|
+ plt.ylabel("F1 Score")
|
|
|
+ plt.title("F1 Score per Epoch")
|
|
|
+ plt.legend()
|
|
|
+ plt.savefig(os.path.join('images', 'f1_scores.png'))
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ # 计算并绘制AUC-ROC曲线
|
|
|
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
|
|
|
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
|
|
|
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
|
|
|
+ 'Fibrosis', 'Pleural_Thickening',
|
|
|
+ 'Hernia']
|
|
|
+
|
|
|
+ ChexnetTrainer.plot_auc_roc(all_val_targets, all_val_scores, CLASS_NAMES,
|
|
|
+ os.path.join('images', 'auc_roc_curve.png'))
|
|
|
+ print("AUC-ROC曲线已保存到 images/auc_roc_curve.png")
|
|
|
+
|
|
|
+ # 提取特征并训练随机森林
|
|
|
+ rf_classifier = ChexnetTrainer.train_random_forest(datasetTrain, datasetVal, model, device)
|
|
|
+
|
|
|
+ # 保存随机森林模型
|
|
|
+ joblib.dump(rf_classifier, os.path.join('images', 'random_forest_model.pkl'))
|
|
|
+ # 保存随机森林分类器
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def epochTrain(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
|
|
|
+ model.train()
|
|
|
+ all_targets = []
|
|
|
+ all_preds = []
|
|
|
+
|
|
|
+ # 创建进度条
|
|
|
+ pbar = tqdm(total=len(dataLoader), desc="Training", leave=True)
|
|
|
+
|
|
|
+ for batchID, (input, target) in enumerate(dataLoader):
|
|
|
+ target = target.to(device)
|
|
|
+ input = input.to(device)
|
|
|
+
|
|
|
+ varInput = torch.autograd.Variable(input)
|
|
|
+ varTarget = torch.autograd.Variable(target)
|
|
|
+ varOutput = model(varInput)
|
|
|
+
|
|
|
+ lossvalue = loss(varOutput, varTarget)
|
|
|
+ optimizer.zero_grad()
|
|
|
+ lossvalue.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
|
|
|
+ all_targets.extend(target.cpu().numpy())
|
|
|
+ all_preds.extend(pred)
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.close()
|
|
|
+
|
|
|
+ f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
|
|
|
+
|
|
|
+ f1_scores.append(f1)
|
|
|
+ # 在每个epoch结束时打印当前的F1-score
|
|
|
+ print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def epochVal(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
|
|
|
+ model.eval()
|
|
|
+ lossVal = 0
|
|
|
+ lossValNorm = 0
|
|
|
+ losstensorMean = 0
|
|
|
+ all_targets = []
|
|
|
+ all_preds = []
|
|
|
+ all_scores = [] # 收集预测概率
|
|
|
+
|
|
|
+ # 创建进度条
|
|
|
+ pbar = tqdm(total=len(dataLoader), desc="Validation", leave=True)
|
|
|
+
|
|
|
+ for i, (input, target) in enumerate(dataLoader):
|
|
|
+ target = target.to(device)
|
|
|
+ input = input.to(device)
|
|
|
+ with torch.no_grad():
|
|
|
+ varInput = torch.autograd.Variable(input)
|
|
|
+ varTarget = torch.autograd.Variable(target)
|
|
|
+ varOutput = model(varInput)
|
|
|
+
|
|
|
+ losstensor = loss(varOutput, varTarget)
|
|
|
+ losstensorMean += losstensor
|
|
|
+ lossVal += losstensor.item()
|
|
|
+ lossValNorm += 1
|
|
|
+
|
|
|
+ scores = torch.sigmoid(varOutput).cpu().data.numpy()
|
|
|
+ all_scores.extend(scores)
|
|
|
+ pred = scores > 0.5
|
|
|
+ all_targets.extend(target.cpu().numpy())
|
|
|
+ all_preds.extend(pred)
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.close()
|
|
|
+
|
|
|
+ f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
|
|
|
+ f1_scores.append(f1)
|
|
|
+ # 在每个epoch结束时打印当前的F1-score
|
|
|
+ print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
|
|
|
+
|
|
|
+ return lossVal / lossValNorm, losstensorMean / lossValNorm, f1, np.array(all_targets), np.array(all_scores)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def train_random_forest(datasetTrain, datasetVal, model, device):
|
|
|
+ """训练随机森林分类器,适应多标签分类"""
|
|
|
+
|
|
|
+ print("正在提取训练数据的特征。")
|
|
|
+ # 提取训练数据的特征
|
|
|
+ train_features, train_labels = ChexnetTrainer.extract_features(datasetTrain, model, device)
|
|
|
+ val_features, val_labels = ChexnetTrainer.extract_features(datasetVal, model, device)
|
|
|
+
|
|
|
+ print("正在训练随机森林分类器。")
|
|
|
+ # 使用MultiOutputClassifier来处理多标签问题
|
|
|
+ rf_classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100))
|
|
|
+ rf_classifier.fit(train_features, train_labels)
|
|
|
+
|
|
|
+ print("正在评估随机森林分类器。")
|
|
|
+ # 在验证集上评估随机森林
|
|
|
+ val_preds = rf_classifier.predict(val_features)
|
|
|
+ val_f1 = f1_score(val_labels, val_preds, average='macro')
|
|
|
+
|
|
|
+ print(f"Random Forest F1 Score on validation set: {val_f1}")
|
|
|
+
|
|
|
+ return rf_classifier
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def extract_features(dataset, model, device):
|
|
|
+ """提取数据集的特征"""
|
|
|
+ features = []
|
|
|
+ labels = []
|
|
|
+
|
|
|
+ print("正在提取数据集的特征。")
|
|
|
+ model.eval()
|
|
|
+ dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
|
|
|
+ for input, target in tqdm(dataLoader, desc="Extracting Features"):
|
|
|
+ input = input.to(device)
|
|
|
+ target = target.to(device)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ varInput = torch.autograd.Variable(input)
|
|
|
+ varOutput = model(varInput)
|
|
|
+
|
|
|
+ # 这里假设 varOutput 是一个二维张量,我们需要把它展平成一维向量
|
|
|
+ features.append(varOutput.view(varOutput.size(0), -1).cpu().data.numpy()) # 展平特征
|
|
|
+ labels.append(target.cpu().data.numpy()) # 假设 target 是多标签格式的
|
|
|
+
|
|
|
+ # 使用 np.vstack 将特征和标签拼接成数组
|
|
|
+ return np.vstack(features), np.vstack(labels) # 确保标签是二维矩阵,每一行对应一个样本的标签
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def plot_auc_roc(y_true, y_scores, class_names, save_path):
|
|
|
+ """绘制AUC-ROC曲线并保存"""
|
|
|
+ plt.figure(figsize=(20, 15))
|
|
|
+ for i, class_name in enumerate(class_names):
|
|
|
+ fpr, tpr, _ = roc_curve(y_true[:, i], y_scores[:, i])
|
|
|
+ roc_auc = auc(fpr, tpr)
|
|
|
+ plt.plot(fpr, tpr, lw=2, label=f'ROC curve of class {class_name} (area = {roc_auc:.2f})')
|
|
|
+
|
|
|
+ plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
|
|
+ plt.xlim([0.0, 1.0])
|
|
|
+ plt.ylim([0.0, 1.05])
|
|
|
+ plt.xlabel('False Positive Rate')
|
|
|
+ plt.ylabel('True Positive Rate')
|
|
|
+ plt.title('Receiver Operating Characteristic (ROC) Curves')
|
|
|
+ plt.legend(loc="lower right")
|
|
|
+ plt.savefig(save_path)
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained,
|
|
|
+ trBatchSize, imgtransResize, imgtransCrop, timestampLaunch):
|
|
|
+ """测试函数,支持深度学习模型和随机森林模型"""
|
|
|
+ # 加载深度学习模型
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+ model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
|
|
|
+ checkpoint = torch.load(pathModel, map_location=device)
|
|
|
+ model.load_state_dict(checkpoint['state_dict'])
|
|
|
+ model.eval() # 确保模型处于评估模式
|
|
|
+
|
|
|
+ # 加载随机森林模型
|
|
|
+ rf_classifier = joblib.load(pathRfModel)
|
|
|
+
|
|
|
+ # 获取 transform
|
|
|
+ transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
|
|
|
+ datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
|
|
|
+ transform=transformSequence, model=model)
|
|
|
+
|
|
|
+ # 使用 extract_features 提取特征和标签
|
|
|
+ features, labels = ChexnetTrainer.extract_features(datasetTest, model, device)
|
|
|
+
|
|
|
+ # 使用随机森林模型进行预测
|
|
|
+ rf_preds = rf_classifier.predict(features)
|
|
|
+
|
|
|
+ # 计算多标签分类的 F1 分数
|
|
|
+ rf_f1 = f1_score(labels, rf_preds, average="macro")
|
|
|
+ print(f"Random Forest Multi-label F1 Score on test set: {rf_f1}")
|
|
|
+
|
|
|
+ # 计算AUC-ROC
|
|
|
+ rf_scores = []
|
|
|
+ if hasattr(rf_classifier, "predict_proba"):
|
|
|
+ rf_scores = rf_classifier.predict_proba(features)
|
|
|
+ # 将列表转换为数组,并选择正类的概率
|
|
|
+ rf_scores = np.array([rf_scores[i][:, 1] for i in range(len(rf_scores))]).reshape(labels.shape)
|
|
|
+ else:
|
|
|
+ # 如果没有predict_proba方法,则无法计算AUC
|
|
|
+ print("随机森林模型不支持predict_proba方法,无法计算AUC-ROC。")
|
|
|
+ rf_scores = np.zeros_like(labels)
|
|
|
+
|
|
|
+ # 绘制AUC-ROC曲线
|
|
|
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
|
|
|
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
|
|
|
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
|
|
|
+ 'Fibrosis', 'Pleural_Thickening',
|
|
|
+ 'Hernia']
|
|
|
+
|
|
|
+ ChexnetTrainer.plot_auc_roc(labels, rf_scores, CLASS_NAMES,
|
|
|
+ os.path.join('images', 'test_auc_roc_curve.png'))
|
|
|
+ print("测试集的AUC-ROC曲线已保存到 images/test_auc_roc_curve.png")
|
|
|
+
|
|
|
+ # 输出深度学习模型和随机森林模型的 F1 分数
|
|
|
+ print(f"Testing completed. Random Forest F1 Score: {rf_f1}")
|
|
|
+
|
|
|
+ return labels, rf_preds # 返回真实标签和预测结果
|
|
|
+
|
|
|
+ # 其他方法保持不变...
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_model(nnArchitecture, nnClassCount, nnIsTrained):
|
|
|
+ """根据选择的模型架构返回对应的模型"""
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+ # 根据模型架构选择模型
|
|
|
+ if nnArchitecture == 'DENSE-NET-121':
|
|
|
+ model = DenseNet121(nnClassCount, nnIsTrained).to(device)
|
|
|
+ elif nnArchitecture == 'DENSE-NET-169':
|
|
|
+ model = DenseNet169(nnClassCount, nnIsTrained).to(device)
|
|
|
+ elif nnArchitecture == 'DENSE-NET-201':
|
|
|
+ model = DenseNet201(nnClassCount, nnIsTrained).to(device)
|
|
|
+ elif nnArchitecture == 'RESNET-50':
|
|
|
+ model = ResNet50(nnClassCount, nnIsTrained).to(device)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Unknown architecture: {nnArchitecture}. Please choose from 'DENSE-NET-121', 'DENSE-NET-169', 'DENSE-NET-201', 'RESNET-50'.")
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_transform(transCrop):
|
|
|
+ """返回图像预处理的转换"""
|
|
|
+ normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ transformList = [
|
|
|
+ transforms.RandomResizedCrop(transCrop),
|
|
|
+ transforms.RandomHorizontalFlip(),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ normalize
|
|
|
+ ]
|
|
|
+ return transforms.Compose(transformList)
|