因果链,让大模型学会推理
因果链,让大模型学会推理
AI已经在帮助数学家和科学家进行研究,而强大的因果推理能力是AI在这些领域大展拳脚的关键。最近,微软、MIT和印度理工学院海得拉巴分校的研究团队提出了一种新的因果推理学习方法——公理训练(axiomatic training)。这种方法通过展示因果链的符号演示,让Transformer模型学会因果推理,甚至能够泛化到更复杂的场景中。
引言
因果推理(causal reasoning)可以定义为一组遵循特定因果性预定义公理或规则的推理流程。例如,d-separation(有向分离)和do-calculus规则可被视为公理,而collider set或backdoor set的规范则可被看作是由公理推导出的规则。
通常,因果推理使用的数据对应于系统中的变量。通过正则化、模型架构或特定的变量选择,可以将公理或规则以归纳偏置的形式集成到机器学习模型中。根据可用数据种类的差异(观察数据、干预数据、反事实数据),Judea Pearl提出了“因果阶梯”,定义了不同类型的因果推理。
公理训练
研究团队假设,因果公理可以表示为符号元组⟨premise, hypothesis, result⟩。其中,hypothesis是指假设,即因果陈述;premise是前提,用于确定该陈述是否为“真”的任意相关信息;result自然是结果,可以是简单的“是”或“否”。
例如,来自论文《Can large language models infer causation from correlation?》的collider公理可以表示为:
基于这个模板,可以通过修改变量名称、变量数量和变量顺序等来生成大量合成元组。
公理训练:数据集、损失函数和位置编码
训练数据
基于特定公理,可以根据“前提”将“假设”映射成合适的标签(Yes或No)。要创建训练数据集,研究团队的做法是在特定的变量设置X、Y、Z、A下枚举所有可能的元组{(P, H, L)}_N,其中P是前提,H是假设,L是标签(Yes或No)。
给定一个基于某个因果图谱的前提P,如果可通过使用特定的公理(一次或多次)推导出假设P,那么标签L就为Yes;否则为No。
损失函数
给定一个数据集,损失函数的定义基于每个元组的基本真值标签,表示为:
分析表明,相比于下一token预测,使用该损失能得到很有希望的结果。
位置编码
位置编码的选择是另一个重要因素。研究团队使用了不同的位置编码来理解其对因果任务中的泛化的影响,包括可学习位置编码(LPE)、正弦位置编码(SPE)、无位置编码(NoPE)。
实验
研究团队使用因果无关型公理的符号演示从头开始训练了一个Transformer模型。为了评估其泛化性能,他们在简单的大小为3-6个节点的因果无关公理链上进行了训练,然后测试了泛化性能的多个不同方面,包括长度泛化性能(大小7-15的链)、名称泛化性能(更长的变量名)、顺序泛化性能(带有反向的边或混洗节点的链)、结构泛化性能(带有分支的图谱)。
研究团队基于GPT-2架构训练了一个基于解码器的有6700万参数的模型。该模型有12个注意力层、8个注意力头和512嵌入维度。他们在每个训练数据集上从头开始训练了该模型。为了理解位置嵌入的影响,他们还研究了三种位置嵌入设置:正弦位置编码(SPE)、可学习位置编码(LPE)和无位置编码(NoPE)。
结果
表1给出了在训练时未曾见过的更大因果链上评估时不同模型的准确度。可以看到,新模型TS2 (NoPE)的表现能与万亿参数规模的GPT-4相媲美。
图3是在有更长节点名称(长于训练集的)的因果序列上的泛化能力评估结果以及不同位置嵌入的影响。
图4评估的是在更长的未见过的因果序列上的泛化能力。
研究团队发现,在简单链上训练的模型可以泛化到在更大的链上多次应用公理,但却无法泛化到顺序或结构泛化等更复杂的场景。但是,如果在简单链以及带有随机逆向边的链组成的混合数据集上训练模型,则模型可以很好地泛化到各种评估场景。
通过扩展在NLP任务上的长度泛化研究结果,他们发现了位置嵌入在确保在长度和其它方面实现因果泛化的重要性。他们表现最佳的模型没有位置编码,但他们也发现正弦编码在某些情况下也很好用。
这种公理训练方法还能泛化用于一个更困难的问题,如图5所示。即以包含统计独立性陈述的前提为基础,任务目标是根据因果关系分辨相关性。解决该任务需要多个公理的知识,包括d-separation和马尔可夫性质。
研究团队使用与上面一样的方法生成了合成训练数据,然后训练了一个模型,结果发现在包含3-4个变量的任务演示上训练得到的Transformer能学会解决包含5个变量的图谱任务。并且在该任务上,该模型的准确度高于GPT-4和Gemini Pro等更大型的LLM。
研究团队表示:“我们的研究提供了一种通过公理的符号演示教模型学习因果推理的新范式,我们称之为公理训练(axiomatic training)。”该方法的数据生成和训练流程是普适的:只要一个公理能被表示成符号元组的格式,就可使用此方法学习它。
总结
这项研究展示了通过公理训练让AI模型学会因果推理的潜力。这种方法不仅能够帮助模型在简单场景中进行因果推理,还能够泛化到更复杂的场景中。这对于提高AI在科学研究、决策支持等领域的应用具有重要意义。