Split-Attention-分散注意力

Ivan 发布于 26 天前 52 次阅读 AIの灾难 235 字 最后更新于 26 天前


AI 摘要

Split-Attention模块是ResNeSt中的核心组件。它通过一种新颖的方式设计,旨在释放网络的潜力,从而在复杂任务中实现强大性能。这个机制如何运作?它又能带来哪些提升?

Split-Attention-分散注意力

    """Split-Attention 模块 - ResNeSt的核心组件"""
    def __init__(self, in_channels, radix=2, reduction=4):
        super(SplitAttention, self).__init__()
        self.radix = radix
        inter_channels = max(in_channels * radix // reduction, 32)
        
        self.conv = nn.Conv2d(in_channels, in_channels * radix, 1, groups=radix)
        self.bn0 = nn.BatchNorm2d(in_channels * radix)
        self.relu = nn.ReLU(inplace=True)
        
        self.fc1 = nn.Conv2d(in_channels, inter_channels, 1)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.fc2 = nn.Conv2d(inter_channels, in_channels * radix, 1)
        
        self.rsoftmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch, channels = x.shape[:2]
        
        # Split
        x = self.conv(x)
        x = self.bn0(x)
        x = self.relu(x)
        
        # 重新整形为 [batch, radix, channels//radix, H, W]
        x = x.view(batch, self.radix, channels, x.size(2), x.size(3))
        
        # 聚合
        x_gap = x.sum(dim=1)  # [batch, channels, H, W]
        x_gap = F.adaptive_avg_pool2d(x_gap, 1)
        
        # Attention
        x_attn = self.fc1(x_gap)
        x_attn = self.bn1(x_attn)
        x_attn = self.relu(x_attn)
        x_attn = self.fc2(x_attn)
        
        # 重新整形为 [batch, radix, channels, 1, 1]
        x_attn = x_attn.view(batch, self.radix, channels, 1, 1)
        x_attn = self.rsoftmax(x_attn)
        
        # 加权融合
        out = (x * x_attn).sum(dim=1)
        return out