Browse Source

上传文件至 ''

Gorilla 2 months ago
parent
commit
c6ee4e26d1
5 changed files with 877 additions and 0 deletions
  1. 55 0
      DatasetGenerator.py
  2. 200 0
      DensenetModels(2)(4).py
  3. 200 0
      DensenetModels.py
  4. 198 0
      Main(2)(2).py
  5. 224 0
      Main.py

+ 55 - 0
DatasetGenerator.py

@@ -0,0 +1,55 @@
+import os
+import numpy as np
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+
+
+class DatasetGenerator(Dataset):
+    def __init__(self, pathImageDirectory, pathDatasetFile, transform, model=None, nnClassCount=14):
+        self.listImagePaths = []
+        self.listImageLabels = []
+        self.transform = transform
+        self.model = model
+
+        # 检查设备
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if self.model:
+            self.model.to(self.device)
+
+        # 读取数据集文件
+        with open(pathDatasetFile, "r") as fileDescriptor:
+            lines = fileDescriptor.readlines()
+
+        # 遍历文件,筛选有效路径
+        for line in lines:
+            lineItems = line.split()
+            # 使用 os.path.join 来确保路径的正确拼接
+            imagePath = os.path.normpath(os.path.join(pathImageDirectory, lineItems[0]))
+
+            # 检查路径是否存在,并且是一个有效文件
+            if not os.path.isfile(imagePath):
+                print(f"Warning: Path {imagePath} does not exist or is not a file, skipping this file.")
+                continue  # 跳过不存在的文件
+
+            imageLabel = [int(float(i)) for i in lineItems[1:]]
+            if np.array(imageLabel).sum() >= 1:  # 确保至少有一个标签为正
+                self.listImagePaths.append(imagePath)
+                self.listImageLabels.append(imageLabel)
+
+        # 如果没有有效样本,抛出异常
+        if len(self.listImagePaths) == 0:
+            raise ValueError("No valid samples found. Please check your dataset file and image paths.")
+
+    def __getitem__(self, index):
+        imageData = Image.open(self.listImagePaths[index]).convert('RGB')
+        imageData = self.transform(imageData)
+        return imageData, torch.FloatTensor(self.listImageLabels[index])
+
+    def __len__(self):
+        return len(self.listImagePaths)
+
+    def get_features_and_labels(self):
+        if not self.features:
+            raise ValueError("No features extracted. Ensure `model` is provided during initialization.")
+        return self.features, self.labels

+ 200 - 0
DensenetModels(2)(4).py

