问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

小白视角:利用 vllm serve 新的 Embedding Model

创作时间:
作者:
@小白创作中心

小白视角:利用 vllm serve 新的 Embedding Model

引用
1
来源
1.
https://aijishu.com/a/1060000000491766

本文详细介绍了如何使用vllm框架部署和使用embedding模型,特别是针对gte-7b模型的部署。文章通过对比vllm和SGLang两个框架的实现方式,说明了如何修改vllm以支持gte模型的embedding功能。

vllm如何处理embedding/completion请求?

在vllm框架中,处理embedding和completion请求的函数分别位于/vllm/engine/async_llm_engine.py文件中:

async def generate(
        self,
        inputs: PromptInputs,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
    ) -> AsyncGenerator[RequestOutput, None]:

        async for output in await self.add_request(
                request_id,
                inputs,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
        ):
            yield LLMEngine.validate_output(output, RequestOutput)

    async def encode(
        self,
        inputs: PromptInputs,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:

        async for output in await self.add_request(
                request_id,
                inputs,
                pooling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
        ):
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)

当调用OpenAI的embedding或completion接口时,会分别调用上述的encode函数和generate函数,以获取embedding或completion结果。值得注意的是,在vllm中,任何一个模型都可以接受embedding与completion请求。

如何修改Qwen2ForCausalLM以支持embedding请求?

尝试直接使用vllm部署gte-7b模型:

CUDA_VISIBLE_DEVICES=0 vllm serve 7embed --dtype auto --api-key \
sk-1dwqsdv4r3wef3rvefg34ef1dwRv --tensor-parallel-size 1  \
 --max-model-len 32768 --enforce-eager \
 --disable-custom-all-reduce --port 7777 --served-model-name e5_7b

发送embedding请求时会报错(pooler not implemented)。进一步观察vllm中支持的qwen2模型实现(位于vllm/model_executor/models/qwen2.py):

class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
        config: Qwen2Config,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ) -> None:
        # ...
        super().__init__()

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
        self.model = Qwen2Model(config, cache_config, quant_config)

        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.hidden_size,
                                          quant_config=quant_config)

        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   attn_metadata, intermediate_tensors)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

对比SGLang中的qwen2模型实现(位于python/sglang/srt/models/qwen2.py):

class Qwen2ForCausalLM(nn.Module):
    def __init__(
        self,
        config: Qwen2Config,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen2BaseModel(config, quant_config=quant_config)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.logits_processor = LogitsProcessor(config)

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
        input_embeds: torch.Tensor = None,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head.weight, input_metadata
        )

发现vllm中每个模型的forward函数仅返回hidden_states,而logits_processor是在compute_logits函数中实现的。SGLang的forward函数则将vllm中的forwardlogits_processor合并在了一起,直接返回logits。基于此设计,vllm的generate请求调用的是compute_logits函数,而SGLang的generate请求调用的是forward函数。

为了解决embedding请求的问题,可以在vllm已实现的Qwen2ForCausalLM类中添加pooler函数:

class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
    # ...

    def __init__(
        self,
        config: Qwen2Config,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ) -> None:
        # ...
        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
        self.sampler = Sampler()

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    # ...

pooler相关的代码可以从vllm/model_executor/models/llama_embedding.py中复制。这样修改后,任何架构为Qwen2ForCausalLM的模型都可以支持embedding请求了。这是因为:

  • 对于任何架构为Qwen2ForCausalLM的模型(如Qwen/Qwen2-72B-InstructAlibaba-NLP/gte-Qwen2-7B-instruct),这个模型会被映射到Qwen2ForCausalLM类上。
  • 当用户调用completion请求时,engine会调用compute_logits函数;当用户调用embedding请求时,engine会调用pooler函数。这样即便是一个模型同时作为embedding model和completion model使用,也可以通过不同的函数调用避免冲突。

对于SGLang的实现,由于embedding和completion请求都调用相同的forward函数,因此无法通过类似的方法进行更改。

完成上述修改后,需要在vllm的vllm/model_executor/models/__init__.py文件中将gte映射到Qwen2ForCausalLM类:

_EMBEDDING_MODELS = {
    "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
    "Qwen2ForCausalLM": ("Qwen2ForCausalLM"),
}

回顾与问题

通过对比vllm和SGLang的实现方式,可以看出vllm将embedding请求与completion请求分设接口的设计极大地帮助了扩展接口。而SGLang由于两种请求没有分设接口,因此出现了同一个架构无法映射到两个类的冲突。

在尝试将得到的embedding与sentence_transformer的embedding进行对比时,发现vllm返回的embedding存在以下问题:

  1. hidden state的维度是sentence_transformer的两倍;
  2. hidden state的偶数维全是0;
  3. hidden state的数值远大于sentence_transformer的数值,怀疑是归一化问题,但目前尚未解决。

总之,通过本文的分析和实现,可以更好地理解vllm框架中embedding和completion请求的处理机制,并为其他类似框架的实现提供参考。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号