CBAM.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # CBAM 模块
  2. import torch
  3. import torch.nn as nn
  4. # 通道注意力模块,用于学习输入特征在通道维度上的重要性
  5. class ChannelAttention(nn.Module):
  6. def __init__(self, in_channels, reduction=16):
  7. super(ChannelAttention, self).__init__()
  8. # in_channels:输入的通道数,表示特征图有多少个通道
  9. # reduction:缩减率,用于减少通道的维度,通常设置为16,表示将通道数缩小 16 倍后再扩展回来
  10. self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应平均池化到 (1, 1),用于生成全局通道信息
  11. self.max_pool = nn.AdaptiveMaxPool2d(1) # 自适应最大池化到 (1, 1),与平均池化结合使用
  12. # 全连接层使用1x1卷积替代,全连接层的作用是通过线性变换来学习不同通道的重要性
  13. # Conv2d(in_channels, in_channels // reduction, 1, bias=False):
  14. # 将输入通道数减小到原来的 1/reduction(即 1/16),用来降低计算复杂度
  15. self.fc = nn.Sequential(
  16. nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
  17. nn.ReLU(inplace=True), # 激活函数,用于增加网络的非线性特性
  18. nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 将通道数还原为原始大小
  19. )
  20. self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间
  21. def forward(self, x):
  22. # 使用平均池化和最大池化得到两个特征图
  23. avg_out = self.fc(self.avg_pool(x)) # 对平均池化后的特征图应用全连接层
  24. max_out = self.fc(self.max_pool(x)) # 对最大池化后的特征图应用全连接层
  25. out = avg_out + max_out # 将两个特征图相加,融合两种不同池化方式的信息
  26. return self.sigmoid(out) * x # 使用 sigmoid 将结果压缩到 (0, 1) 之间,并乘以输入特征图得到加权后的输出
  27. # 空间注意力模块,用于学习输入特征在空间维度(H 和 W)上的重要性
  28. class SpatialAttention(nn.Module):
  29. def __init__(self, kernel_size=7):
  30. super(SpatialAttention, self).__init__()
  31. # kernel_size:卷积核大小,通常为 3 或 7,用于控制注意力机制的感受野
  32. assert kernel_size in (3, 7), 'kernel size must be 3 or 7' # 检查 kernel_size 的合法性
  33. padding = (kernel_size - 1) // 2 # 计算 padding 大小,以保持卷积前后特征图的大小一致
  34. self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # 卷积层,输入通道为2,输出通道为1
  35. self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间
  36. def forward(self, x):
  37. # 平均池化和最大池化在通道维度上进行,得到两个单通道特征图
  38. avg_out = torch.mean(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取平均值
  39. max_out, _ = torch.max(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取最大值
  40. # 将平均池化和最大池化的结果拼接在一起,得到形状为 (batch_size, 2, H, W) 的张量
  41. mask = torch.cat([avg_out, max_out], dim=1)
  42. # 通过卷积层生成空间注意力权重
  43. mask = self.conv(mask)
  44. mask = self.sigmoid(mask) # 使用 sigmoid 将结果压缩到 (0, 1) 之间
  45. return mask * x # 使用注意力权重与输入特征图相乘,得到加权后的输出
  46. # CBAM 模块,结合通道注意力和空间注意力
  47. class CBAM(nn.Module):
  48. def __init__(self, in_channels, reduction=16, kernel_size=7):
  49. super(CBAM, self).__init__()
  50. # 通道注意力模块,首先对通道维度进行加权
  51. self.channel_attention = ChannelAttention(in_channels, reduction)
  52. # 空间注意力模块,然后对空间维度进行加权
  53. self.spatial_attention = SpatialAttention(kernel_size)
  54. def forward(self, x):
  55. x = self.channel_attention(x) # 先通过通道注意力模块
  56. x = self.spatial_attention(x) # 再通过空间注意力模块
  57. return x # 返回经过 CBAM 处理的特征图