问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

CBAM注意力机制原理详解及源码解析

创作时间:
作者:
@小白创作中心

CBAM注意力机制原理详解及源码解析

引用
CSDN
1.
https://m.blog.csdn.net/qq_47233366/article/details/137103025

CBAM(Convolutional Block Attention Module)是一种常见的注意力机制,广泛应用于深度学习和计算机视觉领域。本文将详细介绍CBAM的原理和源码实现,帮助读者深入理解这一技术。

CBAM原理详解

CBAM的整体结构如下图所示:

首先,输入特征图(Input Feature)会分别送入Channel Attention Module(通道注意力模块)和Spatial Attention Module(空间注意力模块)。这两个模块会分别生成通道注意力权重和空间注意力权重,最终得到精制特征图(Refined Feature)。

Channel Attention Module

通道注意力模块的结构如下:

具体步骤如下:

  1. 对输入特征图F进行全局最大池化和全局平均池化,得到两个1 × 1 × C的特征图
  2. 将这两个特征图送入两个全连接层(MLP)
  3. 将两个特征图相加并经过sigmoid激活函数,得到通道注意力权重M_c

Spatial Attention Module

空间注意力模块的结构如下:

具体步骤如下:

  1. 对通道注意力模块的输出特征图F'进行全局最大池化和全局平均池化,得到两个H × W × 1的特征图
  2. 将这两个特征图在channel维度拼接,得到H × W × 2的特征图
  3. 进行一次卷积操作,将特征图大小变为H × W × 1
  4. 经过sigmoid激活函数,得到空间注意力权重M_s

CBAM代码详解

下面是CBAM的完整代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv1(x))
        return out

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=1, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(in_planes, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_att(x) * x
        print(self.channel_att(x).shape)
        print(f"channel Attention Module:{out.shape}")
        out = self.spatial_att(out) * out
        print(self.spatial_att(out).shape)
        #print(f"Spatial Attention Module:{out.shape}")
        return out

if __name__ == '__main__':
    # Testing
    model = CBAM(3)
    input_tensor = torch.ones((1, 3, 224, 224))
    output_tensor = model(input_tensor)
    print(f'Input shape: {input_tensor.shape})')
    print(f'Output shape: {output_tensor.shape}')

运行结果如下:

第一行表示理论部分的特征图M_c,第二行表示F',第三行表示特征图M_s,第四行表示输入,第五行表示CBAM的最终输出。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号