问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

多模态:MLLM模态对齐方法

创作时间:
作者:
@小白创作中心

多模态:MLLM模态对齐方法

引用
CSDN
1.
https://blog.csdn.net/WiSirius/article/details/142861112

多模态模型在AI领域正掀起一股热潮,其中模态对齐是实现多模态理解的关键技术。本文将介绍三种主流的模态对齐方法,通过分析LLaVA、Flamingo和BLIP-2等代表性模型的架构和实现细节,帮助读者深入了解这一领域的最新进展。

前言

目前多模态模型席卷AI领域,最近也在做一些对齐的工作,记录一下目前主流的模态对齐方法。想详细了解的也可以看看下面的综述论文。

paper:https://arxiv.org/pdf/2311.07594

一、介绍

最近的代表性MLLM分为四类:

(1)将LLM作为多模态特征的直接处理器;

(2)利用多模态感知器的MLLM来处理多模态特征;

(3)将LLM作为处理多模态特征的工具;

(4)在特定格式的数据上学习,赋予LLM适应额外模态的能力

本文主要介绍目前几个完成交互的经典方法

1、LLaVA(多模态特征组合的简单样例)

LLaVA 的对齐方式相对来说比较简单,只有简单的线性层。LLaVA 的模型架构如下图所示,LLM 选择的是 Vicuna,图像编码器选择的是 CLIP 的 ViT-L/14,中间增加了一个线性层 W 将图像特征转换为跟文本 Embedding 相同维度,再一起输入到 LLM 中。

简单来说,文本内容经过embedding后输出为(1,n,c),n为文本token数量,c表示每个token的长度。图像特征经过编码器后输出为(1,n1,c1),对其特征进行重映射输出为(1,n1,c)的特征,进行concat后送入LLM。模型结构如下:

LlavaLlamaForCausalLM(
  (model): LlavaLlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
    ......
    )
    (norm): LlamaRMSNorm()
    (vision_tower): CLIPVisionTower(
    ......
    )
    (mm_projector): Sequential(
      (0): Linear(in_features=1024, out_features=4096, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=4096, out_features=4096, bias=True)
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

可以发现图像特征token长度为1024,经过mm_projector后长度变为4096。

2、Flamingo(cross-attention的模态交互样例)

Flamingo 主要做的是 Caption 任务,即输入一张图片,Flamingo 可以生成图片的标题。不同的是,Flamingo 可以输入多张图片,实现上下文学习的 Few-Shot 效果

Flamingo 的模型架构如下图所示,首先通过冻结的视觉编码器对图像进行编码,然后通过一个可训练的感知重采样器(Perceiver Resampler)重新提取特征,输出一个固定数量的视觉 tokens,这些视觉 tokens 再通过交叉注意力层被用于预训练的语言模型的每一层(LM block)

Flamingo 中插入的 Perceiver Resampler 和 GATED XATTN-DENSE 都是重新初始化的,GATED XATTN-DENSE 主要是为了根据视觉输入调整 LM,在冻结的 LM 层之间插入新的交叉注意力层。这些交叉注意力层的 keys 和 values 是从视觉特征中获得的而 queries 则是从语言输入中获得的。交叉注意力层后面跟的是 FFW,这些层都经过了门控(gated)。(门控这个概念可以追溯到LSTM,这里采用tanh函数作为门控,tanh在LSTM中作为输入门用于保留重要信息,sigmod通常作为遗忘门,LSTM学习推荐)

Flamingo(
  (vision_encoder): VisionTransformer(
    ......
  )
  (perceiver): PerceiverResampler(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PerceiverAttention(
          (norm_media): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (norm_latents): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=1024, out_features=512, bias=False)
          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=1024, bias=False)
        )
        (1): Sequential(
          (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1024, out_features=4096, bias=False)
          (2): GELU(approximate='none')
          (3): Linear(in_features=4096, out_features=1024, bias=False)
        )
      )
    )
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lang_encoder): MosaicGPT(
    ......
    (gated_cross_attn_layers): ModuleList(
      (0-23): 24 x GatedCrossAttentionBlock(
        (attn): MaskedCrossAttention(
          (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=2048, out_features=512, bias=False)
          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=2048, bias=False)
        )
        (ff): Sequential(
          (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=2048, out_features=8192, bias=False)
          (2): GELU(approximate='none')
          (3): Linear(in_features=8192, out_features=2048, bias=False)
        )
      )
    )
  )
)

可以发现,重采样器也是大量线性变换组合而成,对图像特征进行转化。

3、BLIP-2(结构与策略的交互样例)

BLIP-2 的论文中提出了一种新的视觉-语言模型预训练的方法—— Q-Former,主要分为两个阶段:① 基于冻结的图像编码器进行视觉-语言表征学习;② 基于冻结的 LLM 进行视觉-语言生成学习。Q-Former 是一个可训练的模块,通过 BERT Base 来初始化权重,用来连接冻结的图像编码器和冻结的 LLM。对于不同分辨率的图像,Q-Former 都可以通过图像编码器提取固定数量的输出特征。Q-Former 主要包括两个 Transformer 子模块,① 图像 Transformer 用于跟冻结的图像编码器交互,提取视觉特征;② 文本 Transformer 可以既作为文本编码器和文本解码器。

视觉-语言表征学习通过三种任务进行训练

图文对比学习:对齐图像表征和文本表征(Image-Text contrastive learning)。

图文匹配:判断图文对是否匹配的二分类任务(image-text matching)。

基于图像的文本生成:基于图像生成标题(image-grounded text generation)。

视觉-语言生成学习:基于训练好的 Q-Former 模块和可学习的 query embeddings 提取图像特征,然后用全连接层将 Q-Former 的输出维度跟 LLM 的输入维度进行对齐,最后再输入到 LLM 中。

BLIP2结构如下

Blip2OPT(
  (visual_encoder): VisionTransformer(
    ......
  )
  (ln_vision): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
  (Qformer): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): None
        (position_embeddings): None
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x: BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (crossattention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=1408, out_features=768, bias=True)
                (value): Linear(in_features=1408, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): None
            (output): None
            (intermediate_query): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output_query): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
    )
    (cls): None
  )
  (opt_model): OPTForCausalLM(
    ......
  )
  (opt_proj): Linear(in_features=768, out_features=2560, bias=True)
)

总结

个人觉得目前模态对齐的方法其实还是集中于第一种和第二种方法,即合并特征或使用cross-attention的方式。但是现在decoder的构建通常只使用self-attention进行完成,因此第二种方式也很少用了。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号