🏈[论文] InstructGPT:基于人类反馈训练语言模型遵从指令的能力
2024-7-29
| 2025-4-14
字数 4887阅读时长 13 分钟
type
status
password
date
slug
summary
category
URL
tags
icon

摘要

notion image
增大模型尺寸未必就能提高它对用户意图的理解能力。 例如,一些大模型可能会生成不真实、有毒或对用户并无帮助(untruthful, toxic, or simply not helpful)的输出。 换句话说,这些模型与它们的用户没有对齐(not aligned)。
本文展示了一种基于人类反馈进行微调(fine-tuning with human feedback), 从而在各种任务上将语言模型与用户意图对齐的方法。简单来说,
  • 先收集一组“预期的模型行为应该是什么样”的数据集, 然后使用监督学习来微调 GPT-3(SFT),
  • 接着,收集一组排名形式组织的模型输出(rankings of model outputs)作为数据集, 使用人类反馈强化学习(RLHF)进一步微调上一步得到的模型。

1 引言

使用人类反馈强化学习(RLHF)主要是为了让语言模型的行为与用户意图保持一致。传统的语言模型在训练时,通常以预测互联网网页上下一个词为目标,这种目标与 “遵循用户指令并提供有益、安全帮助” 的实际需求存在差异,导致模型可能出现捏造事实、生成有偏见或有害的文本,或者干脆不遵循用户指令。而RLHF能够利用人类偏好作为奖励信号,通过以下方式优化模型:
  1. 收集例范数据(demonstration data),训练一个监督策略(supervised policy)。
    1. 对于给定的输入,标注员给出期望的行为 (详见 3.2 节)。然后,使用监督学习(supervised learning)对一个预训练的 GPT-3 模型进行微调。
  1. 收集对比数据(comparison data),训练一个奖励模型(RM)。
    1. 对给定输入,收集两个输出,标注员给出他们的偏好(which output they prefer)。然后,训练一个奖励模型来预测人类偏好输出(human-preferred output)。
  1. 针对奖励模型,使用 PPO 对策略进行优化(optimize a policy)。
    1. 将 RM 的输出作为一个标量奖励。通过 PPO 算法 (Schulman 等,2017) 对监督策略进行微调(fine-tune the supervised policy),以优化这一奖励。
通过RLHF,InstructGPT模型在多个方面取得了显著改进,如在人工评估中更受标注员青睐、在真实性和毒性方面有所改善,并且在一定程度上减轻了在公开NLP数据集上的性能下降问题。这表明RLHF是使语言模型与人类意图对齐的有效方法,为改进语言模型的行为提供了有前景的方向。
notion image
Figure 2: InstructGPT 三部曲:(1) SFT, (2) RM training, (3) RLHF via proximal policy optimization (PPO) on RM.
蓝色箭头表示相应的数据用于训练模型。Step 2 中 A-D 是模型输出的采样,然后标注员对它们进行排序。详见 Section 3。
Figure 2: InstructGPT 三部曲:(1) SFT, (2) RM training, (3) RLHF via proximal policy optimization (PPO) on RM. 蓝色箭头表示相应的数据用于训练模型。Step 2 中 A-D 是模型输出的采样,然后标注员对它们进行排序。详见 Section 3。

2 相关工作

pass

3 方法论与实验详情

3.1 高级方法

  1. 收集例范数据(demonstration data),训练一个监督策略(supervised policy)。
    1. 对于给定的输入,标注员给出期望的行为 (详见 3.2 节)。然后,使用监督学习(supervised learning)对一个预训练的 GPT-3 模型进行微调。
  1. 收集对比数据(comparison data),训练一个奖励模型(RM)。
    1. 对给定输入,收集两个输出,标注员给出他们的偏好(which output they prefer)。然后,训练一个奖励模型来预测人类偏好输出(human-preferred output)。
  1. 针对奖励模型,使用 PPO 对策略进行优化(optimize a policy)。
    1. 将 RM 的输出作为一个标量奖励。通过 PPO 算法 (Schulman 等,2017) 对监督策略进行微调(fine-tune the supervised policy),以优化这一奖励。
