# 近端策略优化(PPO) 最后更新:06/19/2025。 近端策略优化(Proximal Policy Optimization, PPO)是一类用于强化学习的策略梯度方法,由 OpenAI 于 2017 年提出。PPO 在简单性、稳定性和性能之间取得了平衡,使其成为现代 RL 应用中最广泛使用的算法之一,包括大规模语言模型的微调。 传统的策略梯度方法,如 REINFORCE 或 Vanilla Policy Gradient,存在以下问题: - 高方差和样本效率低下。 - 由于大策略更新导致的不稳定性。 PPO 通过使用裁剪的替代目标函数来解决这个问题,该函数避免过度大的更新,而无需使用二阶导数。 有关 PPO 的更多技术细节,我们建议阅读 [OpenAI spinning up 教程] (https://spinningup.openai.com/en/latest/algorithms/ppo.html) 中的介绍,以及论文 [近端策略优化算法] (https://arxiv.org/abs/1707.06347)。 ## 关键组件 - 演员-评论家架构(Actor-Critic Architecture):PPO 需要一个演员模型(策略)和一个评论家模型(价值函数)。与其他算法(如 GRPO 和 RLOO)不同,它不需要评论家模型。 - 广义优势估计(Generalized Advantage Estimation, GAE):PPO 使用 GAE 来计算优势值,这有助于降低策略梯度估计的方差,同时保持较低的偏差。 - 裁剪的替代目标函数(Clipped Surrogate Objective):PPO 的核心是通过裁剪的替代目标函数实现的,该函数限制策略更新。 ## 配置 请注意,所有包含 `micro_batch_size` 的配置用于配置每次前向或后向传递的最大样本或标记计数,以避免 GPU 内存不足(OOM),其值不应改变算法/收敛行为。 大多数评论家配置与演员类似。请注意,下图中省略了评论家模型。 ![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) - `data.train_batch_size`:用于生成一组采样轨迹/ rollout 的提示的全局批次大小。响应/轨迹的数量为 `data.train_batch_size * actor_rollout.ref.rollout.n` - `actor_rollout_ref.actor.ppo_mini_batch_size`:采样轨迹集被分割成多个小批次,其中 batch_size = ppo_mini_batch_size,用于 PPO 演员更新。ppo_mini_batch_size 是跨所有工作器的全局大小 - `critic.ppo_mini_batch_size`:采样轨迹集被分割成多个小批次,其中 batch_size = ppo_mini_batch_size,用于 PPO 评论家更新。ppo_mini_batch_size 是跨所有工作器的全局大小 - `actor_rollout_ref.actor.clip_ratio`:PPO 裁剪范围。默认为 0.2 - `actor_rollout_ref.actor.ppo_epochs`:在一组采样轨迹上进行 PPO 演员更新的 epochs 数 - `critic.ppo_epochs`:在一组采样轨迹上进行 PPO 评论家更新的 epochs 数。默认为 `actor_rollout_ref.actor.ppo_epochs` - `algorithm.gemma`:折扣因子 - `algorithm.lam`:在 GAE 估计器中权衡偏差和方差的 lambda 项 - `algorithm.adv_estimator`:支持 gae、grpo、reinforce_plus_plus、reinforce_plus_plus_baseline、rloo ## 高级扩展 ### KL 散度控制 防止策略与参考策略偏离太远的选项。有两种机制可用:KL 奖励惩罚和 KL 损失。有关更多技术细节,请参见 [使用人类反馈训练语言模型跟随指令] (https://arxiv.org/abs/2203.02155) 使用 KL 损失进行 KL 散度控制的选项: - `actor_rollout_ref.actor.use_kl_loss`:在演员中使用 KL 损失。使用时,我们不在奖励函数中应用 KL。默认为 False - `actor_rollout_ref.actor.kl_loss_coef`:KL 损失的系数。默认为 0.001 - `actor_rollout_ref.actor.kl_loss_type`:支持 kl(k1)、abs、mse(k2)、low_var_kl(k3) 和 full。在末尾附加 "+" (例如 'k1+' 和 'k3+')会应用直通以使用 k2 来进行无偏梯度估计,无论 KL 值估计如何(有关更多详情,请参见 https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848)。如何计算演员和参考策略之间的 KL 散度。请参见此博客文章以获取详细分析:http://joschu.net/blog/kl-approx.html 使用奖励中的 KL 惩罚的选项: - `algorithm.use_kl_in_reward`:是否启用奖励中的 KL 惩罚。默认为 False - `algorithm.kl_penalty`:支持 kl(k1)、abs、mse(k2)、low_var_kl(k3) 和 full。这定义了计算演员和参考策略之间 KL 散度的方式。具体选项请参考 core_algos.py 中的 `kl_penalty`。请参见此博客文章以获取详细分析:http://joschu.net/blog/kl-approx.html - `algorithm.kl_ctrl.kl_coef`:奖励中 KL 惩罚的(初始)系数。默认为 0.001 - `algorithm.kl_ctrl.type`:'fixed' 表示 FixedKLController,'adaptive' 表示 AdaptiveKLController - `algorithm.kl_ctrl.horizon`:请参见 AdaptiveKLController 的源代码了解详情 - `algorithm.kl_ctrl.target_kl`:请参见 AdaptiveKLController 的源代码了解详情 ### 双裁剪 PPO 双裁剪 PPO 通过在优势小于零时对策略比率应用下界来引入方法,当乘以大比率时,不会超过指定的下界。 ![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139) - `actor_rollout_ref.actor.clip_ratio_c`:双裁剪 PPO 的值下界,默认为 3.0 ## 参考示例 Qwen2.5 训练日志和命令:[link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) ```bash bash run_gemma.sh trainer.n_gpus_per_node=1 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ trainer.logger=console \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ data.train_batch_size=256 \ actor_rollout_ref.actor.ppo_mini_batch_size=64 \ actor_rollout_ref.actor.ppo_micro_batch_size=2 \ critic.ppo_micro_batch_size=2 ``` 使用 verl v0.2 的参考性能: | 模型 | 方法 | 分数 | 链接 | |-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------| | Qwen/Qwen2.5-0.5B-Instruct | 预训练模型 | 36.4 | [Qwen 博客](https://qwenlm.github.io/blog/qwen2.5-llm/) | | Qwen/Qwen2.5-0.5B-Instruct | PPO | 56.7 | [PPO 命令和日志](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |