ASPP-空洞空间金字塔池化模块
class ASPP(nn.Module):
"""空洞空间金字塔池化模块"""
def __init__(self, in_channels, out_channels=256, rates=[6, 12, 18, 24]):
super(ASPP, self).__init__()
modules = []
# 1x1 卷积
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 多个不同空洞率的空洞卷积
for rate in rates:
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 图像级特征
modules.append(nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.convs = nn.ModuleList(modules)
# 投影层
self.project = nn.Sequential(
nn.Conv2d((len(rates) + 2) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x):
size = x.shape[2:]
res = []
for conv in self.convs:
y = conv(x)
if y.shape[2:] != size:
y = F.interpolate(y, size=size, mode='bilinear', align_corners=False)
res.append(y)
return self.project(torch.cat(res, dim=1))
Comments NOTHING