Lag-Llama:首个开源时序预测基础模型,零样本预测能力超越传统模型
Lag-Llama:首个开源时序预测基础模型,零样本预测能力超越传统模型
什么是Lag-Llama?
Lag-Llama是首个开源的时间序列预测基础模型,由Morgan Stanley、ServiceNow、Université de Montréal、Mila-Quebec和McGill University等多个机构联合开发。该模型基于Transformer的纯解码器架构,灵感来自LLaMA,用于单变量概率预测。
Lag-Llama的核心技术特点
Lag-Llama的核心创新在于其独特的数据标记策略和模型架构:
滞后特征标记:Lag-Llama使用一组指定的滞后来构造序列的滞后特征。它会根据数据频率(如季度、月、周、天、小时、秒)自动选择合适的滞后项。这种策略使得模型能够很好地推广到不可见的频率。
Transformer解码器架构:模型采用基于Transformer的纯解码器结构,输入标记包括滞后时间步长和静态协变量(如秒/分、小时/天等)。输入序列通过线性投影层映射到解码器内部的隐藏维度。
分布头输出概率分布:Lag-Llama的输出层使用Student's t分布来构造不确定性区间,能够生成预测区间。这种概率预测方式提供了更丰富的信息,帮助用户理解预测的置信度。
训练与性能
Lag-Llama在来自不同领域的27个时间序列数据集上进行了训练,涵盖能源、交通、经济等多个领域。训练数据包含7965个单变量时间序列,总计约3.52亿个令牌。所有数据集都是开源的,包括ethth、Exchange和Weather等。
在零样本预测任务中,Lag-Llama展现出卓越的性能。例如,在澳大利亚电力需求数据集上,Lag-Llama的预测精度显著优于Temporal Fusion Transformer (TFT) 和DeepAR等深度学习模型。这种零样本预测能力使得模型能够快速适应新场景,无需针对特定数据集进行训练。
应用场景
Lag-Llama适用于多个领域的时序预测任务,包括但不限于:
- 能源领域:电力需求预测、能源消耗预测
- 金融领域:股票价格预测、汇率预测
- 交通领域:交通流量预测、航班延误预测
- 经济领域:GDP预测、失业率预测
使用方法
使用Lag-Llama进行预测的基本步骤如下:
环境搭建:
- 克隆GitHub仓库:
git clone https://github.com/time-series-foundation-models/lag-llama/
- 安装依赖:
pip install -r requirements.txt
- 下载预训练模型权重:
huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt
- 克隆GitHub仓库:
加载数据集:
- 使用GluonTS加载数据集:
dataset = get_dataset("australian_electricity_demand")
- 使用GluonTS加载数据集:
预测流程:
- 初始化LagLlamaEstimator对象
- 设置预测参数(如预测长度、上下文长度)
- 运行预测并生成结果
未来发展
目前Lag-Llama的实现还处于初期阶段,微调功能正在积极开发中。未来版本将支持针对特定领域数据的微调,进一步提升模型性能。此外,模型还在持续优化输入令牌的长度问题,以更好地处理长序列数据。
总结
Lag-Llama作为首个开源的时间序列预测基础模型,凭借其创新的架构和强大的零样本预测能力,在多个领域展现出广阔的应用前景。随着微调功能的完善和更多应用场景的探索,Lag-Llama有望成为大数据分析领域的重要工具。