Prepare Data for Post-Training
Last updated: 02/09/2025.
在启动后训练任务之前,我们需要为策略训练准备数据。数据应以 parquet 格式存储。
我们为不同的数据集提供了几个数据预处理脚本,包括 GSM8K、MATH、HelloSwag、Full_hh_rlhf。要准备其他数据集,我们需要遵循以下步骤:数据预处理脚本可以分为两个部分:
第一部分是通用部分,它从 huggingface 的
datasets包加载数据集。然后使用make_map_fn对数据集进行预处理,然后以 parquet 格式存储。
import re
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
# To extract the solution for each prompts in the dataset
# def extract_solution(solution_str):
# ...
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')
parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args()
num_few_shot = 5
data_source = 'openai/gsm8k'
dataset = datasets.load_dataset(data_source, 'main')
train_dataset = dataset['train']
test_dataset = dataset['test']
# Construct a `def make_map_fn(split)` for the corresponding datasets.
# ...
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
用户需要自行实现
make_map_fn()函数(以及extract_solution),以支持不同的数据集或任务。
我们已经实现了 GSM8k、MATH、Hellaswag 和 Full_hh_rlhf 数据集的数据预处理。以 GSM8k 数据集为例:
GSM8K
在 make_map_fn 中,每个数据字段应包含以下 5 个字段:
data_source: 数据集的名称,用于在RewardModel中索引相应的奖励函数。prompt: 此字段应按照 huggingface chat_template 格式构造。RLHFDataset中的分词器将应用聊天模板并对提示进行分词。ability: 定义任务类别。reward_model: 目前,我们在评估期间仅使用ground_truth字段。ground_truth由extract_solution函数计算。注意,相应奖励函数的实现应与此提取的ground_truth保持一致。extra_info: 记录当前提示的一些信息。目前未使用。
def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) # extract the solution after ####
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split('#### ')[1].replace(',', '')
return final_solution
instruction_following = "Let's think step by step and output the final answer after \"####\"."
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question = example.pop('question')
question = question + ' ' + instruction_following
answer = example.pop('answer')
solution = extract_solution(answer)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
}],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx
}
}
return data
return process_fn