CBAM-通道注意力+空间注意力
class DualAttention(nn.Module):
"""双注意力模块 - 通道注意力 + 空间注意力"""
def __init__(self, in_channels, reduction=16):
super(DualAttention, self).__init__()
# 通道注意力
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid()
)
# 空间注意力
self.spatial_attention = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=7, padding=3),
nn.Sigmoid()
)
def forward(self, x):
# 通道注意力
ca = self.channel_attention(x)
x = x * ca
# 空间注意力
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
sa_input = torch.cat([avg_out, max_out], dim=1)
sa = self.spatial_attention(sa_input)
x = x * sa
return x
Comments NOTHING