+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import transforms
+import matplotlib.pyplot as plt
+from torch.utils.data import DataLoader
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import accuracy_score, f1_score
+from tqdm import tqdm
+import joblib
+from DatasetGenerator import DatasetGenerator
+from DensenetModels import DenseNet121, DenseNet169, DenseNet201, ResNet50, FocalLoss
+from sklearn.multioutput import MultiOutputClassifier
+from sklearn.metrics import roc_curve, auc
+class ChexnetTrainer:
+ @staticmethod
+ def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
+ nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
+ transCrop, launchTimestamp, checkpoint=None):
+ """训练函数"""
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
+ transformSequence = ChexnetTrainer._get_transform(transCrop)
+ datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTrain,
+ transform=transformSequence, model=model)
+ datasetVal = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileVal,
+ transform=transformSequence, model=model)
+ dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True, num_workers=4,
+ pin_memory=True)
+ dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4,
+ pin_memory=True)
+ optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
+ class_weights = torch.tensor(
+ [1.1762, 0.6735, 0.9410, 1.6680, 0.9699, 1.1950, 2.2584, 0.6859, 1.6683, 0.7744, 0.4625, 0.7385, 1.3764,
+ 0.1758], dtype=torch.float32).to(device)
+ loss = FocalLoss(alpha=class_weights, gamma=2, logits=True)
+ lossMIN = 100000
+ train_f1_scores, val_f1_scores = [], []
+ all_val_targets = []
+ all_val_scores = []
+ for epochID in range(trMaxEpoch):
+ ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss,
+ train_f1_scores, device)
+ lossVal, losstensor, val_f1, val_targets, val_scores = ChexnetTrainer.epochVal(
+ model, dataLoaderVal, optimizer, scheduler,
+ trMaxEpoch, nnClassCount, loss, val_f1_scores, device)
+ all_val_targets.append(val_targets)
+ all_val_scores.append(val_scores)
+ scheduler.step()
+ if lossVal < lossMIN:
+ lossMIN = lossVal
+ torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
+ 'optimizer': optimizer.state_dict()},
+ 'm-' + launchTimestamp + '.pth.tar')
+ all_val_targets = np.vstack(all_val_targets)
+ all_val_scores = np.vstack(all_val_scores)
+ plt.figure()
+ plt.plot(train_f1_scores, label="Train F1-Score")
+ plt.plot(val_f1_scores, label="Val F1-Score")
+ plt.xlabel("Epoch")
+ plt.ylabel("F1 Score")
+ plt.title("F1 Score per Epoch")
+ plt.legend()
+ plt.savefig(os.path.join('images', 'f1_scores.png'))
+ plt.close()
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
+ 'Fibrosis', 'Pleural_Thickening',
+ 'Hernia']
+ ChexnetTrainer.plot_auc_roc(all_val_targets, all_val_scores, CLASS_NAMES,
+ os.path.join('images', 'auc_roc_curve.png'))
+ print("AUC-ROC曲线已保存到 images/auc_roc_curve.png")
+ rf_classifier = ChexnetTrainer.train_random_forest(datasetTrain, datasetVal, model, device)
+ joblib.dump(rf_classifier, os.path.join('images', 'random_forest_model.pkl'))
+ @staticmethod
+ def epochTrain(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
+ model.train()
+ all_targets = []
+ all_preds = []
+ pbar = tqdm(total=len(dataLoader), desc="Training", leave=True)
+ for batchID, (input, target) in enumerate(dataLoader):
+ target = target.to(device)
+ input = input.to(device)
+ varInput = torch.autograd.Variable(input)
+ varTarget = torch.autograd.Variable(target)
+ varOutput = model(varInput)
+ lossvalue = loss(varOutput, varTarget)
+ optimizer.zero_grad()
+ lossvalue.backward()
+ optimizer.step()
+ pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
+ all_targets.extend(target.cpu().numpy())
+ all_preds.extend(pred)
+ pbar.update(1)
+ pbar.close()
+ f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
+ f1_scores.append(f1)
+ print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
+ @staticmethod
+ def epochVal(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
+ model.eval()
+ lossVal = 0
+ lossValNorm = 0
+ losstensorMean = 0
+ all_targets = []
+ all_preds = []
+ all_scores = []
+ pbar = tqdm(total=len(dataLoader), desc="Validation", leave=True)
+ for i, (input, target) in enumerate(dataLoader):
+ target = target.to(device)
+ input = input.to(device)
+ with torch.no_grad():
+ varInput = torch.autograd.Variable(input)
+ varTarget = torch.autograd.Variable(target)
+ varOutput = model(varInput)
+ losstensor = loss(varOutput, varTarget)
+ losstensorMean += losstensor
+ lossVal += losstensor.item()
+ lossValNorm += 1
+ scores = torch.sigmoid(varOutput).cpu().data.numpy()
+ all_scores.extend(scores)
+ pred = scores > 0.5
+ all_targets.extend(target.cpu().numpy())
+ all_preds.extend(pred)
+ pbar.update(1)
+ pbar.close()
+ f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
+ f1_scores.append(f1)
+ print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
+ return lossVal / lossValNorm, losstensorMean / lossValNorm, f1, np.array(all_targets), np.array(all_scores)
+ @staticmethod
+ def train_random_forest(datasetTrain, datasetVal, model, device):
+ """训练随机森林分类器,适应多标签分类"""
+ print("正在提取训练数据的特征。")
+ train_features, train_labels = ChexnetTrainer.extract_features(datasetTrain, model, device)
+ val_features, val_labels = ChexnetTrainer.extract_features(datasetVal, model, device)
+ print("正在训练随机森林分类器。")
+ rf_classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100))
+ rf_classifier.fit(train_features, train_labels)
+ print("正在评估随机森林分类器。")
+ val_preds = rf_classifier.predict(val_features)
+ val_f1 = f1_score(val_labels, val_preds, average='macro')
+ print(f"Random Forest F1 Score on validation set: {val_f1}")
+ return rf_classifier
+ @staticmethod
+ def extract_features(dataset, model, device):
+ """提取数据集的特征"""
+ features = []
+ labels = []
+ print("正在提取数据集的特征。")
+ model.eval()
+ dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
+ for input, target in tqdm(dataLoader, desc="Extracting Features"):
+ input = input.to(device)
+ target = target.to(device)
+ with torch.no_grad():
+ varInput = torch.autograd.Variable(input)
+ varOutput = model(varInput)
+ features.append(varOutput.view(varOutput.size(0), -1).cpu().data.numpy())
+ labels.append(target.cpu().data.numpy())
+ return np.vstack(features), np.vstack(labels)
+ @staticmethod
+ def plot_auc_roc(y_true, y_scores, class_names, save_path):
+ """绘制AUC-ROC曲线并保存"""
+ plt.figure(figsize=(20, 15))
+ for i, class_name in enumerate(class_names):
+ fpr, tpr, _ = roc_curve(y_true[:, i], y_scores[:, i])
+ roc_auc = auc(fpr, tpr)
+ plt.plot(fpr, tpr, lw=2, label=f'ROC curve of class {class_name} (area = {roc_auc:.2f})')
+ plt.plot([0, 1], [0, 1], 'k--', lw=2)
+ plt.xlim([0.0, 1.0])
+ plt.ylim([0.0, 1.05])
+ plt.xlabel('False Positive Rate')
+ plt.ylabel('True Positive Rate')
+ plt.title('Receiver Operating Characteristic (ROC) Curves')
+ plt.legend(loc="lower right")
+ plt.savefig(save_path)
+ plt.close()
+ @staticmethod
+ def test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained,
+ trBatchSize, imgtransResize, imgtransCrop, timestampLaunch):
+ """测试函数,支持深度学习模型和随机森林模型"""
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
+ checkpoint = torch.load(pathModel, map_location=device)
+ model.load_state_dict(checkpoint['state_dict'])
+ model.eval()
+ rf_classifier = joblib.load(pathRfModel)
+ transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
+ datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
+ transform=transformSequence, model=model)
+ features, labels = ChexnetTrainer.extract_features(datasetTest, model, device)
+ rf_preds = rf_classifier.predict(features)
+ rf_f1 = f1_score(labels, rf_preds, average="macro")
+ print(f"Random Forest Multi-label F1 Score on test set: {rf_f1}")
+ rf_scores = []
+ if hasattr(rf_classifier, "predict_proba"):
+ rf_scores = rf_classifier.predict_proba(features)
+ rf_scores = np.array([rf_scores[i][:, 1] for i in range(len(rf_scores))]).reshape(labels.shape)
+ else:
+ print("随机森林模型不支持predict_proba方法,无法计算AUC-ROC。")
+ rf_scores = np.zeros_like(labels)
+ CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
+ 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
+ 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
+ 'Fibrosis', 'Pleural_Thickening',
+ 'Hernia']
+ ChexnetTrainer.plot_auc_roc(labels, rf_scores, CLASS_NAMES,
+ os.path.join('images', 'test_auc_roc_curve.png'))
+ print("测试集的AUC-ROC曲线已保存到 images/test_auc_roc_curve.png")
+ print(f"Testing completed. Random Forest F1 Score: {rf_f1}")
+ return labels, rf_preds
+ @staticmethod
+ def _get_model(nnArchitecture, nnClassCount, nnIsTrained):
+ """根据选择的模型架构返回对应的模型"""
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if nnArchitecture == 'DENSE-NET-121':
+ model = DenseNet121(nnClassCount, nnIsTrained).to(device)
+ elif nnArchitecture == 'DENSE-NET-169':
+ model = DenseNet169(nnClassCount, nnIsTrained).to(device)
+ elif nnArchitecture == 'DENSE-NET-201':
+ model = DenseNet201(nnClassCount, nnIsTrained).to(device)
+ elif nnArchitecture == 'RESNET-50':
+ model = ResNet50(nnClassCount, nnIsTrained).to(device)
+ else:
+ raise ValueError(
+ f"Unknown architecture: {nnArchitecture}. Please choose from 'DENSE-NET-121', 'DENSE-NET-169', 'DENSE-NET-201', 'RESNET-50'.")
+ return model
+ @staticmethod
+ def _get_transform(transCrop):
+ """返回图像预处理的转换"""
+ normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ transformList = [
+ transforms.RandomResizedCrop(transCrop),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize
+ ]
+ return transforms.Compose(transformList)