问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

【Block总结】ASSA,自适应稀疏自注意力,减少无关区域的噪声干扰

创作时间:
作者:
@小白创作中心

【Block总结】ASSA,自适应稀疏自注意力,减少无关区域的噪声干扰

引用
CSDN
1.
https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/145476176

本文介绍了一篇发表在CVPR 2024上的论文《Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration》,该论文提出了一种新的图像恢复方法——自适应稀疏变换器(Adaptive Sparse Transformer, AST)。AST通过自适应稀疏自注意力(ASSA)和特征精炼前馈网络(FRFN)两个关键组件,显著提升了图像恢复的性能。

一、论文信息

二、创新点

该论文提出了一种新的图像恢复方法——自适应稀疏变换器(Adaptive Sparse Transformer, AST),其主要创新点包括:

  • 自适应稀疏自注意力(ASSA): 通过双分支结构,稀疏分支过滤低查询-键匹配分数的影响,减少无关区域的噪声干扰;密集分支则确保信息流的充分传递,帮助网络学习判别性表示。
  • 特征精炼前馈网络(FRFN): 采用增强和简化机制消除通道上的特征冗余,从而提升图像恢复的效果。

三、方法

AST模型的设计包括以下几个关键步骤:

  1. 特征提取: 通过变换器架构提取图像特征,重点关注重要区域。
  2. 自适应机制: 通过ASSA模块,动态调整对不同区域的关注程度,减少无关特征的干扰。
  3. 特征精炼: FRFN模块通过增强和简化方案,消除冗余特征,提升恢复质量。
  4. 注意力机制: 引入注意力机制,增强对关键特征的关注,提升图像恢复的精度。

FRFN模块

特征精炼前馈网络(Feature Refinement Feedforward Network, FRFN)是自适应稀疏变换器(AST)模型中的一个关键组件,旨在通过增强和简化机制来消除通道中的特征冗余,从而提升图像恢复的效果。FRFN通过对特征进行优化处理,增强了模型在图像恢复任务中的表现。

FRFN的设计理念主要集中在以下几个方面:

  • 增强和简化机制: FRFN通过对特征进行变换和优化,旨在提升通道维度上的信息表达能力,减少冗余信息的干扰。
  • 部分卷积(PConv): FRFN采用部分卷积操作来强化特征中的有用元素,有助于提取关键信息。这种卷积方式能够有效处理缺失数据,提升特征提取的准确性。
  • 门控机制: 通过门控机制,FRFN能够限制无用信息的传播,减少不相关特征对图像恢复的干扰,从而提高恢复效果。

特征精炼前馈网络(FRFN)通过其增强和简化机制,有效地消除了通道中的特征冗余,提升了图像恢复的效果。通过结合部分卷积和深度卷积,FRFN能够在保留重要信息的同时,减少冗余特征的干扰,从而增强模型的图像恢复能力。该模块在自适应稀疏变换器(AST)模型中发挥了重要作用,为图像恢复任务提供了强有力的支持。

四、效果

实验结果表明,AST在多个图像恢复任务中表现优异,尤其是在去噪、去模糊和去雨等任务中,显著提高了恢复图像的质量。与传统方法相比,AST在处理复杂场景时能够更好地保留细节和结构。

五、实验结果

论文中进行了大量实验,验证了AST的有效性。实验结果显示:

  • 在标准数据集上,AST在PSNR(峰值信噪比)和SSIM(结构相似性指数)等指标上均优于现有的最先进方法。
  • 具体实验包括雨痕去除、雨滴去除和真实雾霾去除等任务,AST在这些任务中均展现出更好的恢复效果。

六、总结

总的来说,论文《Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration》提出了一种创新的自适应稀疏变换器模型,显著提升了图像恢复的性能。通过有效的特征提取和冗余消除,AST为图像恢复领域提供了新的思路和方法,具有广泛的应用潜力。

代码

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from einops import repeat
import math
from einops import rearrange

class FRFN(nn.Module):
    def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False):
        super().__init__()
        self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2),
                                     act_layer())
        self.dwconv = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
            act_layer())
        self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.dim_conv = self.dim // 4
        self.dim_untouched = self.dim - self.dim_conv
        self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)

    def forward(self, x):
        # bs x hw x c
        bs, hw, c = x.size()
        hh = int(math.sqrt(hw))
        # spatial restore
        x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
        x1, x2, = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)
        # flaten
        x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh)
        x = self.linear1(x)
        # gate mechanism
        x_1, x_2 = x.chunk(2, dim=-1)
        x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
        x_1 = self.dwconv(x_1)
        x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h=hh, w=hh)
        x = x_1 * x_2
        x = self.linear2(x)
        # x = self.eca(x)
        return x

if __name__ == "__main__":
    dim = 64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(2, 40*40,dim).to(device)
    # 初始化 ESSAttn 模块
    win_size = (40,40)
    block = FRFN(dim,dim*4)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

本文原文来自CSDN

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号