PPO Ray 训练器
最后更新:02/12/2025。
我们实现了 RayPPOTrainer,这是一种运行在驱动进程(driver process)上的训练器,位于单个 CPU/GPU 节点上(默认为 CPU)。
该 PPORayTrainer 包含 3 个核心函数,分别用于数据准备、WorkerGroup 初始化和 PPO 训练循环。
数据准备
PPORayTrainer 作为一个单进程,负责从数据集中加载一整批样本(提示词),然后分发到运行在不同 GPU 上的不同 worker_groups。
为了通用化数据加载,我们实现了 RLHFDataset 类,用于加载预处理的 parquet 文件,对提示词应用对话模板(chat templates)、添加填充、截断超出最大提示长度的提示,然后进行分词(tokenize)。
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
config=self.config.data)
然后,数据加载器将在 PPO 小批量大小下迭代该数据集。
WorkerGroup 初始化
我们首先介绍在给定 GPU 集上初始化演员模型(actor model)的 WorkerGroup 的基本实现。
# max_colocate_count 表示每个 RayResourcePool 中的 WorkerGroups(即进程)数量
# 对于 FSDP 后端,我们推荐使用 max_colocate_count=1,将所有 WorkerGroups 合并为一个,以避免冗余。
# 对于 Megatron 后端,我们推荐使用 max_colocate_count>1,可以为不同模型利用不同 WorkerGroup
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,
use_gpu=True,
max_colocate_count=1)
# 定义要在远程初始化的演员 rollout 类
actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)
# 定义演员 rollout worker group
actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=actor_rollout_cls,
default_megatron_kwargs=config.actor_rollout.megatron)
不同的 WorkerGroups,例如 actor_rollout_worker_group、critic_worker_group 和 ref_worker_group,在上方实现中位于单独的进程中。
驱动进程随后可以通过调用 actor_rollout_worker_group 及其他角色的分布式计算函数来构建 RL 训练循环。
对于位于同一组 GPU 上的共置模型,我们进一步提供了一种细粒度优化,将不同角色的 worker_group 合并到同一进程中。这种优化可以节省不同进程中的冗余 CUDA/分布式上下文。
# 初始化 WorkerGroup
# 注意:如果你希望为每个角色使用不同的资源池(以支持不同的并行大小),请不要使用 `create_colocated_worker_cls`。
# 相反,直接为不同 worker groups 传递不同的资源池。
# 有关更多信息,请参阅 TODO(url)。
all_wg = {}
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg['critic']
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg = all_wg['ref']
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg['rm']
self.rm_wg.init_model()
# 我们应该在最后创建 rollout,以便 vllm 能更好地估算 KV 缓存内存
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()
Note
对于 Megatron 后端,如果将 worker_groups 合并到同一进程中,所有角色将使用相同的 3D 并行大小。为了优化这一点,我们可能需要在同一分布式上下文中为每个角色维护多个 3D 进程组。如果你希望为不同角色使用不同的 3D 并行大小,请遵循第一个代码块的类似架构来初始化每个角色的 worker_group。
PPO 训练循环
我们通过调用每个角色的 worker_group 中的函数来实现 PPO 训练循环。每个函数的输入和输出数据都是 `protocol.py`_ 中实现的 DataProto 对象。在训练循环中,训练器将遵循封装在 worker 函数中的传输协议,向不同 GPU 分发/收集数据。PPO 微批量的计算在 update_actor 和 update_critic 函数中处理。
要扩展到其他 RLHF 算法,如 DPO、GRPO,请参考 Extend to other RL(HF) algorithms。
def fit(self):
"""
PPO 的训练循环。
驱动进程只需通过 RPC 调用 worker group 的计算函数,即可构建 PPO 数据流。
轻量级的优势计算在驱动进程上完成。
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
global_steps = 0
# 在训练前执行验证
# 目前,我们仅支持使用 reward_function 进行验证。
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# batch = batch.to('cuda')
# 为生成弹出这些键
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
# 生成一批数据
with Timer(name='gen', logger=None) as timer:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
metrics['timing/gen'] = timer.last
batch = batch.union(gen_batch_output)
if self.use_reference_policy:
# 计算参考对数概率
with Timer(name='ref', logger=None) as timer:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
metrics['timing/ref'] = timer.last
# 计算价值
with Timer(name='values', logger=None) as timer:
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
metrics['timing/values'] = timer.last
with Timer(name='adv', logger=None) as timer:
# 计算分数。支持基于模型和基于函数的混合方式。
# 我们首先使用奖励模型计算分数,然后调用 reward_fn 组合奖励模型和基于规则的结果。
if self.use_rm:
# 我们首先计算奖励模型分数
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# 我们与基于规则的 rm 组合
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
# 计算奖励。如果可用,则应用 KL 惩罚
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
# 计算优势,在驱动进程上执行
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)
metrics['timing/adv'] = timer.last
# 更新批评者模型
if self.use_critic:
with Timer(name='update_critic', logger=None) as timer:
critic_output = self.critic_wg.update_critic(batch)
metrics['timing/update_critic'] = timer.last
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
# 实现批评者预热
if self.config.trainer.critic_warmup <= global_steps:
# 更新演员模型
with Timer(name='update_actor', logger=None) as timer:
actor_output = self.actor_rollout_wg.update_actor(batch)
metrics['timing/update_actor'] = timer.last
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# 验证
if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
with Timer(name='testing', logger=None) as timer:
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
metrics['timing/testing'] = timer.last
metrics.update(val_metrics)
# 收集数据指标
data_metrics = compute_data_metrics(batch=batch)
metrics.update(data_metrics)
# TODO: 创建一个支持各种后端的规范日志记录器
logger.log(data=metrics, step=global_steps)
if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{global_steps}')
actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)
if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{global_steps}')
critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)
global_steps += 1
# 在训练后执行验证
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')