Split-Attention-分散注意力
"""Split-Attention 模块 - ResNeSt的核心组件"""
def __init__(self, in_channels, radix=2, reduction=4):
super(SplitAttention, self).__init__()
self.radix = radix
inter_channels = max(in_channels * radix // reduction, 32)
self.conv = nn.Conv2d(in_channels, in_channels * radix, 1, groups=radix)
self.bn0 = nn.BatchNorm2d(in_channels * radix)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Conv2d(in_channels, inter_channels, 1)
self.bn1 = nn.BatchNorm2d(inter_channels)
self.fc2 = nn.Conv2d(inter_channels, in_channels * radix, 1)
self.rsoftmax = nn.Softmax(dim=1)
def forward(self, x):
batch, channels = x.shape[:2]
# Split
x = self.conv(x)
x = self.bn0(x)
x = self.relu(x)
# 重新整形为 [batch, radix, channels//radix, H, W]
x = x.view(batch, self.radix, channels, x.size(2), x.size(3))
# 聚合
x_gap = x.sum(dim=1) # [batch, channels, H, W]
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
# Attention
x_attn = self.fc1(x_gap)
x_attn = self.bn1(x_attn)
x_attn = self.relu(x_attn)
x_attn = self.fc2(x_attn)
# 重新整形为 [batch, radix, channels, 1, 1]
x_attn = x_attn.view(batch, self.radix, channels, 1, 1)
x_attn = self.rsoftmax(x_attn)
# 加权融合
out = (x * x_attn).sum(dim=1)
return out
Comments NOTHING