@@ -0,0 +1,200 @@
+import os
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from CBAM import CBAM
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+from sklearn.metrics import roc_auc_score
+
+import torchvision
+
+# class FocalLoss(nn.Module):
+#     def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
+#         super(FocalLoss, self).__init__()
+#         self.alpha = alpha
+#         self.gamma = gamma
+#         self.logits = logits
+#         self.reduce = reduce
+
+#     def forward(self, inputs, targets):
+#         if self.logits:
+#             BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+#         else:
+#             BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
+#         pt = torch.exp(-BCE_loss)
+#         F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
+#         if self.reduce:
+#             return torch.mean(F_loss)
+#         else:
+#             return F_loss
+class FocalLoss(nn.Module):
+    def __init__(self, alpha=None, gamma=2, logits=True, reduce=True):
+        super(FocalLoss, self).__init__()
+        if alpha is None:
+            self.alpha = 1.0
+        else:
+            self.alpha = alpha
+        self.gamma = gamma
+        self.logits = logits
+        self.reduce = reduce
+
+    def forward(self, inputs, targets):
+        if self.logits:
+            BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+        else:
+            BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
+
+        pt = torch.exp(-BCE_loss)
+
+        # 应用类别权重 alpha
+        if isinstance(self.alpha, torch.Tensor):
+            self.alpha = self.alpha.to(inputs.device)
+            alpha_factor = self.alpha.unsqueeze(0)  # 调整形状以匹配输入 (batch_size, num_classes)
+        else:
+            alpha_factor = self.alpha
+
+        F_loss = alpha_factor * (1 - pt) ** self.gamma * BCE_loss
+
+        if self.reduce:
+            return torch.mean(F_loss)
+        else:
+            return F_loss
+# class DenseNet121(nn.Module):
+#
+#     def __init__(self, classCount, isTrained):
+#         super(DenseNet121, self).__init__()
+#
+#         self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)
+#
+#         kernelCount = self.densenet121.classifier.in_features
+#
+#         self.densenet121.classifier = nn.Sequential(
+#             nn.Linear(kernelCount, classCount))
+#
+#     def forward(self, x):
+#         x = self.densenet121(x)
+#         return x
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel, bias=False),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
+
+# 修改后的DenseNet121,增加SE模块
+class DenseNet121(nn.Module):
+    def __init__(self, num_classes, pretrained=True):
+        super(DenseNet121, self).__init__()
+        # 加载预训练的DenseNet121模型
+        self.densenet = torchvision.models.densenet121(weights="IMAGENET1K_V1" if pretrained else None)
+        self.se = SELayer(self.densenet.features[-1].num_features)
+        kernel_count = self.densenet.classifier.in_features
+        self.densenet.classifier = nn.Linear(kernel_count, num_classes)
+
+    def forward(self, x):
+        # 特征提取
+        x = self.densenet.features(x)
+        # 使用SE模块
+        x = self.se(x)
+        # 全局平均池化以确保形状匹配
+        x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+        # 展平特征
+        x = torch.flatten(x, 1)
+        # 分类层
+        x = self.densenet.classifier(x)
+        return x
+
+# class DenseNet121(nn.Module):
+#     def __init__(self, num_classes, pretrained=True):
+#         super(DenseNet121, self).__init__()
+#         self.densenet121 = torchvision.models.densenet121(
+#             weights="IMAGENET1K_V1" if pretrained else None)
+#         kernel_count = self.densenet121.classifier.in_features
+
+#         self.cbam = CBAM(kernel_count)
+
+#         # 移除 Sigmoid 激活
+#         self.densenet121.classifier = nn.Linear(kernel_count, num_classes)
+
+#     def forward(self, x):
+#         x = self.densenet121.features(x)
+#         x = self.cbam(x)
+#         x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+#         x = torch.flatten(x, 1)
+#         x = self.densenet121.classifier(x)
+#         return x
+
+class DenseNet169(nn.Module):
+
+    def __init__(self, classCount, isTrained):
+        super(DenseNet169, self).__init__()
+
+        self.densenet169 = torchvision.models.densenet169(pretrained=isTrained)
+
+        kernelCount = self.densenet169.classifier.in_features
+
+        self.densenet169.classifier = nn.Sequential(
+            nn.Linear(kernelCount, classCount), nn.Sigmoid())
+
+    def forward(self, x):
+        x = self.densenet169(x)
+        return x
+
+
+class DenseNet201(nn.Module):
+
+    def __init__(self, classCount, isTrained):
+        super(DenseNet201, self).__init__()
+
+        self.densenet201 = torchvision.models.densenet201(pretrained=isTrained)
+
+        kernelCount = self.densenet201.classifier.in_features
+
+        self.densenet201.classifier = nn.Sequential(
+            nn.Linear(kernelCount, classCount), nn.Sigmoid())
+
+    def forward(self, x):
+        x = self.densenet201(x)
+        return x
+
+class ResNet50(nn.Module):
+    def __init__(self, classCount, isTrained=True):
+        super(ResNet50, self).__init__()
+        self.resnet50 = torchvision.models.resnet50(weights="IMAGENET1K_V1" if isTrained else None)
+        self.se = SELayer(self.resnet50.layer4[-1].conv3.out_channels)
+        kernelCount = self.resnet50.fc.in_features
+        self.resnet50.fc = nn.Linear(kernelCount, classCount)
+
+    def forward(self, x):
+        # 特征提取
+        x = self.resnet50.conv1(x)
+        x = self.resnet50.bn1(x)
+        x = self.resnet50.relu(x)
+        x = self.resnet50.maxpool(x)
+
+        x = self.resnet50.layer1(x)
+        x = self.resnet50.layer2(x)
+        x = self.resnet50.layer3(x)
+        x = self.resnet50.layer4(x)
+        # 使用SE模块
+        x = self.se(x)
+        # 全局平均池化以确保形状匹配
+        x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+        # 展平特征
+        x = torch.flatten(x, 1)
+        # 分类层
+        x = self.resnet50.fc(x)
+        return x

+ 200 - 0
DensenetModels.py