步骤 2 和步骤 3 可以持续迭代;基于当前最优策略收集更多比较数据,用于训练新的奖励模型,进而训练新的策略。在实践中,我们的比较数据大多来自监督策略,也有部分来自 PPO 策略。
notion image
Figure 2: InstructGPT 三部曲:(1) SFT, (2) RM training, (3) RLHF via proximal policy optimization (PPO) on RM.
蓝色箭头表示相应的数据用于训练模型。Step 2 中 A-D 是模型输出的采样,然后标注员对它们进行排序。详见 Section 3。
Figure 2: InstructGPT 三部曲:(1) SFT, (2) RM training, (3) RLHF via proximal policy optimization (PPO) on RM. 蓝色箭头表示相应的数据用于训练模型。Step 2 中 A-D 是模型输出的采样,然后标注员对它们进行排序。详见 Section 3。

3.2 数据集

我们的提示数据集主要由提交到OpenAI API的文本提示构成,具体是那些在Playground界面上使用早期版本InstructGPT模型的用户提交的提示。在本文中,我们不使用生产环境中API客户的数据。我们通过检查提示是否有长的公共前缀来启发式地去除重复提示,并将每个用户ID的提示数量限制为200条。我们还基于用户ID创建训练集、验证集和测试集,确保验证集和测试集中不包含训练集中用户的数据。为避免模型学习到潜在敏感的客户细节,我们会在训练集中过滤掉包含个人身份信息(PII)的提示。

3.2.1 冷启动

为了训练最初的 InstructGPT 模型,我们要求标注员自己编写 prompt。 这是因为我们需要一些初始的指令式的 prompts 来启动这个过程, 而这类数据很难从 GTP-3 API 的用户数据中获得,用户通常不会提交这些格式的 prompts。我们要求标注员编写三种类型的 prompt :
  1. Plain: 标注员提出任意的任务,确保任务具有足够的多样性就行。
  1. Few-shot: 标注员提出一条指令,并为该指令提供多个查询/响应对,最终数据格式为(instruction、query、response)
  1. User-based: OpenAI API 的 waitlist applications 中我们列了一些使用案例。我们要求标注员提供与这些使用案例相关的 prompts。
详见附录 A。

3.2.2 SFT、RM、PPO数据集大小

我们生成了三个不同的数据集用于微调过程:(1) 监督微调数据集(共13k),包含标注员的示例,用于训练我们的SFT模型;(2) 奖励模型数据集(共33k),包含标注员对模型输出的排名,用于训练我们的RM;(3) PPO数据集(共31k),没有任何人工标注,用于强化学习从人类反馈(RLHF)微调。
source
SFT
size
RM
size
PPO
size
labeler
SFT train
11,295
RM train
6,623
customer
SFT train
1,430
RM train
26,584
PPO train
31,144
labeler
SFT valid
1,550
RM valid
3,488
customer
SFT valid
103
RM valid
14,399
PPO valid
16,185
SFT 总计
~15k
RM 总计
~50k
PPO 总计
~47k

3.2.3 不同NLP任务数据分布

Use-case
(%)
Use-case
(%)
Use-case
(%)
Generation
45.6%
Rewrite
6.6%
Other
3.5%
Open QA
12.4%
Summarization
4.2%
Closed QA
2.6%
Brainstorming
11.2%
Classification
3.5%
Extract
1.9%
Chat
8.4%

3.2.4 prompt 示例

表 2 展示了几个 prompt 示例(由研究人员编写,提交给 InstructGPT 的格式),
Table 2: API prompt 具体例子。

3.3 任务

我们的训练任务来自两个来源:(1) 标注员编写的提示数据集;(2) 提交到早期InstructGPT模型的API提示数据集(见表6)。这些提示非常多样化,涵盖生成、问答、对话、总结、提取和其他自然语言任务(见表1)。我们的数据集超过96%是英文的,不过在4.3节中,我们也探究了模型对其他语言指令的响应能力以及完成编码任务的能力。
对于每个自然语言提示,任务通常通过自然语言指令直接指定(例如 “写一个关于一只聪明青蛙的故事”),但也可以通过少样本示例(例如给出两个青蛙故事的示例,然后促使模型生成一个新故事)或implicit continuation(例如提供一个关于青蛙的故事开头)间接指定。在每种情况下,我们要求标注员尽力推断编写提示的用户的意图,并跳过任务非常不明确的输入。此外,在每种情况下,我们都要求标注员尽力推断每个 prompt 背后的用户意图, 并要求他们跳过那些任务非常模糊的 prompt。

3.4 标注员聘用标准

