FPN-特征金字塔网络

Ivan 发布于 25 天前 41 次阅读 AIの灾难 244 字 最后更新于 25 天前


AI 摘要

在深度神经网络中,如何多尺度检测?设计增强的FPN结构能够更好表示目标,使用不同层次的特征融合,形成目标对象的丰富表示。

FPN-特征金字塔网络

class EnhancedFPN(nn.Module):
    """增强的特征金字塔网络"""
    def __init__(self, in_channels_list, out_channels):
        super(EnhancedFPN, self).__init__()
        
        # 横向连接
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, 1)
            for in_channels in in_channels_list
        ])
        
        # 输出卷积
        self.fpn_convs = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
            for _ in in_channels_list
        ])
        
        # 特征融合
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(out_channels * len(in_channels_list), out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, features):
        # 横向连接
        laterals = [conv(f) for f, conv in zip(features, self.lateral_convs)]
        
        # 自顶向下的路径
        for i in range(len(laterals) - 1, 0, -1):
            laterals[i-1] = laterals[i-1] + F.interpolate(
                laterals[i], size=laterals[i-1].shape[-2:], 
                mode='bilinear', align_corners=False
            )
        
        # 输出卷积
        outputs = [conv(lateral) for lateral, conv in zip(laterals, self.fpn_convs)]
        
        # 融合所有尺度的特征
        target_size = outputs[0].shape[-2:]
        resized_outputs = []
        for output in outputs:
            if output.shape[-2:] != target_size:
                output = F.interpolate(output, size=target_size, mode='bilinear', align_corners=False)
            resized_outputs.append(output)
        
        # 融合
        fused = torch.cat(resized_outputs, dim=1)
        return self.fusion_conv(fused)