@@ -0,0 +1,200 @@
+import os
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+from CBAM import CBAM
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+from sklearn.metrics import roc_auc_score
+
+import torchvision
+
+# class FocalLoss(nn.Module):
+#     def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
+#         super(FocalLoss, self).__init__()
+#         self.alpha = alpha
+#         self.gamma = gamma
+#         self.logits = logits
+#         self.reduce = reduce
+
+#     def forward(self, inputs, targets):
+#         if self.logits:
+#             BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+#         else:
+#             BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
+#         pt = torch.exp(-BCE_loss)
+#         F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
+#         if self.reduce:
+#             return torch.mean(F_loss)
+#         else:
+#             return F_loss
+class FocalLoss(nn.Module):
+    def __init__(self, alpha=None, gamma=2, logits=True, reduce=True):
+        super(FocalLoss, self).__init__()
+        if alpha is None:
+            self.alpha = 1.0
+        else:
+            self.alpha = alpha
+        self.gamma = gamma
+        self.logits = logits
+        self.reduce = reduce
+
+    def forward(self, inputs, targets):
+        if self.logits:
+            BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+        else:
+            BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
+
+        pt = torch.exp(-BCE_loss)
+
+        # 应用类别权重 alpha
+        if isinstance(self.alpha, torch.Tensor):
+            self.alpha = self.alpha.to(inputs.device)
+            alpha_factor = self.alpha.unsqueeze(0)  # 调整形状以匹配输入 (batch_size, num_classes)
+        else:
+            alpha_factor = self.alpha
+
+        F_loss = alpha_factor * (1 - pt) ** self.gamma * BCE_loss
+
+        if self.reduce:
+            return torch.mean(F_loss)
+        else:
+            return F_loss
+class DenseNet121(nn.Module):
+
+    def __init__(self, classCount, isTrained):
+        super(DenseNet121, self).__init__()
+
+        self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)
+
+        kernelCount = self.densenet121.classifier.in_features
+
+        self.densenet121.classifier = nn.Sequential(
+            nn.Linear(kernelCount, classCount))
+
+    def forward(self, x):
+        x = self.densenet121(x)
+        return x
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel, bias=False),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
+
+# 修改后的DenseNet121,增加SE模块
+# class DenseNet121(nn.Module):
+#     def __init__(self, num_classes, pretrained=True):
+#         super(DenseNet121, self).__init__()
+#         # 加载预训练的DenseNet121模型
+#         self.densenet = torchvision.models.densenet121(weights="IMAGENET1K_V1" if pretrained else None)
+#         self.se = SELayer(self.densenet.features[-1].num_features)
+#         kernel_count = self.densenet.classifier.in_features
+#         self.densenet.classifier = nn.Linear(kernel_count, num_classes)
+#
+#     def forward(self, x):
+#         # 特征提取
+#         x = self.densenet.features(x)
+#         # 使用SE模块
+#         x = self.se(x)
+#         # 全局平均池化以确保形状匹配
+#         x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+#         # 展平特征
+#         x = torch.flatten(x, 1)
+#         # 分类层
+#         x = self.densenet.classifier(x)
+#         return x
+
+# class DenseNet121(nn.Module):
+#     def __init__(self, num_classes, pretrained=True):
+#         super(DenseNet121, self).__init__()
+#         self.densenet121 = torchvision.models.densenet121(
+#             weights="IMAGENET1K_V1" if pretrained else None)
+#         kernel_count = self.densenet121.classifier.in_features
+
+#         self.cbam = CBAM(kernel_count)
+
+#         # 移除 Sigmoid 激活
+#         self.densenet121.classifier = nn.Linear(kernel_count, num_classes)
+
+#     def forward(self, x):
+#         x = self.densenet121.features(x)
+#         x = self.cbam(x)
+#         x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+#         x = torch.flatten(x, 1)
+#         x = self.densenet121.classifier(x)
+#         return x
+
+class DenseNet169(nn.Module):
+
+    def __init__(self, classCount, isTrained):
+        super(DenseNet169, self).__init__()
+
+        self.densenet169 = torchvision.models.densenet169(pretrained=isTrained)
+
+        kernelCount = self.densenet169.classifier.in_features
+
+        self.densenet169.classifier = nn.Sequential(
+            nn.Linear(kernelCount, classCount), nn.Sigmoid())
+
+    def forward(self, x):
+        x = self.densenet169(x)
+        return x
+
+
+class DenseNet201(nn.Module):
+
+    def __init__(self, classCount, isTrained):
+        super(DenseNet201, self).__init__()
+
+        self.densenet201 = torchvision.models.densenet201(pretrained=isTrained)
+
+        kernelCount = self.densenet201.classifier.in_features
+
+        self.densenet201.classifier = nn.Sequential(
+            nn.Linear(kernelCount, classCount), nn.Sigmoid())
+
+    def forward(self, x):
+        x = self.densenet201(x)
+        return x
+
+class ResNet50(nn.Module):
+    def __init__(self, classCount, isTrained=True):
+        super(ResNet50, self).__init__()
+        self.resnet50 = torchvision.models.resnet50(weights="IMAGENET1K_V1" if isTrained else None)
+        self.se = SELayer(self.resnet50.layer4[-1].conv3.out_channels)
+        kernelCount = self.resnet50.fc.in_features
+        self.resnet50.fc = nn.Linear(kernelCount, classCount)
+
+    def forward(self, x):
+        # 特征提取
+        x = self.resnet50.conv1(x)
+        x = self.resnet50.bn1(x)
+        x = self.resnet50.relu(x)
+        x = self.resnet50.maxpool(x)
+
+        x = self.resnet50.layer1(x)
+        x = self.resnet50.layer2(x)
+        x = self.resnet50.layer3(x)
+        x = self.resnet50.layer4(x)
+        # 使用SE模块
+        x = self.se(x)
+        # 全局平均池化以确保形状匹配
+        x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
+        # 展平特征
+        x = torch.flatten(x, 1)
+        # 分类层
+        x = self.resnet50.fc(x)
+        return x