为生成演示数据和比较数据,并进行主要评估,我们通过Upwork和ScaleAI聘请了约40名合同工。与早期在文本摘要任务中收集人类偏好数据的工作(齐格勒等人,2019;施蒂农等人,2020;吴等人,2021)相比,我们的输入涵盖了更广泛的任务,偶尔还会涉及有争议和敏感的话题。我们的目标是挑选出对不同人口群体的偏好敏感,并且擅长识别潜在有害输出的标注员。因此,我们进行了一项筛选测试,以衡量标注员在这些方面的表现。我们选择在测试中表现出色的标注员;有关我们的选择程序和标注员人口统计信息的更多信息,请参见附录B.1。

3.4.1 对齐冲突的处理

在训练和评估过程中,我们的校准标准可能会发生冲突:例如,当用户请求可能有害的响应时。在训练过程中,我们优先考虑对用户的帮助程度(不这样做需要做出一些困难的设计决策,我们将其留作未来的工作;详见5.4节的更多讨论)。然而,在最终评估中,我们要求标注员优先考虑真实性和无害性(因为这是我们真正关心的)。

3.5 Models(模型)

我们从 GPT-3 预训练模型开始微调。GPT-3 在大量互联网数据上进行了训练,适用于各种下游任务, 但其行为尚未充分符合人类需求。基于 GPT-3,我们使用三种不同技术进行了模型微调。

3.5.1 Supervised fine-tuning (SFT)

使用监督学习的方式,在我们的示范数据上对 GPT-3 进行微调。
  • 16 epoch
  • a cosine learning rate decay
  • residual dropout 0.2
得到很多个 SFT 模型。最后根据 validation set 上的 RM 分数选择最终的 SFT 模型。
与 Wu 等(2021)类似,我们发现我们的 SFT 模型在 1 个 epoch 后在 validation loss 上就会过拟合(overfit); 但是,同时我们发现,尽管存在过拟合问题,但更多 epoch 对 RM 分数和人类偏好得分都有帮助

3.5.2 Reward modeling (RM)

将 SFT 模型去掉最后的 unembedding 层,然后从这样的模型开始训练,
  • 输入:prompt 和 response
  • 输出:一个标量奖励。
最后得到的就是一个 RM 模型。
在 Stiennon 等(2020)中,给两个模型相同的输入,然后得到两份输出作为对比数据, RM 是在这个对比数据集(dataset of comparisons)上进行训练的
为了快速收集数据,我们向标注员展示4到9个不等的响应(即一个 input/prompt 喂给模型,得到 K 个 output))供其排序。 这样每个 prompt 就对应 个对比数据。由于每个标注任务中的比较具有很强的相关性,我们发现如果简单地将比较结果打乱成一个数据集,模型在单次遍历数据集时就会过拟合。
因此,我们将每个 prompt 的所有 个对比作为单个 batch进行训练。 这样做在计算上更加高效,因为只需要一次正向遍历(forward pass of the RM for each completion, 而不是 forward passes for K completions), 并且由于不再过拟合,它实现了更好的 validation accuracy and log loss。
具体来说,奖励模型的损失函数为:
其中,
  • :prompt(输入的提示词)
  • :completion(模型的返回)
  • :模型返回的 结果中,人类更喜欢
  • :标注员给出的排序数据集
  • :是奖励模型(参数为)对于提示和响应的奖励。
最后,由于RM损失对奖励的偏移是不变的,我们在进行强化学习之前,使用bias对奖励模型进行归一化,使标注员的示例获得的平均分数为0。
💡
这里的奖励模型代替环境为大模型实时奖励,以便在PPO过程中指导大模型那颗响应效果好
代码示例

3.5.3 Reinforcement learning (RL)

notion image
我们使用PPO(舒尔曼等人,2017)在我们的环境中对SFT模型进行微调。该环境是一个多臂老虎机环境,1、它会随机给出一个客户提示,并期望得到对该提示的响应;2、根据提示和响应,它会产生一个由奖励模型确定的奖励,并结束该回合。此外,我们在每个token处添加来自SFT模型的每个token的KL散度惩罚,以减轻对奖励模型的过度优化。价值函数由RM初始化。我们将这些模型称为 PPO
我们还尝试将预训练梯度混合到PPO梯度中,以解决在公共NLP数据集上的性能下降问题。我们将这些模型称为PPO-ptx。在RL训练中,我们最大化以下组合目标函数:
其中
  • 是奖励模型的输出结果。在响应结尾给出奖励,非结尾处奖励为0。
  • 是学习到的 RL 策略,
  • 是 SFT 模型,
  • 是预训练分布(pretraining distribution)。
  • KL 奖励系数 和预训练损失系数 分别控制 KL 惩罚和预训练梯度的强度(strength)。
  • 对于 PPO模型, 设置为 0。在本文中InstructGPT指的是PPO-ptx模型,
