FPN-特征金字塔网络
class EnhancedFPN(nn.Module):
"""增强的特征金字塔网络"""
def __init__(self, in_channels_list, out_channels):
super(EnhancedFPN, self).__init__()
# 横向连接
self.lateral_convs = nn.ModuleList([
nn.Conv2d(in_channels, out_channels, 1)
for in_channels in in_channels_list
])
# 输出卷积
self.fpn_convs = nn.ModuleList([
nn.Conv2d(out_channels, out_channels, 3, padding=1)
for _ in in_channels_list
])
# 特征融合
self.fusion_conv = nn.Sequential(
nn.Conv2d(out_channels * len(in_channels_list), out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, features):
# 横向连接
laterals = [conv(f) for f, conv in zip(features, self.lateral_convs)]
# 自顶向下的路径
for i in range(len(laterals) - 1, 0, -1):
laterals[i-1] = laterals[i-1] + F.interpolate(
laterals[i], size=laterals[i-1].shape[-2:],
mode='bilinear', align_corners=False
)
# 输出卷积
outputs = [conv(lateral) for lateral, conv in zip(laterals, self.fpn_convs)]
# 融合所有尺度的特征
target_size = outputs[0].shape[-2:]
resized_outputs = []
for output in outputs:
if output.shape[-2:] != target_size:
output = F.interpolate(output, size=target_size, mode='bilinear', align_corners=False)
resized_outputs.append(output)
# 融合
fused = torch.cat(resized_outputs, dim=1)
return self.fusion_conv(fused)
Comments NOTHING