+ 198 - 0
Main(2)(2).py

@@ -0,0 +1,198 @@
+import os
+import numpy as np
+import time
+import sys
+
+from ChexnetTrainer import ChexnetTrainer  # 从ChexnetTrainer模块中导入训练相关的类
+
+
+# --------------------------------------------------------------------------------
+
+# 主函数,负责启动不同的功能(训练、测试或运行演示)
+def main():
+    #runDemo()  # 运行演示模式
+    # runTest()  # 测试模式(注释掉,可以通过解除注释来运行)
+    runTrain()  # 训练模式(注释掉,可以通过解除注释来运行)
+
+
+# --------------------------------------------------------------------------------
+
+# 训练函数,定义训练所需的参数并启动模型训练
+def runTrain():
+    DENSENET121 = 'DENSE-NET-121'  # 定义DenseNet121模型名称
+    DENSENET169 = 'DENSE-NET-169'  # 定义DenseNet169模型名称
+    DENSENET201 = 'DENSE-NET-201'  # 定义DenseNet201模型名称
+    Resnet50='RESNET-50'  # 定义Resnet50模型名称
+
+    # 获取当前的时间戳,作为训练过程的标记
+    timestampTime = time.strftime("%H%M%S")
+    timestampDate = time.strftime("%d%m%Y")
+    timestampLaunch = timestampDate + '-' + timestampTime
+    print("Launching " + timestampLaunch)
+
+    # 图像数据所在的路径
+    pathDirData = 'chest xray14'
+
+    # 训练、验证和测试数据集文件路径
+    # 每个文件中包含图像路径及其对应的标签
+    pathFileTrain = './dataset/train_2.txt'
+    pathFileVal = './dataset/valid_2.txt'
+    pathFileTest = './dataset/test_2.txt'
+
+    # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
+    nnArchitecture = DENSENET121
+    nnIsTrained = True  # 使用预训练的权重
+    nnClassCount = 14  # 数据集包含14个分类
+
+    # 训练参数:批量大小和最大迭代次数(epochs)
+    trBatchSize = 16
+    trMaxEpoch = 48
+
+    # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
+    imgtransResize = 256
+    imgtransCrop = 224
+
+    # 保存模型的路径,包含时间戳
+
+    pathModel = 'm-' + timestampLaunch + '.pth.tar'
+
+    print('Training NN architecture = ', nnArchitecture)
+    ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
+                         nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
+                         trMaxEpoch, imgtransResize, imgtransCrop,
+                         timestampLaunch, None)
+    print('Testing the trained model')
+    ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
+                        nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
+                        imgtransCrop, timestampLaunch)
+
+# --------------------------------------------------------------------------------
+
+# 测试函数,加载预训练模型并在测试数据集上进行测试
+def runTest():
+    pathDirData = '/chest xray14'  # 数据路径
+    pathFileTest = './dataset/test.txt'  # 测试集路径
+    nnArchitecture = 'DENSE-NET-121'  # 使用DenseNet121架构
+    nnIsTrained = True  # 使用预训练模型
+    nnClassCount = 14  # 分类数
+    trBatchSize = 4  # 批量大小
+    imgtransResize = 256  # 图像缩放大小
+    imgtransCrop = 224  # 图像裁剪大小
+
+    # 预训练模型路径
+    pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
+
+    timestampLaunch = ''  # 时间戳
+
+    # 调用测试函数,使用上述参数
+    ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture,
+                        nnClassCount, nnIsTrained,
+                        trBatchSize, imgtransResize, imgtransCrop,
+                        timestampLaunch)
+
+
+# --------------------------------------------------------------------------------
+
+# 演示函数,展示模型在测试集上的推理过程
+def runDemo():
+    pathDirData = '/media/sunjc0306/未命名/CODE/chest xray14'  # 数据路径
+    pathFileTest = './dataset/test.txt'  # 测试集路径
+    nnArchitecture = 'DENSE-NET-121'  # 使用DenseNet121架构
+    nnIsTrained = True  # 使用预训练模型
+    nnClassCount = 14  # 分类数
+    trBatchSize = 4  # 批量大小
+    imgtransResize = 256  # 图像缩放大小
+    imgtransCrop = 224  # 图像裁剪大小
+
+    pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'  # 预训练模型路径
+
+    # 定义分类名称
+    CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
+                   'Mass', 'Nodule', 'Pneumonia',
+                   'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
+                   'Fibrosis', 'Pleural_Thickening', 'Hernia']
+
+    import torch
+    import torch.backends.cudnn as cudnn
+    import torchvision.transforms as transforms
+    from torch.utils.data import DataLoader
+    from tqdm import tqdm  # 用于进度条显示
+    from DensenetModels import DenseNet121  # 导入模型
+    from DensenetModels import DenseNet169
+    from DensenetModels import DenseNet201
+    from DatasetGenerator import DatasetGenerator  # 导入数据集生成器
+    cudnn.benchmark = True  # 加速模型的推理过程
+
+    # 设置网络架构并加载模型
+    if nnArchitecture == 'DENSE-NET-121':
+        model = DenseNet121(nnClassCount, nnIsTrained).cuda()
+    elif nnArchitecture == 'DENSE-NET-169':
+        model = DenseNet169(nnClassCount, nnIsTrained).cuda()
+    elif nnArchitecture == 'DENSE-NET-201':
+        model = DenseNet201(nnClassCount, nnIsTrained).cuda()
+
+    model = model.cuda()
+
+    # 加载预训练的模型权重
+    modelCheckpoint = torch.load(pathModel)
+    model.load_state_dict(modelCheckpoint['state_dict'])
+
+    # 图像变换:先缩放,再进行裁剪,最后进行归一化
+    normalize = transforms.Normalize([0.485, 0.456, 0.406],
+                                     [0.229, 0.224, 0.225])
+
+    # 数据集变换和数据加载器
+    transformList = []
+    transformList.append(transforms.Resize(imgtransResize))
+    transformList.append(transforms.TenCrop(imgtransCrop))  # 对图像进行十裁剪
+    transformList.append(transforms.Lambda(lambda crops: torch.stack(
+        [transforms.ToTensor()(crop) for crop in crops])))  # 转换为张量
+    transformList.append(transforms.Lambda(
+        lambda crops: torch.stack([normalize(crop) for crop in crops])))  # 归一化
+    transformSequence = transforms.Compose(transformList)
+
+    # 构建测试集数据集生成器
+    datasetTest = DatasetGenerator(pathImageDirectory=pathDirData,
+                                   pathDatasetFile=pathFileTest,
+                                   transform=transformSequence)
+
+    model.eval()  # 设置模型为评估模式
+    i = 0  # 初始索引
+
+    # 定义演示函数,展示模型在某个样本上的推理结果
+    def demo(i):
+        (input, target) = datasetTest[i]  # 获取输入和目标
+        n_crops, c, h, w = input.size()  # 获取输入的尺寸
+
+        # 将输入放入模型进行推理
+        varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(),
+                                           volatile=True)
+        with torch.no_grad():
+            out = model(varInput)
+            outMean = out.view(1, n_crops, -1).mean(1)  # 对裁剪后的多个输出取平均
+        print('-------------------------------')
+        pd = torch.sigmoid(outMean).ge_(0.5).long().detach().cpu().numpy()[
+            0].tolist()  # 预测结果
+        gt = target.long().detach().cpu().numpy().tolist()  # 真实标签
+        print(f"PD_{i}", pd)  # 输出预测结果
+        print(f"GT_{i}", gt)  # 输出真实标签
+        return pd, gt
+
+    # 逐个演示测试集中的样本
+    for i in range(len(datasetTest)):
+        pd, gt = demo(i)  # 调用演示函数
+        if type(gt) == list:
+            pd_name = []
+            gt_name = []
+            for i in range(len(CLASS_NAMES)):
+                if pd[i] == 1:
+                    pd_name.append(CLASS_NAMES[i])  # 将预测为1的标签存入pd_name
+                if gt[i] == 1:
+                    gt_name.append(CLASS_NAMES[i])  # 将真实为1的标签存入gt_name
+            print(datasetTest.listImagePaths[i])  # 打印图像路径
+            print(pd_name)  # 打印预测的疾病名称
+            print(gt_name)  # 打印真实的疾病名称
+
+
+if __name__ == '__main__':
+    main()  # 启动主函数