💡
PPO算法的损失函数包括奖励以及模型的偏移两项。其中,代表在状态下采取动作得到的相对奖励(比其他动作地的奖励);

3.5.4 性能比较基线

我们将PPO模型的性能与SFT模型和GPT-3进行比较。我们还将GPT-3在提供少样本前缀以 “促使” 其进入指令遵循模式(GPT-3-prompted)时的情况进行比较,在实现上,这个前缀会被添加到用户指定的指令之前。
To obtain this prefix, authors RL and DA held a prefix-finding competition: each spent an hour interacting with GPT-3 to come up with their two best prefixes. The winning prefix was the one that led GPT-3 to attain the highest RM score on the prompt validation set. DA won.

3.6 性能评估

为评估我们的模型的 “对齐” 程度,我们首先需要明确对齐的含义,我们认为对齐的含义是有用的、诚实的和无害的。
因此,我们使用了一系列更具体的替代标准,试图捕捉部署模型中可能产生危害的不同行为:让标注员评估输出在客户助手场景中是否合适、是否贬低受保护群体、是否包含色情或暴力内容。我们还在旨在衡量偏差和毒性的数据集上对模型进行基准测试,如RealToxicityPrompts(Gehman等人,2020)和CrowS-Pairs(Nangia等人,2020)。
综上所述,我们的定量评估可分为两个部分:
  • 基于API分布的评估:我们的主要指标是在与训练分布来源相同的保留提示集上获得的人类偏好评分。在使用API提示进行评估时,我们仅选择未包含在训练集中的客户提示。但考虑到我们的训练提示是为InstructGPT模型设计的,这可能会使GPT-3基线模型处于劣势。因此,我们也在提交给GPT-3模型的API提示上进行评估;这些提示通常不是“指令跟随”风格,而是专门为GPT-3设计的。在这两种情况下,对于每个模型,我们都会计算其输出比基线策略更受偏好的频率;我们选择175B的监督微调(SFT)模型作为基线,因为其性能处于中等水平。此外,我们让标注员根据1 - 7分的李克特量表对每个回复的整体质量进行评分,并为每个模型输出收集一系列元数据(见表3)。
  • 基于公共NLP数据集的评估:我们在两类公共数据集上进行评估:一类是衡量语言模型安全性(特别是真实性、毒性和偏差)的数据集;另一类是衡量传统NLP任务(如问答、阅读理解和总结)零样本性能的数据集。我们还对RealToxicityPrompts数据集(Gehman等人,2020)进行了人工毒性评估。我们将在所有基于采样的NLP任务中发布模型的输出样本。

3.6.1 指标

helpful
要做到有帮助,模型不仅要能遵循指令,还应该能从一个 few-shot prompt 或其他可解释的模式(例如 Q: {question}\nA:)中推断意图(infer intention)。由于给定提示的意图可能不明确或存在歧义,我们依赖标注员的判断,主要指标是标注员的偏好评分。但标注员并非提出提示的用户,所以他们从提示中推断出的意图,可能与用户的实际意图存在差异。
honest / truthfulness
在纯生成模型中,衡量诚实性颇具挑战,这需要将模型的实际输出与其对正确输出的“认知”进行比较,可模型就像一个黑箱,我们无法推断其“认知”。因此,我们通过两个指标来衡量真实性,即模型对世界的陈述是否真实:一是评估模型在封闭领域任务中编造信息的倾向(“幻觉”);二是使用TruthfulQA数据集(Lin等人,2021)。不过,这些方法也只是涉及了真实性的一小部分。
harmless
与诚实度类似,衡量语言模型的有害性也很难。 在大多数情况下,语言模型的危害取决于其输出在现实世界中是如何被使用的。 例如,对于一个生成有毒输出的模型,如果部署在聊天机器人环境中,可能就是有害的;如果用于数据增强以训练更准确的毒性检测模型,则可能是有益的。
在项目早期,我们让标注员评估输出是否“可能有害”。 但由于这需要大量猜测输出的最终用途,尤其是我们的数据来自与Playground API接口交互的客户(而非实际生产用例),所以后来放弃了这种做法。

4 结果

暂略。 见原文。

5 问题讨论

暂略。 见原文。
 
  • LLM
  • 强化学习
  • GPT、GPT-2、GPT-3[论文] LLaMA:开放和高效的基础语言模型集
    Loading...