123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- import os
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.backends.cudnn as cudnn
- from CBAM import CBAM
- import torchvision.transforms as transforms
- from torch.utils.data import DataLoader
- from sklearn.metrics import roc_auc_score
- import torchvision
- # class FocalLoss(nn.Module):
- # def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
- # super(FocalLoss, self).__init__()
- # self.alpha = alpha
- # self.gamma = gamma
- # self.logits = logits
- # self.reduce = reduce
- # def forward(self, inputs, targets):
- # if self.logits:
- # BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
- # else:
- # BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
- # pt = torch.exp(-BCE_loss)
- # F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
- # if self.reduce:
- # return torch.mean(F_loss)
- # else:
- # return F_loss
- class FocalLoss(nn.Module):
- def __init__(self, alpha=None, gamma=2, logits=True, reduce=True):
- super(FocalLoss, self).__init__()
- if alpha is None:
- self.alpha = 1.0
- else:
- self.alpha = alpha
- self.gamma = gamma
- self.logits = logits
- self.reduce = reduce
- def forward(self, inputs, targets):
- if self.logits:
- BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
- else:
- BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
- pt = torch.exp(-BCE_loss)
- # 应用类别权重 alpha
- if isinstance(self.alpha, torch.Tensor):
- self.alpha = self.alpha.to(inputs.device)
- alpha_factor = self.alpha.unsqueeze(0) # 调整形状以匹配输入 (batch_size, num_classes)
- else:
- alpha_factor = self.alpha
- F_loss = alpha_factor * (1 - pt) ** self.gamma * BCE_loss
- if self.reduce:
- return torch.mean(F_loss)
- else:
- return F_loss
- class DenseNet121(nn.Module):
- def __init__(self, classCount, isTrained):
- super(DenseNet121, self).__init__()
- self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)
- kernelCount = self.densenet121.classifier.in_features
- self.densenet121.classifier = nn.Sequential(
- nn.Linear(kernelCount, classCount))
- def forward(self, x):
- x = self.densenet121(x)
- return x
- class SELayer(nn.Module):
- def __init__(self, channel, reduction=16):
- super(SELayer, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Sequential(
- nn.Linear(channel, channel // reduction, bias=False),
- nn.ReLU(inplace=True),
- nn.Linear(channel // reduction, channel, bias=False),
- nn.Sigmoid()
- )
- def forward(self, x):
- b, c, _, _ = x.size()
- y = self.avg_pool(x).view(b, c)
- y = self.fc(y).view(b, c, 1, 1)
- return x * y.expand_as(x)
- # 修改后的DenseNet121,增加SE模块
- # class DenseNet121(nn.Module):
- # def __init__(self, num_classes, pretrained=True):
- # super(DenseNet121, self).__init__()
- # # 加载预训练的DenseNet121模型
- # self.densenet = torchvision.models.densenet121(weights="IMAGENET1K_V1" if pretrained else None)
- # self.se = SELayer(self.densenet.features[-1].num_features)
- # kernel_count = self.densenet.classifier.in_features
- # self.densenet.classifier = nn.Linear(kernel_count, num_classes)
- #
- # def forward(self, x):
- # # 特征提取
- # x = self.densenet.features(x)
- # # 使用SE模块
- # x = self.se(x)
- # # 全局平均池化以确保形状匹配
- # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
- # # 展平特征
- # x = torch.flatten(x, 1)
- # # 分类层
- # x = self.densenet.classifier(x)
- # return x
- # class DenseNet121(nn.Module):
- # def __init__(self, num_classes, pretrained=True):
- # super(DenseNet121, self).__init__()
- # self.densenet121 = torchvision.models.densenet121(
- # weights="IMAGENET1K_V1" if pretrained else None)
- # kernel_count = self.densenet121.classifier.in_features
- # self.cbam = CBAM(kernel_count)
- # # 移除 Sigmoid 激活
- # self.densenet121.classifier = nn.Linear(kernel_count, num_classes)
- # def forward(self, x):
- # x = self.densenet121.features(x)
- # x = self.cbam(x)
- # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
- # x = torch.flatten(x, 1)
- # x = self.densenet121.classifier(x)
- # return x
- class DenseNet169(nn.Module):
- def __init__(self, classCount, isTrained):
- super(DenseNet169, self).__init__()
- self.densenet169 = torchvision.models.densenet169(pretrained=isTrained)
- kernelCount = self.densenet169.classifier.in_features
- self.densenet169.classifier = nn.Sequential(
- nn.Linear(kernelCount, classCount), nn.Sigmoid())
- def forward(self, x):
- x = self.densenet169(x)
- return x
- class DenseNet201(nn.Module):
- def __init__(self, classCount, isTrained):
- super(DenseNet201, self).__init__()
- self.densenet201 = torchvision.models.densenet201(pretrained=isTrained)
- kernelCount = self.densenet201.classifier.in_features
- self.densenet201.classifier = nn.Sequential(
- nn.Linear(kernelCount, classCount), nn.Sigmoid())
- def forward(self, x):
- x = self.densenet201(x)
- return x
- class ResNet50(nn.Module):
- def __init__(self, classCount, isTrained=True):
- super(ResNet50, self).__init__()
- self.resnet50 = torchvision.models.resnet50(weights="IMAGENET1K_V1" if isTrained else None)
- self.se = SELayer(self.resnet50.layer4[-1].conv3.out_channels)
- kernelCount = self.resnet50.fc.in_features
- self.resnet50.fc = nn.Linear(kernelCount, classCount)
- def forward(self, x):
- # 特征提取
- x = self.resnet50.conv1(x)
- x = self.resnet50.bn1(x)
- x = self.resnet50.relu(x)
- x = self.resnet50.maxpool(x)
- x = self.resnet50.layer1(x)
- x = self.resnet50.layer2(x)
- x = self.resnet50.layer3(x)
- x = self.resnet50.layer4(x)
- # 使用SE模块
- x = self.se(x)
- # 全局平均池化以确保形状匹配
- x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
- # 展平特征
- x = torch.flatten(x, 1)
- # 分类层
- x = self.resnet50.fc(x)
- return x
|