12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- # CBAM 模块
- import torch
- import torch.nn as nn
- # 通道注意力模块,用于学习输入特征在通道维度上的重要性
- class ChannelAttention(nn.Module):
- def __init__(self, in_channels, reduction=16):
- super(ChannelAttention, self).__init__()
- # in_channels:输入的通道数,表示特征图有多少个通道
- # reduction:缩减率,用于减少通道的维度,通常设置为16,表示将通道数缩小 16 倍后再扩展回来
- self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应平均池化到 (1, 1),用于生成全局通道信息
- self.max_pool = nn.AdaptiveMaxPool2d(1) # 自适应最大池化到 (1, 1),与平均池化结合使用
- # 全连接层使用1x1卷积替代,全连接层的作用是通过线性变换来学习不同通道的重要性
- # Conv2d(in_channels, in_channels // reduction, 1, bias=False):
- # 将输入通道数减小到原来的 1/reduction(即 1/16),用来降低计算复杂度
- self.fc = nn.Sequential(
- nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
- nn.ReLU(inplace=True), # 激活函数,用于增加网络的非线性特性
- nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 将通道数还原为原始大小
- )
- self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间
- def forward(self, x):
- # 使用平均池化和最大池化得到两个特征图
- avg_out = self.fc(self.avg_pool(x)) # 对平均池化后的特征图应用全连接层
- max_out = self.fc(self.max_pool(x)) # 对最大池化后的特征图应用全连接层
- out = avg_out + max_out # 将两个特征图相加,融合两种不同池化方式的信息
- return self.sigmoid(out) * x # 使用 sigmoid 将结果压缩到 (0, 1) 之间,并乘以输入特征图得到加权后的输出
- # 空间注意力模块,用于学习输入特征在空间维度(H 和 W)上的重要性
- class SpatialAttention(nn.Module):
- def __init__(self, kernel_size=7):
- super(SpatialAttention, self).__init__()
- # kernel_size:卷积核大小,通常为 3 或 7,用于控制注意力机制的感受野
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7' # 检查 kernel_size 的合法性
- padding = (kernel_size - 1) // 2 # 计算 padding 大小,以保持卷积前后特征图的大小一致
- self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # 卷积层,输入通道为2,输出通道为1
- self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间
- def forward(self, x):
- # 平均池化和最大池化在通道维度上进行,得到两个单通道特征图
- avg_out = torch.mean(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取平均值
- max_out, _ = torch.max(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取最大值
- # 将平均池化和最大池化的结果拼接在一起,得到形状为 (batch_size, 2, H, W) 的张量
- mask = torch.cat([avg_out, max_out], dim=1)
- # 通过卷积层生成空间注意力权重
- mask = self.conv(mask)
- mask = self.sigmoid(mask) # 使用 sigmoid 将结果压缩到 (0, 1) 之间
- return mask * x # 使用注意力权重与输入特征图相乘,得到加权后的输出
- # CBAM 模块,结合通道注意力和空间注意力
- class CBAM(nn.Module):
- def __init__(self, in_channels, reduction=16, kernel_size=7):
- super(CBAM, self).__init__()
- # 通道注意力模块,首先对通道维度进行加权
- self.channel_attention = ChannelAttention(in_channels, reduction)
- # 空间注意力模块,然后对空间维度进行加权
- self.spatial_attention = SpatialAttention(kernel_size)
- def forward(self, x):
- x = self.channel_attention(x) # 先通过通道注意力模块
- x = self.spatial_attention(x) # 再通过空间注意力模块
- return x # 返回经过 CBAM 处理的特征图
|