123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- import os
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torchvision import transforms
- import matplotlib.pyplot as plt
- from torch.utils.data import DataLoader
- from sklearn.ensemble import RandomForestClassifier
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import accuracy_score, f1_score
- from tqdm import tqdm
- import joblib # 导入joblib库用于保存模型
- from DatasetGenerator import DatasetGenerator
- from DensenetModels import DenseNet121, DenseNet169, DenseNet201, ResNet50, FocalLoss
- from sklearn.multioutput import MultiOutputClassifier
- from sklearn.metrics import roc_curve, auc
- class ChexnetTrainer:
- # 在 ChexnetTrainer 类中添加以下方法
- @staticmethod
- def train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture,
- nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize,
- transCrop, launchTimestamp, checkpoint=None):
- """训练函数"""
- # 自动选择设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 获取模型并迁移到设备
- model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
- # 设置数据预处理和数据加载器
- transformSequence = ChexnetTrainer._get_transform(transCrop)
- datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTrain,
- transform=transformSequence, model=model)
- datasetVal = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileVal,
- transform=transformSequence, model=model)
- dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True, num_workers=4,
- pin_memory=True)
- dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=4,
- pin_memory=True)
- optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
- # 替换 ReduceLROnPlateau 为 StepLR
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
- # 迁移class_weights到对应设备
- class_weights = torch.tensor(
- [1.1762, 0.6735, 0.9410, 1.6680, 0.9699, 1.1950, 2.2584, 0.6859, 1.6683, 0.7744, 0.4625, 0.7385, 1.3764,
- 0.1758], dtype=torch.float32).to(device)
- loss = FocalLoss(alpha=class_weights, gamma=2, logits=True)
- lossMIN = 100000
- train_f1_scores, val_f1_scores = [], []
- all_val_targets = []
- all_val_scores = []
- # 训练循环
- for epochID in range(trMaxEpoch):
- ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss,
- train_f1_scores, device)
- lossVal, losstensor, val_f1, val_targets, val_scores = ChexnetTrainer.epochVal(
- model, dataLoaderVal, optimizer, scheduler,
- trMaxEpoch, nnClassCount, loss, val_f1_scores, device)
- all_val_targets.append(val_targets)
- all_val_scores.append(val_scores)
- # 更新学习率
- scheduler.step()
- # 保存最佳模型
- if lossVal < lossMIN:
- lossMIN = lossVal
- torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
- 'optimizer': optimizer.state_dict()},
- 'm-' + launchTimestamp + '.pth.tar')
- # 合并所有验证集的标签和预测
- all_val_targets = np.vstack(all_val_targets)
- all_val_scores = np.vstack(all_val_scores)
- # 绘制F1分数图
- plt.figure()
- plt.plot(train_f1_scores, label="Train F1-Score")
- plt.plot(val_f1_scores, label="Val F1-Score")
- plt.xlabel("Epoch")
- plt.ylabel("F1 Score")
- plt.title("F1 Score per Epoch")
- plt.legend()
- plt.savefig(os.path.join('images', 'f1_scores.png'))
- plt.close()
- # 计算并绘制AUC-ROC曲线
- CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
- 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening',
- 'Hernia']
- ChexnetTrainer.plot_auc_roc(all_val_targets, all_val_scores, CLASS_NAMES,
- os.path.join('images', 'auc_roc_curve.png'))
- print("AUC-ROC曲线已保存到 images/auc_roc_curve.png")
- # 提取特征并训练随机森林
- rf_classifier = ChexnetTrainer.train_random_forest(datasetTrain, datasetVal, model, device)
- # 保存随机森林模型
- joblib.dump(rf_classifier, os.path.join('images', 'random_forest_model.pkl'))
- # 保存随机森林分类器
- @staticmethod
- def epochTrain(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
- model.train()
- all_targets = []
- all_preds = []
- # 创建进度条
- pbar = tqdm(total=len(dataLoader), desc="Training", leave=True)
- for batchID, (input, target) in enumerate(dataLoader):
- target = target.to(device)
- input = input.to(device)
- varInput = torch.autograd.Variable(input)
- varTarget = torch.autograd.Variable(target)
- varOutput = model(varInput)
- lossvalue = loss(varOutput, varTarget)
- optimizer.zero_grad()
- lossvalue.backward()
- optimizer.step()
- pred = torch.sigmoid(varOutput).cpu().data.numpy() > 0.5
- all_targets.extend(target.cpu().numpy())
- all_preds.extend(pred)
- # 更新进度条
- pbar.update(1)
- pbar.close()
- f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
- f1_scores.append(f1)
- # 在每个epoch结束时打印当前的F1-score
- print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
- @staticmethod
- def epochVal(model, dataLoader, optimizer, scheduler, epochMax, classCount, loss, f1_scores, device):
- model.eval()
- lossVal = 0
- lossValNorm = 0
- losstensorMean = 0
- all_targets = []
- all_preds = []
- all_scores = [] # 收集预测概率
- # 创建进度条
- pbar = tqdm(total=len(dataLoader), desc="Validation", leave=True)
- for i, (input, target) in enumerate(dataLoader):
- target = target.to(device)
- input = input.to(device)
- with torch.no_grad():
- varInput = torch.autograd.Variable(input)
- varTarget = torch.autograd.Variable(target)
- varOutput = model(varInput)
- losstensor = loss(varOutput, varTarget)
- losstensorMean += losstensor
- lossVal += losstensor.item()
- lossValNorm += 1
- scores = torch.sigmoid(varOutput).cpu().data.numpy()
- all_scores.extend(scores)
- pred = scores > 0.5
- all_targets.extend(target.cpu().numpy())
- all_preds.extend(pred)
- # 更新进度条
- pbar.update(1)
- pbar.close()
- f1 = f1_score(np.array(all_targets), np.array(all_preds), average="macro")
- f1_scores.append(f1)
- # 在每个epoch结束时打印当前的F1-score
- print(f"Epoch completed. F1 Score (Macro): {f1:.4f}")
- return lossVal / lossValNorm, losstensorMean / lossValNorm, f1, np.array(all_targets), np.array(all_scores)
- @staticmethod
- def train_random_forest(datasetTrain, datasetVal, model, device):
- """训练随机森林分类器,适应多标签分类"""
- print("正在提取训练数据的特征。")
- # 提取训练数据的特征
- train_features, train_labels = ChexnetTrainer.extract_features(datasetTrain, model, device)
- val_features, val_labels = ChexnetTrainer.extract_features(datasetVal, model, device)
- print("正在训练随机森林分类器。")
- # 使用MultiOutputClassifier来处理多标签问题
- rf_classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100))
- rf_classifier.fit(train_features, train_labels)
- print("正在评估随机森林分类器。")
- # 在验证集上评估随机森林
- val_preds = rf_classifier.predict(val_features)
- val_f1 = f1_score(val_labels, val_preds, average='macro')
- print(f"Random Forest F1 Score on validation set: {val_f1}")
- return rf_classifier
- @staticmethod
- def extract_features(dataset, model, device):
- """提取数据集的特征"""
- features = []
- labels = []
- print("正在提取数据集的特征。")
- model.eval()
- dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
- for input, target in tqdm(dataLoader, desc="Extracting Features"):
- input = input.to(device)
- target = target.to(device)
- with torch.no_grad():
- varInput = torch.autograd.Variable(input)
- varOutput = model(varInput)
- # 这里假设 varOutput 是一个二维张量,我们需要把它展平成一维向量
- features.append(varOutput.view(varOutput.size(0), -1).cpu().data.numpy()) # 展平特征
- labels.append(target.cpu().data.numpy()) # 假设 target 是多标签格式的
- # 使用 np.vstack 将特征和标签拼接成数组
- return np.vstack(features), np.vstack(labels) # 确保标签是二维矩阵,每一行对应一个样本的标签
- @staticmethod
- def plot_auc_roc(y_true, y_scores, class_names, save_path):
- """绘制AUC-ROC曲线并保存"""
- plt.figure(figsize=(20, 15))
- for i, class_name in enumerate(class_names):
- fpr, tpr, _ = roc_curve(y_true[:, i], y_scores[:, i])
- roc_auc = auc(fpr, tpr)
- plt.plot(fpr, tpr, lw=2, label=f'ROC curve of class {class_name} (area = {roc_auc:.2f})')
- plt.plot([0, 1], [0, 1], 'k--', lw=2)
- plt.xlim([0.0, 1.0])
- plt.ylim([0.0, 1.05])
- plt.xlabel('False Positive Rate')
- plt.ylabel('True Positive Rate')
- plt.title('Receiver Operating Characteristic (ROC) Curves')
- plt.legend(loc="lower right")
- plt.savefig(save_path)
- plt.close()
- @staticmethod
- def test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained,
- trBatchSize, imgtransResize, imgtransCrop, timestampLaunch):
- """测试函数,支持深度学习模型和随机森林模型"""
- # 加载深度学习模型
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = ChexnetTrainer._get_model(nnArchitecture, nnClassCount, nnIsTrained).to(device)
- checkpoint = torch.load(pathModel, map_location=device)
- model.load_state_dict(checkpoint['state_dict'])
- model.eval() # 确保模型处于评估模式
- # 加载随机森林模型
- rf_classifier = joblib.load(pathRfModel)
- # 获取 transform
- transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
- datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
- transform=transformSequence, model=model)
- # 使用 extract_features 提取特征和标签
- features, labels = ChexnetTrainer.extract_features(datasetTest, model, device)
- # 使用随机森林模型进行预测
- rf_preds = rf_classifier.predict(features)
- # 计算多标签分类的 F1 分数
- rf_f1 = f1_score(labels, rf_preds, average="macro")
- print(f"Random Forest Multi-label F1 Score on test set: {rf_f1}")
- # 计算AUC-ROC
- rf_scores = []
- if hasattr(rf_classifier, "predict_proba"):
- rf_scores = rf_classifier.predict_proba(features)
- # 将列表转换为数组,并选择正类的概率
- rf_scores = np.array([rf_scores[i][:, 1] for i in range(len(rf_scores))]).reshape(labels.shape)
- else:
- # 如果没有predict_proba方法,则无法计算AUC
- print("随机森林模型不支持predict_proba方法,无法计算AUC-ROC。")
- rf_scores = np.zeros_like(labels)
- # 绘制AUC-ROC曲线
- CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
- 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
- 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
- 'Fibrosis', 'Pleural_Thickening',
- 'Hernia']
- ChexnetTrainer.plot_auc_roc(labels, rf_scores, CLASS_NAMES,
- os.path.join('images', 'test_auc_roc_curve.png'))
- print("测试集的AUC-ROC曲线已保存到 images/test_auc_roc_curve.png")
- # 输出深度学习模型和随机森林模型的 F1 分数
- print(f"Testing completed. Random Forest F1 Score: {rf_f1}")
- return labels, rf_preds # 返回真实标签和预测结果
- # 其他方法保持不变...
- @staticmethod
- def _get_model(nnArchitecture, nnClassCount, nnIsTrained):
- """根据选择的模型架构返回对应的模型"""
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 根据模型架构选择模型
- if nnArchitecture == 'DENSE-NET-121':
- model = DenseNet121(nnClassCount, nnIsTrained).to(device)
- elif nnArchitecture == 'DENSE-NET-169':
- model = DenseNet169(nnClassCount, nnIsTrained).to(device)
- elif nnArchitecture == 'DENSE-NET-201':
- model = DenseNet201(nnClassCount, nnIsTrained).to(device)
- elif nnArchitecture == 'RESNET-50':
- model = ResNet50(nnClassCount, nnIsTrained).to(device)
- else:
- raise ValueError(
- f"Unknown architecture: {nnArchitecture}. Please choose from 'DENSE-NET-121', 'DENSE-NET-169', 'DENSE-NET-201', 'RESNET-50'.")
- return model
- @staticmethod
- def _get_transform(transCrop):
- """返回图像预处理的转换"""
- normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- transformList = [
- transforms.RandomResizedCrop(transCrop),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- normalize
- ]
- return transforms.Compose(transformList)
|