# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) Last updated: 06/19/2025. > Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) 🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) > 我们提出了**解耦剪辑和动态采样策略优化**(DAPO)算法。通过公开我们的工作,我们为更广泛的研究社区和社会提供对可扩展强化学习的实际访问,使所有人能够从中受益。我们的系统基于优秀的 [verl](https://github.com/volcengine/verl) 框架。感谢他们的伟大工作!将 DAPO 训练应用于 Qwen2.5-32B 基础模型证明优于之前的最新技术 DeepSeek-R1-Zero-Qwen-32B,在 AIME 2024 上实现**50%** 准确率,仅需**50%** 训练步骤。 > > ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) ## 快速开始 1. **在 Ray 集群上**准备数据集: ```bash bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default ``` 2. **从任何机器**向 Ray 集群提交作业: ```bash cd verl # Repo root export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster # Set the runtime environment like env vars and pip packages for the Ray cluster in yaml export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts ``` ## 复现运行 | 设置 | AIME 2024 Acc. | 硬件 | 镜像 | 提交 | 环境变量 | 训练脚本 | 训练记录 | | -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | | DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | | DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | | DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | > [!IMPORTANT] > > **📢 征集贡献!** > > 欢迎提交您的复现运行和设置! ## 配置 ### 分离的剪辑 epsilon(-> Clip-Higher) 示例配置: ```yaml actor_rollout_ref: actor: clip_ratio_low: 0.2 clip_ratio_high: 0.28 ``` `clip_ratio_low` 和 `clip_ratio_high` 指定了 DAPO 目标中的 $\varepsilon_{\text {low }}$ 和 $\varepsilon_{\text {high }}$(即低和高剪辑比率)。 核心相关代码: ```python pg_losses1 = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) pg_losses = torch.maximum(pg_losses1, pg_losses2) ``` ### 动态采样(带组过滤) 示例配置: ```yaml data: gen_batch_size: 1536 train_batch_size: 512 algorithm: filter_groups: enable: True metric: acc # score / seq_reward / seq_final_reward / ... max_num_gen_batches: 10 # Non-positive values mean no upper limit ``` 将 `filter_groups.enable` 设置为 `True` 将过滤掉输出指标(metric)完全相同的组,例如对于 `acc`,过滤掉输出的准确率全为 1 或 0 的组。 训练器将重复使用 `gen_batch_size` 采样,直到有足够的合格组达到 `train_batch_size` 或达到 `max_num_gen_batches` 指定的上限。 核心相关代码: ```python prompt_bsz = self.config.data.train_batch_size if num_prompt_in_batch < prompt_bsz: print(f'{num_prompt_in_batch=} < {prompt_bsz=}') num_gen_batches += 1 max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') continue else: raise ValueError( f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' ) else: # Align the batch traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n batch = batch[:traj_bsz] ``` ### 灵活的损失聚合模式(-> Token-level Loss) 示例配置: ```yaml actor_rollout_ref: actor: loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" # NOTE: "token-mean" is the default behavior ``` 将 `loss_agg_mode` 设置为 `token-mean` 将平均计算一个小批量中所有序列的所有令牌的(策略梯度)损失。 核心相关代码: ```python if loss_agg_mode == "token-mean": loss = verl_F.masked_mean(loss_mat, loss_mask) elif loss_agg_mode == "seq-mean-token-sum": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum loss = torch.mean(seq_losses) # seq-mean elif loss_agg_mode == "seq-mean-token-mean": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean loss = torch.mean(seq_losses) # seq-mean else: raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") ``` ### 超长奖励塑造 示例配置: ```yaml data: max_response_length: 20480 # 16384 + 4096 reward_model: overlong_buffer: enable: True len: 4096 penalty_factor: 1.0 ``` 将 `overlong_buffer.enable` 设置为 `True` 将对输出长度超长但仍在硬上下文限制内的输出进行惩罚。 具体来说,当输出长度超过 `max_response_length - overlong_buffer.len` 后,从 0 到 `overlong_buffer.len` 令牌的过程中,惩罚从 `0` 线性增加到 `overlong_buffer.penalty_factor`。 核心相关代码: ```python if self.overlong_buffer_cfg.enable: overlong_buffer_len = self.overlong_buffer_cfg.len expected_len = self.max_resp_len - overlong_buffer_len exceed_len = valid_response_length - expected_len overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) reward += overlong_reward ``` ## 常见问题解答 ### 论文中的“Overlong Filtering”在哪里? 论文中的大多数实验,包括性能最好的实验,都是在没有 Overlong Filtering 的情况下运行的,因为它与 Overlong Reward Shaping 在正确学习最长输出方面有某种重叠。所以我们在这里没有实现它。 ### [`main` 分支中的 `recipe/dapo` 目录](https://github.com/volcengine/verl/tree/main/recipe/dapo)与[`recipe/dapo` 分支中的 `recipe/dapo` 目录](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)有什么区别? [`recipe/dapo` 分支](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)用于**原样复现**,因此不会更新新功能。 [`main` 分支中的 `recipe/dapo` 目录](https://github.com/volcengine/verl/tree/main/recipe/dapo)作为如何扩展最新 `verl` 来实现算法配方的示例,将维护新功能。 ### 为什么修改后无法产生类似的结果? 当今的 RL 基础设施仍具有固有的不稳定性,我们正在努力改进这一点。 我们强烈推荐一次只修改一件事。 我们还在这里列出了一些已知问题: 1. 启用 CUDA 图(`enforce_eager=False`)可能会导致模型性能下降,其原因仍在调查中。