问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

【评估指标】困惑度 Perplexity

创作时间:
作者:
@小白创作中心

【评估指标】困惑度 Perplexity

引用
CSDN
1.
https://m.blog.csdn.net/a13545564067/article/details/144540723

困惑度(Perplexity)是机器学习和自然语言处理中常用的一个概念,尤其在语言模型的评估中。它用于衡量一个模型对文本数据的预测能力,直观上表示模型对数据“不确定”的程度。

1. 定义

对于语言模型而言,Perplexity 的公式如下:

  • N:文本中单词的数量。
  • wi:文本中的第 ( i ) 个单词。
  • P(wi∣w1,w2,…,w i−1):模型在给定前面单词的情况下,预测当前单词的概率。

2. 直观理解

  • 低 Perplexity:意味着模型对数据的预测更加准确,即模型更“确信”它的预测。
  • 高 Perplexity:表示模型对数据的预测更加不确定,表现更差。

示例:

  • 如果一个模型对句子有 4 个单词的概率估计为 0.5, 0.25, 0.125, 0.125,Perplexity 反映了模型的“平均猜测”不确定程度。结果越低,模型效果越好。

取值范围

  • PPL 的取值范围是 [1, +∞)。
  • PPL = 1:模型完全确定地正确预测所有单词(理论最优)。
  • PPL > 1:模型有一定的不确定性或错误。
  • PPL 越大:表示模型对序列的预测不确定性越高,质量越差。

3. 为什么重要?

  • Perplexity 是一种常见的语言模型评价指标。在训练语言模型时,目标是最小化 Perplexity。
  • 一个较低的 Perplexity 表示模型能更好地预测文本的概率分布,也就意味着它对语言的建模更加精准。

4. 关系到实际应用

  • 语言模型:如 GPT、BERT 等模型的训练过程中使用 Perplexity 来衡量性能。
  • 语音识别和机器翻译:Perplexity 也是这些领域中衡量模型质量的重要指标之一。

简单来说,Perplexity 就是衡量模型对文本“预测难度”的一种量化指标

Python代码实现

import transformers, torch, os
from math import exp

MODEL_PATH = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"

class PerplexityCalculator:
    def __init__(self,):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)
        self.model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.float32,)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        self.model.eval()
    
    def get_perplexity(self, text: str) -> float:
        with torch.no_grad():
            text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
            model_inputs = self.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)
            logits = self.model(**model_inputs, use_cache=True)['logits']
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = model_inputs['input_ids'][..., 1:].contiguous()
            loss = self.loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1))
            sequence_loss = loss.sum() / len(loss)
            loss_list = sequence_loss.cpu().item()
        return exp(loss_list)
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号