【Block总结】ASSA,自适应稀疏自注意力,减少无关区域的噪声干扰
【Block总结】ASSA,自适应稀疏自注意力,减少无关区域的噪声干扰
本文介绍了一篇发表在CVPR 2024上的论文《Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration》,该论文提出了一种新的图像恢复方法——自适应稀疏变换器(Adaptive Sparse Transformer, AST)。AST通过自适应稀疏自注意力(ASSA)和特征精炼前馈网络(FRFN)两个关键组件,显著提升了图像恢复的性能。
一、论文信息
- 标题: Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration
- 论文链接: https://openaccess.thecvf.com/content/CVPR2024/papers/Zhou_Adapt_or_Perish_Adaptive_Sparse_Transformer_with_Attentive_Feature_Refinement_CVPR_2024_paper.pdf
- GitHub链接: https://github.com/joshyZhou/AST
二、创新点
该论文提出了一种新的图像恢复方法——自适应稀疏变换器(Adaptive Sparse Transformer, AST),其主要创新点包括:
- 自适应稀疏自注意力(ASSA): 通过双分支结构,稀疏分支过滤低查询-键匹配分数的影响,减少无关区域的噪声干扰;密集分支则确保信息流的充分传递,帮助网络学习判别性表示。
- 特征精炼前馈网络(FRFN): 采用增强和简化机制消除通道上的特征冗余,从而提升图像恢复的效果。
三、方法
AST模型的设计包括以下几个关键步骤:
- 特征提取: 通过变换器架构提取图像特征,重点关注重要区域。
- 自适应机制: 通过ASSA模块,动态调整对不同区域的关注程度,减少无关特征的干扰。
- 特征精炼: FRFN模块通过增强和简化方案,消除冗余特征,提升恢复质量。
- 注意力机制: 引入注意力机制,增强对关键特征的关注,提升图像恢复的精度。
FRFN模块
特征精炼前馈网络(Feature Refinement Feedforward Network, FRFN)是自适应稀疏变换器(AST)模型中的一个关键组件,旨在通过增强和简化机制来消除通道中的特征冗余,从而提升图像恢复的效果。FRFN通过对特征进行优化处理,增强了模型在图像恢复任务中的表现。
FRFN的设计理念主要集中在以下几个方面:
- 增强和简化机制: FRFN通过对特征进行变换和优化,旨在提升通道维度上的信息表达能力,减少冗余信息的干扰。
- 部分卷积(PConv): FRFN采用部分卷积操作来强化特征中的有用元素,有助于提取关键信息。这种卷积方式能够有效处理缺失数据,提升特征提取的准确性。
- 门控机制: 通过门控机制,FRFN能够限制无用信息的传播,减少不相关特征对图像恢复的干扰,从而提高恢复效果。
特征精炼前馈网络(FRFN)通过其增强和简化机制,有效地消除了通道中的特征冗余,提升了图像恢复的效果。通过结合部分卷积和深度卷积,FRFN能够在保留重要信息的同时,减少冗余特征的干扰,从而增强模型的图像恢复能力。该模块在自适应稀疏变换器(AST)模型中发挥了重要作用,为图像恢复任务提供了强有力的支持。
四、效果
实验结果表明,AST在多个图像恢复任务中表现优异,尤其是在去噪、去模糊和去雨等任务中,显著提高了恢复图像的质量。与传统方法相比,AST在处理复杂场景时能够更好地保留细节和结构。
五、实验结果
论文中进行了大量实验,验证了AST的有效性。实验结果显示:
- 在标准数据集上,AST在PSNR(峰值信噪比)和SSIM(结构相似性指数)等指标上均优于现有的最先进方法。
- 具体实验包括雨痕去除、雨滴去除和真实雾霾去除等任务,AST在这些任务中均展现出更好的恢复效果。
六、总结
总的来说,论文《Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration》提出了一种创新的自适应稀疏变换器模型,显著提升了图像恢复的性能。通过有效的特征提取和冗余消除,AST为图像恢复领域提供了新的思路和方法,具有广泛的应用潜力。
代码
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from einops import repeat
import math
from einops import rearrange
class FRFN(nn.Module):
def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False):
super().__init__()
self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2),
act_layer())
self.dwconv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
act_layer())
self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
self.dim = dim
self.hidden_dim = hidden_dim
self.dim_conv = self.dim // 4
self.dim_untouched = self.dim - self.dim_conv
self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)
def forward(self, x):
# bs x hw x c
bs, hw, c = x.size()
hh = int(math.sqrt(hw))
# spatial restore
x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
x1, x2, = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
# flaten
x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh)
x = self.linear1(x)
# gate mechanism
x_1, x_2 = x.chunk(2, dim=-1)
x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
x_1 = self.dwconv(x_1)
x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h=hh, w=hh)
x = x_1 * x_2
x = self.linear2(x)
# x = self.eca(x)
return x
if __name__ == "__main__":
dim = 64
# 如果GPU可用,将模块移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 输入张量 (batch_size, channels,height, width)
x = torch.randn(2, 40*40,dim).to(device)
# 初始化 ESSAttn 模块
win_size = (40,40)
block = FRFN(dim,dim*4)
print(block)
block = block.to(device)
# 前向传播
output = block(x)
print("输入:", x.shape)
print("输出:", output.shape)
本文原文来自CSDN