Prepare Data for Post-Training ======================================== Last updated: 02/09/2025. 在启动后训练任务之前,我们需要为策略训练准备数据。数据应以 parquet 格式存储。 我们为不同的数据集提供了几个数据预处理脚本,包括 GSM8K、MATH、HelloSwag、Full_hh_rlhf。要准备其他数据集,我们需要遵循以下步骤:数据预处理脚本可以分为两个部分: 1. 第一部分是通用部分,它从 huggingface 的 ``datasets`` 包加载数据集。然后使用 ``make_map_fn`` 对数据集进行预处理,然后以 parquet 格式存储。 .. code:: python import re import os import datasets from verl.utils.hdfs_io import copy, makedirs import argparse # To extract the solution for each prompts in the dataset # def extract_solution(solution_str): # ... if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--local_dir', default='/opt/tiger/gsm8k') parser.add_argument('--hdfs_dir', default=None) args = parser.parse_args() num_few_shot = 5 data_source = 'openai/gsm8k' dataset = datasets.load_dataset(data_source, 'main') train_dataset = dataset['train'] test_dataset = dataset['test'] # Construct a `def make_map_fn(split)` for the corresponding datasets. # ... train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) local_dir = args.local_dir hdfs_dir = args.hdfs_dir train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) makedirs(hdfs_dir) copy(src=local_dir, dst=hdfs_dir) 2. 用户需要自行实现 ``make_map_fn()`` 函数(以及 ``extract_solution``),以支持不同的数据集或任务。 我们已经实现了 GSM8k、MATH、Hellaswag 和 Full_hh_rlhf 数据集的数据预处理。以 GSM8k 数据集为例: **GSM8K** 在 ``make_map_fn`` 中,每个数据字段应包含以下 5 个字段: 1. ``data_source``: 数据集的名称,用于在 ``RewardModel`` 中索引相应的奖励函数。 2. ``prompt``: 此字段应按照 huggingface chat_template 格式构造。``RLHFDataset`` 中的分词器将应用聊天模板并对提示进行分词。 3. ``ability``: 定义任务类别。 4. ``reward_model``: 目前,我们在评估期间仅使用 ``ground_truth`` 字段。``ground_truth`` 由 ``extract_solution`` 函数计算。**注意**,相应奖励函数的实现应与此提取的 ``ground_truth`` 保持一致。 5. ``extra_info``: 记录当前提示的一些信息。目前未使用。 .. code:: python def extract_solution(solution_str): solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) # extract the solution after #### assert solution is not None final_solution = solution.group(0) final_solution = final_solution.split('#### ')[1].replace(',', '') return final_solution instruction_following = "Let's think step by step and output the final answer after \"####\"." # add a row to each data item that represents a unique id def make_map_fn(split): def process_fn(example, idx): question = example.pop('question') question = question + ' ' + instruction_following answer = example.pop('answer') solution = extract_solution(answer) data = { "data_source": data_source, "prompt": [{ "role": "user", "content": question }], "ability": "math", "reward_model": { "style": "rule", "ground_truth": solution }, "extra_info": { 'split': split, 'index': idx } } return data return process_fn