问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

LoRA基于BERT模型微调实践

创作时间:
作者:
@小白创作中心

LoRA基于BERT模型微调实践

引用
CSDN
1.
https://m.blog.csdn.net/weixin_42924890/article/details/142962906

LoRA(Low-Rank Adaptation)是一种用于微调预训练模型的高效技术,特别适用于大规模模型。本文将通过一个基于BERT的10分类任务,详细展示如何使用LoRA进行模型微调,并与全量微调进行对比分析。

前言

在之前的系列文章中,我们已经介绍了LoRA微调的一些基础知识,包括其基本概念、模型结构中的可训练参数以及配置详解等。本文将通过一个具体的实践案例,展示如何使用LoRA技术对BERT模型进行微调。

LoRA微调核心代码

LoRA微调的代码相对简洁,只需要在原始模型的某层(Linear层)增加额外的配置即可。以下是核心配置代码:

config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    target_modules=["query", "key", "value"],
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
model = get_peft_model(model, config)

核心超参数说明:

  • task_type:任务类型,例如序列分类(SEQ_CLS)。
  • r:LoRA模型中新增加的权重矩阵的秩。
  • lora_alpha:控制LoRA模块中缩放因子的大小,默认值为8(peft==0.5.0)。
  • lora_dropout:LoRA模块中的dropout。
  • inference_mode:False表示模型处于训练模式,LoRA层会进行更新;反之,LoRA层不更新(推理使用)。

完整训练代码

以下是完整的训练代码实现:

import pandas as pd
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
import model_config
from peft import get_peft_model, LoraConfig, TaskType

# 数据加载
train_data_path = model_config.train_data_path
dev_data_path = model_config.dev_data_path
train_data = pd.read_csv(train_data_path)
train_texts = train_data["0"].tolist()
train_labels = train_data["1"].tolist()
dev_data = pd.read_csv(dev_data_path)
eval_texts = dev_data["0"].tolist()
eval_labels = dev_data["1"].tolist()

# 分类标签数
num_labels = len(set(train_labels))

# 预训练模型加载
model_name = model_config.model_name_tokenizer_path
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

# 文本编码
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=64)
eval_encodings = tokenizer(eval_texts, truncation=True, padding=True, max_length=64)

# 数据集定义
class TextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = TextDataset(train_encodings, train_labels)
eval_dataset = TextDataset(eval_encodings, eval_labels)

# 训练参数设置
training_args = TrainingArguments(
    output_dir="./lora_results",
    logging_dir="./lora_logs",
    save_strategy="steps",
    save_total_limit=1,
    evaluation_strategy="steps",
    save_steps=250,
    eval_steps=125,
    load_best_model_at_end=True,
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=1250,
    weight_decay=0.001,
    dataloader_drop_last=True,
)

# LoRA配置
config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    target_modules=["query", "key", "value"],
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
model = get_peft_model(model, config)

# 训练器设置
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# 开始训练
t1 = time.time()
trainer.train()
t2 = time.time()
delta_time = t2 - t1
print(f"Train cost: {delta_time:.4f} s")

批量测试

LoRA训练完成后,需要加载LoRA权重进行推理使用。以下是推理的核心代码:

config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    target_modules=["query", "key", "value"],
    inference_mode=True,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=model_path_loar, config=config)

完整测试代码如下:

import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
from transformers import BertTokenizerFast, BertForSequenceClassification
import torch
import seaborn as sns
from torch.utils.data import DataLoader
from datasets import Dataset
import matplotlib.pyplot as plt
import model_config
from peft import get_peft_model, PeftModel, LoraConfig, TaskType

# 测试集加载
test_data_path = model_config.test_data_path
test_data = pd.read_csv(test_data_path) 
texts = test_data["0"].tolist()
labels = test_data["1"].tolist()

# 模型加载
model_path = model_config.model_name_tokenizer_path
model_path_loar = model_config.model_path_lora
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=10)

config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    target_modules=["query", "key", "value"],
    inference_mode=True,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
model = PeftModel.from_pretrained(model, model_id=model_path_loar, config=config)

# 数据处理
def collate_fn(batch):
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encoding = tokenizer(texts, padding=True, truncation=True, max_length=64, return_tensors="pt")
    encoding["labels"] = torch.tensor(labels)
    return encoding

batch_size = model_config.test_batch_size
dataset = Dataset.from_dict({"text": texts, "label": labels})
data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

predictions = []
for batch in data_loader:
    inputs = {k: v for k, v in batch.items() if k != "labels"}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    if model.config.num_labels > 2:
        batch_predictions = torch.argmax(logits, dim=1).tolist()
    else:
        batch_predictions = (logits > 0.5).squeeze().tolist()
    predictions.extend(batch_predictions)

# 评估指标计算
accuracy = accuracy_score(labels, predictions)
precision = precision_score(labels, predictions, average="weighted", zero_division=0)
recall = recall_score(labels, predictions, average="weighted", zero_division=0)

# 输出结果
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

# 混淆矩阵绘制
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.savefig("img/confusion_matrix_lora.png")
plt.show()

训练可视化

模型重训练的完整代码仓库中包含训练过程的可视化结果。以下是关键的可视化对比:

  • 重训练评估过程

  • 训练损失持续降低,但评估数据集出现类似抛物线的走势,表明发生了过拟合。

  • 训练时长:1746.2688秒。

  • LoRA微调训练评估过程

  • 训练损失持续降低,评估数据集趋于收敛,未发生过拟合。

  • 训练时长:3140.1472秒。

评估结果可视化

通过比较LoRA微调和全量微调的混淆矩阵,可以看出:

  • 全量微调的混淆矩阵

  • Accuracy: 0.9370

  • Precision: 0.9369

  • Recall: 0.9370

  • LoRA微调的混淆矩阵

  • Accuracy: 0.9325

  • Precision: 0.9327

  • Recall: 0.9325

虽然LoRA微调在模型表现上略逊于全量训练,但差距并不明显。对于大规模预训练模型而言,LoRA技术使得微调成为可能,并且在特定任务上的表现是可以接受的。



© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号