+ 224 - 0
Main.py

@@ -0,0 +1,224 @@
+
+import os
+import time
+import torch.multiprocessing as mp
+import matplotlib.pyplot as plt
+import numpy as np
+from ChexnetTrainer import ChexnetTrainer
+from DatasetGenerator import DatasetGenerator
+from visual import HeatmapGenerator
+
+# 确保 images 目录存在
+if not os.path.exists('images'):
+    os.makedirs('images')
+
+
+# 主函数,负责启动不同的功能(训练、测试或运行演示)
+def main():
+    # runDemo()  # 运行演示模式
+    # runTest()  # 测试模式(注释掉,可以通过解除注释来运行)
+    runTrain()  # 训练模式(注释掉,可以通过解除注释来运行)
+
+
+# --------------------------------------------------------------------------------
+
+# 训练函数,定义训练所需的参数并启动模型训练
+def runTrain():
+    DENSENET121 = 'DENSE-NET-121'  # 定义DenseNet121模型名称
+    DENSENET169 = 'DENSE-NET-169'  # 定义DenseNet169模型名称
+    DENSENET201 = 'DENSE-NET-201'  # 定义DenseNet201模型名称
+    Resnet50 = 'RESNET-50'  # 定义Resnet50模型名称
+
+    # 获取当前的时间戳,作为训练过程的标记
+    timestampTime = time.strftime("%H%M%S")
+    timestampDate = time.strftime("%d%m%Y")
+    timestampLaunch = timestampDate + '-' + timestampTime
+    print("Launching " + timestampLaunch)
+
+    # 图像数据所在的路径
+    pathDirData = './chest xray14'
+
+    # 训练、验证和测试数据集文件路径
+    # 每个文件中包含图像路径及其对应的标签
+    pathFileTrain = './dataset/train_2.txt'
+    pathFileVal = './dataset/valid_2.txt'
+    pathFileTest = './dataset/test_2.txt'
+
+    # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量
+    nnArchitecture = DENSENET121
+    nnIsTrained = True  # 使用预训练的权重
+    nnClassCount = 14  # 数据集包含14个分类
+
+    # 训练参数:批量大小和最大迭代次数(epochs)
+    trBatchSize = 2
+    trMaxEpoch = 2
+
+    # 图像预处理相关参数:图像缩放的大小和裁剪后的大小
+    imgtransResize = 256
+    imgtransCrop = 224
+
+    # 保存模型的路径,包含时间戳
+    pathModel = 'm-' + timestampLaunch + '.pth.tar'
+
+    print('Training NN architecture = ', nnArchitecture)
+    ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal,
+                         nnArchitecture, nnIsTrained, nnClassCount, trBatchSize,
+                         trMaxEpoch, imgtransResize, imgtransCrop,
+                         timestampLaunch, None)
+    print('Testing the trained model')
+
+    pathRfModel = os.path.join('images', 'random_forest_model.pkl')  # 更新路径
+    labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
+                                           nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
+                                           imgtransCrop, timestampLaunch)
+
+    # 生成并保存热力图
+    # 选择一些测试集中的样本来生成热力图
+    transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
+    datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
+                                   transform=transformSequence, model=None)
+
+    # 确保测试集中有足够的样本
+    num_samples = min(8, len(datasetTest))
+    sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False)  # 随机选择8个样本
+
+    # 定义分类名称
+    CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
+                   'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
+                   'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
+                   'Fibrosis', 'Pleural_Thickening',
+                   'Hernia']
+
+    # 创建一个图形窗口
+    plt.figure(figsize=(20, 10))
+    n_cols = 4
+    n_rows = 2
+
+    for idx, sample_idx in enumerate(sample_indices):
+        image_path = datasetTest.listImagePaths[sample_idx]
+        true_labels = labels[sample_idx]
+        pred_labels = rf_preds[sample_idx]
+
+        # 加载图像
+        img = plt.imread(image_path)
+
+        # 创建子图
+        ax = plt.subplot(n_rows, n_cols, idx + 1)
+        ax.imshow(img)
+        ax.axis('off')
+
+        # 获取真实标签和预测标签的名称
+        true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
+        pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
+
+        # 设置标题
+        title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
+        ax.set_title(title, fontsize=10)
+
+    plt.tight_layout()
+    output_plot_path = os.path.join('images', 'test_predictions.png')
+    plt.savefig(output_plot_path)
+    plt.show()
+    print(f"预测结果图已保存到 {output_plot_path}")
+
+    # 生成热力图(可选)
+    for idx in sample_indices:
+        image_path = datasetTest.listImagePaths[idx]
+        output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
+        h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
+        h.generate(image_path, output_heatmap_path, imgtransCrop)
+        print(f"热力图已保存到 {output_heatmap_path}")
+
+
+# --------------------------------------------------------------------------------
+
+# 测试函数,加载预训练模型并在测试数据集上进行测试
+def runTest():
+    pathDirData = '/chest xray14'
+    pathFileTest = './dataset/test.txt'
+    nnArchitecture = 'DENSE-NET-121'
+    nnIsTrained = True
+    nnClassCount = 14
+    trBatchSize = 4
+    imgtransResize = 256
+    imgtransCrop = 224
+
+    pathModel = 'm-06102024-235412BCELoss()delete.pth.tar'
+
+    timestampLaunch = ''
+
+    # 获取统一的 transformSequence
+    transformSequence = ChexnetTrainer._get_transform(imgtransCrop)
+    pathRfModel = 'images/random_forest_model.pkl'  # 确保路径正确
+    labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture,
+                                           nnClassCount, nnIsTrained, trBatchSize, imgtransResize,
+                                           imgtransCrop, timestampLaunch)
+
+    # 生成并保存热力图
+    datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest,
+                                   transform=transformSequence, model=None)
+
+    # 确保测试集中有足够的样本
+    num_samples = min(8, len(datasetTest))
+    sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False)  # 随机选择8个样本
+
+    # 定义分类名称
+    CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion',
+                   'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
+                   'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
+                   'Fibrosis', 'Pleural_Thickening',
+                   'Hernia']
+
+    # 创建一个图形窗口
+    plt.figure(figsize=(20, 10))
+    n_cols = 4
+    n_rows = 2
+
+    for idx, sample_idx in enumerate(sample_indices):
+        image_path = datasetTest.listImagePaths[sample_idx]
+        true_labels = labels[sample_idx]
+        pred_labels = rf_preds[sample_idx]
+
+        # 加载图像
+        img = plt.imread(image_path)
+
+        # 创建子图
+        ax = plt.subplot(n_rows, n_cols, idx + 1)
+        ax.imshow(img)
+        ax.axis('off')
+
+        # 获取真实标签和预测标签的名称
+        true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1]
+        pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1]
+
+        # 设置标题
+        title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}"
+        ax.set_title(title, fontsize=10)
+
+    plt.tight_layout()
+    output_plot_path = os.path.join('images', 'test_predictions.png')
+    plt.savefig(output_plot_path)
+    plt.show()
+    print(f"预测结果图已保存到 {output_plot_path}")
+
+    # 生成热力图(可选)
+    for idx in sample_indices:
+        image_path = datasetTest.listImagePaths[idx]
+        output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png')
+        h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence)
+        h.generate(image_path, output_heatmap_path, imgtransCrop)
+        print(f"热力图已保存到 {output_heatmap_path}")
+
+
+# --------------------------------------------------------------------------------
+
+# 演示函数,展示模型在测试集上的推理过程
+def runDemo():
+    # 原有代码保持不变
+    pass
+
+
+# 确保代码在主进程中运行
+if __name__ == '__main__':
+    mp.set_start_method('spawn', force=True)
+    main()  # 启动主函数