实现奖励函数以供数据集使用

最后更新:06/02/2025。

对于每个数据集,我们需要实现一个奖励函数,或利用奖励模型来为生成的响应计算奖励。 我们在 reward_score directory 中已预实现了部分奖励函数。 您也可以使用定制的奖励函数。

目前,我们支持 GSM8k 和 MATH 数据集的奖励函数。对于 RLHF 数据集 (例如, full_hh_rlhf) 和代码生成 (例如,APPS),我们分别利用奖励模型 和 SandBox (即将开源) 进行评估。

RewardManager

在 PPO 后训练脚本的入口点 main_ppo.py 中, 我们实现了一个 RewardManager,它利用预实现的奖励函数来为每个响应计算分数。

RewardManager 中,我们实现了一个 __call__ 函数来 为每个响应计算分数。 所有奖励函数都通过 compute_score_fn 来执行。 输入是一个 DataProto,其中包括:

  • input_idsattention_mask:应用聊天模板后的 input_idsattention_mask,包括提示和响应

  • responses:响应标记

  • ground_truth:当前提示的真实答案字符串。 存储在 non_tensor_batch 中的 DataProto 内,应在 parquet 文件中 预处理。

  • data_source:当前提示的数据集名称。存储在 non_tensor_batch 中的 DataProto 内,应在 parquet 文件中 预处理。

在对响应进行去标记化后,响应字符串和真实答案字符串将被输入到 compute_score_fn 中,以为每个响应计算分数。

奖励函数

预实现

我们在 reward_score directory 中已预实现了部分奖励函数。

  • GSM8k 示例 中,我们 强制响应在四个 #### 后输出最终答案,然后 使用字符串匹配来与真实答案比较。如果完全正确,得分 1;如果格式正确,得分 0.1;如果 格式不正确,得分 0。

  • MATH 示例 中,我们遵循 lm-evaluation-harness repository 中的实现。

定制化

您可以在单独文件中实现定制奖励函数,并使用 custom_reward_function.pathcustom_reward_function.name 来指定它们。有关它们的设置,请参考 config-explain-page

您的奖励函数的参数应为 data_sourcesolution_strground_truthextra_info。 例如:

def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):
  return len(solution_str)/100

如果您只测试单个定制奖励函数,可以简单地将其命名为“compute_score”,并留下 custom_reward_function.name 未设置。

要使用不同的定制奖励函数运行多次测试,您可以为每次试验修改 custom_reward_function.pathcustom_reward_function.name 。 例如,您可能创建一个单个 my_reward.py 文件,并在其中实现多个奖励函数。这样,对于不同的试验,您只需调整 custom_reward_function.name,这样在脚本中进行多次测试会更方便。