实现奖励函数以供数据集使用
最后更新:06/02/2025。
对于每个数据集,我们需要实现一个奖励函数,或利用奖励模型来为生成的响应计算奖励。 我们在 reward_score directory 中已预实现了部分奖励函数。 您也可以使用定制的奖励函数。
目前,我们支持 GSM8k 和 MATH 数据集的奖励函数。对于 RLHF 数据集 (例如, full_hh_rlhf) 和代码生成 (例如,APPS),我们分别利用奖励模型 和 SandBox (即将开源) 进行评估。
RewardManager
在 PPO 后训练脚本的入口点 main_ppo.py 中,
我们实现了一个 RewardManager,它利用预实现的奖励函数来为每个响应计算分数。
在 RewardManager 中,我们实现了一个 __call__ 函数来
为每个响应计算分数。
所有奖励函数都通过 compute_score_fn 来执行。
输入是一个 DataProto,其中包括:
input_ids、attention_mask:应用聊天模板后的input_ids和attention_mask,包括提示和响应responses:响应标记ground_truth:当前提示的真实答案字符串。 存储在non_tensor_batch中的DataProto内,应在 parquet 文件中 预处理。data_source:当前提示的数据集名称。存储在non_tensor_batch中的DataProto内,应在 parquet 文件中 预处理。
在对响应进行去标记化后,响应字符串和真实答案字符串将被输入到 compute_score_fn 中,以为每个响应计算分数。
奖励函数
预实现
我们在 reward_score directory 中已预实现了部分奖励函数。
在 GSM8k 示例 中,我们 强制响应在四个 #### 后输出最终答案,然后 使用字符串匹配来与真实答案比较。如果完全正确,得分 1;如果格式正确,得分 0.1;如果 格式不正确,得分 0。
在 MATH 示例 中,我们遵循 lm-evaluation-harness repository 中的实现。
定制化
您可以在单独文件中实现定制奖励函数,并使用 custom_reward_function.path 和 custom_reward_function.name 来指定它们。有关它们的设置,请参考 config-explain-page。
您的奖励函数的参数应为 data_source、solution_str、ground_truth 和 extra_info。
例如:
def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):
return len(solution_str)/100
如果您只测试单个定制奖励函数,可以简单地将其命名为“compute_score”,并留下 custom_reward_function.name 未设置。
要使用不同的定制奖励函数运行多次测试,您可以为每次试验修改 custom_reward_function.path 和 custom_reward_function.name 。
例如,您可能创建一个单个 my_reward.py 文件,并在其中实现多个奖励函数。这样,对于不同的试验,您只需调整 custom_reward_function.name,这样在脚本中进行多次测试会更方便。