DensenetModels.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.backends.cudnn as cudnn
  6. from CBAM import CBAM
  7. import torchvision.transforms as transforms
  8. from torch.utils.data import DataLoader
  9. from sklearn.metrics import roc_auc_score
  10. import torchvision
  11. # class FocalLoss(nn.Module):
  12. # def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
  13. # super(FocalLoss, self).__init__()
  14. # self.alpha = alpha
  15. # self.gamma = gamma
  16. # self.logits = logits
  17. # self.reduce = reduce
  18. # def forward(self, inputs, targets):
  19. # if self.logits:
  20. # BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
  21. # else:
  22. # BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
  23. # pt = torch.exp(-BCE_loss)
  24. # F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
  25. # if self.reduce:
  26. # return torch.mean(F_loss)
  27. # else:
  28. # return F_loss
  29. class FocalLoss(nn.Module):
  30. def __init__(self, alpha=None, gamma=2, logits=True, reduce=True):
  31. super(FocalLoss, self).__init__()
  32. if alpha is None:
  33. self.alpha = 1.0
  34. else:
  35. self.alpha = alpha
  36. self.gamma = gamma
  37. self.logits = logits
  38. self.reduce = reduce
  39. def forward(self, inputs, targets):
  40. if self.logits:
  41. BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
  42. else:
  43. BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
  44. pt = torch.exp(-BCE_loss)
  45. # 应用类别权重 alpha
  46. if isinstance(self.alpha, torch.Tensor):
  47. self.alpha = self.alpha.to(inputs.device)
  48. alpha_factor = self.alpha.unsqueeze(0) # 调整形状以匹配输入 (batch_size, num_classes)
  49. else:
  50. alpha_factor = self.alpha
  51. F_loss = alpha_factor * (1 - pt) ** self.gamma * BCE_loss
  52. if self.reduce:
  53. return torch.mean(F_loss)
  54. else:
  55. return F_loss
  56. class DenseNet121(nn.Module):
  57. def __init__(self, classCount, isTrained):
  58. super(DenseNet121, self).__init__()
  59. self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)
  60. kernelCount = self.densenet121.classifier.in_features
  61. self.densenet121.classifier = nn.Sequential(
  62. nn.Linear(kernelCount, classCount))
  63. def forward(self, x):
  64. x = self.densenet121(x)
  65. return x
  66. class SELayer(nn.Module):
  67. def __init__(self, channel, reduction=16):
  68. super(SELayer, self).__init__()
  69. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  70. self.fc = nn.Sequential(
  71. nn.Linear(channel, channel // reduction, bias=False),
  72. nn.ReLU(inplace=True),
  73. nn.Linear(channel // reduction, channel, bias=False),
  74. nn.Sigmoid()
  75. )
  76. def forward(self, x):
  77. b, c, _, _ = x.size()
  78. y = self.avg_pool(x).view(b, c)
  79. y = self.fc(y).view(b, c, 1, 1)
  80. return x * y.expand_as(x)
  81. # 修改后的DenseNet121,增加SE模块
  82. # class DenseNet121(nn.Module):
  83. # def __init__(self, num_classes, pretrained=True):
  84. # super(DenseNet121, self).__init__()
  85. # # 加载预训练的DenseNet121模型
  86. # self.densenet = torchvision.models.densenet121(weights="IMAGENET1K_V1" if pretrained else None)
  87. # self.se = SELayer(self.densenet.features[-1].num_features)
  88. # kernel_count = self.densenet.classifier.in_features
  89. # self.densenet.classifier = nn.Linear(kernel_count, num_classes)
  90. #
  91. # def forward(self, x):
  92. # # 特征提取
  93. # x = self.densenet.features(x)
  94. # # 使用SE模块
  95. # x = self.se(x)
  96. # # 全局平均池化以确保形状匹配
  97. # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
  98. # # 展平特征
  99. # x = torch.flatten(x, 1)
  100. # # 分类层
  101. # x = self.densenet.classifier(x)
  102. # return x
  103. # class DenseNet121(nn.Module):
  104. # def __init__(self, num_classes, pretrained=True):
  105. # super(DenseNet121, self).__init__()
  106. # self.densenet121 = torchvision.models.densenet121(
  107. # weights="IMAGENET1K_V1" if pretrained else None)
  108. # kernel_count = self.densenet121.classifier.in_features
  109. # self.cbam = CBAM(kernel_count)
  110. # # 移除 Sigmoid 激活
  111. # self.densenet121.classifier = nn.Linear(kernel_count, num_classes)
  112. # def forward(self, x):
  113. # x = self.densenet121.features(x)
  114. # x = self.cbam(x)
  115. # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
  116. # x = torch.flatten(x, 1)
  117. # x = self.densenet121.classifier(x)
  118. # return x
  119. class DenseNet169(nn.Module):
  120. def __init__(self, classCount, isTrained):
  121. super(DenseNet169, self).__init__()
  122. self.densenet169 = torchvision.models.densenet169(pretrained=isTrained)
  123. kernelCount = self.densenet169.classifier.in_features
  124. self.densenet169.classifier = nn.Sequential(
  125. nn.Linear(kernelCount, classCount), nn.Sigmoid())
  126. def forward(self, x):
  127. x = self.densenet169(x)
  128. return x
  129. class DenseNet201(nn.Module):
  130. def __init__(self, classCount, isTrained):
  131. super(DenseNet201, self).__init__()
  132. self.densenet201 = torchvision.models.densenet201(pretrained=isTrained)
  133. kernelCount = self.densenet201.classifier.in_features
  134. self.densenet201.classifier = nn.Sequential(
  135. nn.Linear(kernelCount, classCount), nn.Sigmoid())
  136. def forward(self, x):
  137. x = self.densenet201(x)
  138. return x
  139. class ResNet50(nn.Module):
  140. def __init__(self, classCount, isTrained=True):
  141. super(ResNet50, self).__init__()
  142. self.resnet50 = torchvision.models.resnet50(weights="IMAGENET1K_V1" if isTrained else None)
  143. self.se = SELayer(self.resnet50.layer4[-1].conv3.out_channels)
  144. kernelCount = self.resnet50.fc.in_features
  145. self.resnet50.fc = nn.Linear(kernelCount, classCount)
  146. def forward(self, x):
  147. # 特征提取
  148. x = self.resnet50.conv1(x)
  149. x = self.resnet50.bn1(x)
  150. x = self.resnet50.relu(x)
  151. x = self.resnet50.maxpool(x)
  152. x = self.resnet50.layer1(x)
  153. x = self.resnet50.layer2(x)
  154. x = self.resnet50.layer3(x)
  155. x = self.resnet50.layer4(x)
  156. # 使用SE模块
  157. x = self.se(x)
  158. # 全局平均池化以确保形状匹配
  159. x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
  160. # 展平特征
  161. x = torch.flatten(x, 1)
  162. # 分类层
  163. x = self.resnet50.fc(x)
  164. return x