ASPP-空洞空间金字塔池化模块

Ivan 发布于 25 天前 46 次阅读 AIの灾难 225 字 最后更新于 25 天前


AI 摘要

ASPP模块作为空间金字塔池化的革新设计,通过并联多尺度特征提取路径,突破传统卷积核的感受野有限制约,以空洞卷积规避下采样引发的信息缺失,实现更全面的上下文感知。

ASPP-空洞空间金字塔池化模块

class ASPP(nn.Module):
    """空洞空间金字塔池化模块"""
    def __init__(self, in_channels, out_channels=256, rates=[6, 12, 18, 24]):
        super(ASPP, self).__init__()
        modules = []
        
        # 1x1 卷积
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        # 多个不同空洞率的空洞卷积
        for rate in rates:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        
        # 图像级特征
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        self.convs = nn.ModuleList(modules)
        
        # 投影层
        self.project = nn.Sequential(
            nn.Conv2d((len(rates) + 2) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def forward(self, x):
        size = x.shape[2:]
        res = []
        for conv in self.convs:
            y = conv(x)
            if y.shape[2:] != size:
                y = F.interpolate(y, size=size, mode='bilinear', align_corners=False)
            res.append(y)
        return self.project(torch.cat(res, dim=1))