# Recipe: Fully Async Policy Trainer **Author:** `https://github.com/meituan-search` Last updated: 10/18/2025. 本文档介绍了完全异步的 PPO 训练系统,该系统完全解耦了 Trainer 和 Rollouter, 支持异步样本生成和训练。 在此系统中,我们在训练 Qwen2.5-7B 模型时,使用 128 张 GPU 实现了 2.35x-2.67x 的性能提升, 而不会显著影响结果。 ## 介绍 ### 背景 分离 rollout 和 train 的架构相比于 colocate 架构,可以更灵活地分配资源,并设计更灵活的训练逻辑,从而解决长尾问题(long-tail problems)导致的低 GPU 利用率和训练效率问题。 one_step_off_policy 通过设计分离架构并在 rollout 和 train 之间进行一步异步训练,缓解了长 rollout 时间的问题,并在训练效率上取得了一些收益。 然而,它强制使用一步异步训练的数据,灵活性不够,无法完全消除长尾对训练效率的影响。 在其他框架如 AReaL、Magistral、StreamRL 和 AsyncFlow 中,基于分离架构实现了异步训练和流式训练,并取得了收益。 我们借鉴了其方法,并在 VERL 中实现了该方法。fully_async_policy 支持异步、流式和部分 rollout 训练。 通过合理设置资源分配和参数同步频率等参数,fully_async_policy 可以显著提高训练效率。 > Magistral https://arxiv.org/abs/2506.10910 > > AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language > Reasoning https://arxiv.org/abs/2505.24298 > > StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream > Generation https://arxiv.org/abs/2504.15930 > > AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01653 > ### 核心贡献 * **资源隔离**:与使用 hybrid_engine 不同,Rollouter 和 Trainer 使用单独的计算资源,并需要分别指定它们占用的资源。 * **并行生成和训练**:Trainer 训练时,Rollouter 正在生成新的样本。 * **多步异步**:相比于一步 off policy,它支持从 0.x 步到多步的异步设置,使异步解决方案更加灵活。 * **NCCL 参数同步**:使用 NCCL 通信原语进行 Rollouter 和 Trainer 之间的参数通信。 * **流式推理和训练**:Rollouter 以样本为单位生成数据,数据传输以单个样本为最小传输单元。 * **异步训练和新鲜度控制**:通过设置参数 `async_training.staleness_threshold`,支持使用由旧参数生成的样本进行训练。 * **PartialRollout**:Rollouter 的推理过程支持部分 rollout 逻辑。在参数同步期间,通过添加 `sleep()` 和 `resume()` 逻辑, 保存正在进行的 rollout 的样本,并在下一次 rollout 中继续使用它们,从而减少在参数同步期间等待正在进行的任务完成的时间。 当前支持的使用模式是 megatron/fsdp+vllm。vllm 必须使用基于 AgentLoop 的 server 模式。 ## 设计 fully_async_policy 的整体架构如下图所示。fully_async_policy 主要由四个部分组成:Rollouter、MessageQueue、Trainer 和 ParameterSynchronizer。 ![fully_async_policy_structure]( https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true) 1. Rollouter 以样本为单位生成序列,并将生成的样本放入 MessageQueue,生产速度由新鲜度控制。 2. MessageQueue 用于临时存储 Rollouter 生成的样本。 3. Trainer 以样本为单位从 MessageQueue 中获取样本。获取 `require_batches*ppo_mini_batch_size` 个样本后,将执行训练。训练 `async_training.trigger_parameter_sync_step` 轮后,触发与 Rollouter 的参数同步。 4. ParameterSynchronizer 实现了 NCCL 同步参数同步能力。 与基础方案相比,收益的来源在于:在 colocate 情况下,为 rollout 分配更多资源无法解决长尾样本导致的闲置。 在我们执行资源隔离后,rollout 和 train 的时间可能比之前要长(因为使用的资源更少), 但它们的时耗重叠减少了端到端的时耗。 ![fully_async_policy_revenue]( https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true) ## 使用 ### 参数描述 | super params | implication | |-----------------------------------------------|------------------------------------------------------------------------------------------------| | `trainer.nnodes` | Trainer 的节点数 | | `trainer.n_gpus_per_node` | Trainer 每节点 GPU 数 | | `rollout.nnodes` | Rollouter 的节点数 | | `rollout.n_gpus_per_node` | Rollouter 每节点 GPU 数 | | `data.train_batch_size` | 在完全异步策略中,该值无效(默认为 0) | | `data.gen_batch_size` | 在完全异步策略中,使用流式样本生产逻辑(默认为 1) | | `rollout.total_rollout_steps` | Rollout 样本总数量 | | `rollout.test_freq` | Rollouter 更新参数前执行验证的次数 | | `actor_rollout_ref.actor.ppo_mini_batch_size` | ppo_mini_batch_size 是所有 workers/gpu 的全局数量 | | `async_training.require_batches` | FullyAsyncTrainer 一次获取的 ppo_mini_batch_size 数量 | | `async_training.trigger_parameter_sync_step` | 指示 FullyAsyncTrainer 进行参数同步前执行的本地更新次数 | | `async_training.staleness_threshold` | 新鲜度控制 | | `async_training.partial_rollout` | 是否执行 partial_rollout | | `async_training.use_rollout_log_probs` | 使用 rollout 生成的 log_probs | | `async_training.compute_prox_log_prob` | 是否在训练阶段使用训练模型的参数计算 log_prob。 | | **进一步解释:** * `rollout.total_rollout_steps` 与 colocate 相比,通过将 train_batch_size 和 step 相乘可以保持一致: `rollout.total_rollout_steps = data.train_batch_size * step`。 * `async_training.trigger_parameter_sync_step` 在完全异步策略中,它指示 Trainer 执行多少本地更新(即获取 `require_batches * ppo_mini_batch_size` 样本多少次)后与 Rollouter 进行参数同步。 在 Rollouter 和 Trainer 之间的每两次参数同步之间,Trainer 将处理 `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` 个样本。 要与 colocate 公平比较速度,trigger_parameter_sync_step 应设置为 `data.train_batch_size / (require_batches * ppo_mini_batch_size)`。 * `async_training.staleness_threshold` 在完全异步策略中,它指示允许使用的过期样本的最大比例。 * staleness_threshold=0,表示同步训练。 Rollouter 将在两次参数更新之间生成固定数量的样本,样本数量为: $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$ * staleness_threshold>0,表示异步训练,可以设置为小于 1 的小数以实现更灵活的异步调用。 Rollout 将在两次参数更新之间生成最多以下数量的样本: $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$ num_staleness_sample 表示上次 rollout 中多生成的过期样本数量。 由于这是一个流式系统,rollout 继续生成,trainer 继续消费。如果 rollouter 较慢,trainer 会提前触发参数同步,而 rollouter 可能不会实际产生 rollout_num 个样本。 当 rollout 足够快时,将 staleness_threshold 设置为 1 基本上相当于 one_step_off policy。 为避免过多过期样本影响训练准确性,建议将此值设置为小于 1。 * `async_training.partial_rollout` partial_rollout 仅在 staleness_threshold>0 时真正生效。 * `async_training.use_rollout_log_probs` 在强化学习算法中,log_probs 与参数版本和 token 有隐含相关性。由于 PPO/GRPO/DAPO 等算法的设置,在计算重要性采样时, old_log_prob 必须使用 rollout 参数和 token 对应的 log_probs,以确保算法正确性。在完全异步策略中,我们默认使用 rollout 计算 old_log_prob 而不是 trainer。 * `async_training.require_batches` 在流式训练中,require_batches 应设置为 1,表示生产足够的 ppo_mini_batch_size 样本后执行训练。 在实际测试中,我们发现一次发出的样本数量越少,由于数据分布的顺序,会导致训练不稳定和响应长度更长。 为此,我们额外提供了 require_batches 用于流式分布,并控制一次训练中参与的样本数量。 * `async_training.compute_prox_log_prob` (experimental) 在训练过程中,我们观察到指标和响应长度在训练后期可能变得不稳定。为了缓解这个问题,我们可以使用 [Rollout Correction](https://verl.readthedocs.io/en/latest/algo/rollout_corr.html) 技术进行重要性采样和拒绝采样。要利用 Rollout Correction,需要使用训练引擎计算 log_prob,这就需要启用此开关。 此外,当 compute_prox_log_prob 和 Rollout Correction 在模式 d(异步流管道与部分 rollout)下启用时,我们的实现遵循 `Decoupled PPO`,详见 [Mathmatics of Rollout Correction](https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html)。 ### 支持的模式 1. on policy pipeline: 1. **trigger_parameter_sync_step=1, staleness_threshold=0** 2. Rollouter 一次生成 `require_batches*ppo_mini_batch_size` 个样本,Trainer 获取这些样本进行训练,训练完成后 Trainer 和 Rollouter 执行参数同步; 3. 在 rollout 阶段,如果有长尾样本但 rollout 样本较少,较短样本无法填补闲置资源,会造成一些资源浪费。 4. 如图 a 所示; 2. stream off policy pipeline: 1. **trigger_parameter_sync_step>1, staleness_threshold=0** 2. 将执行同步流式训练。Rollouter 一次生成 `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` 个样本,Trainer 每次获取 `require_batches*ppo_mini_batch_size` 个样本后执行本地训练,训练 trigger_parameter_sync_step 次后,Trainer 和 Rollouter 执行参数同步; 3. 与 a 相比,由于一次生成更多样本,资源闲置将降低。 4. 在一步训练中,将有两个资源闲置周期:获取第一批样本时,train 等待 `require_batches*ppo_mini_batch_size` 个样本生产完毕,以及最后一个参数更新时,rollout 等待训练完成。 5. 如图 b 所示; 3. async stream pipeline with stale samples: 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False** 2. 每次参数更新后,Rollouter 将计划生成最多 rollout_num 个样本(实际生成的样本数量可能取决于 rollout 速度)。 3. 如果 rollout 过程相对较快,Rollouter 将在参数同步前生成一些额外样本 num_stale_samples,以便在同步后立即用于 Trainer。 当触发参数同步时,如果 Rollouter 有正在进行的任务,它将等待任务完成,而不是添加新任务; 4. 与 b 相比,除了第一次训练步骤外,后续训练将不会有等待第一批 rollout 完成的时间,但会有等待活跃任务完成的时间。 5. 如图 c 所示; 4. async stream pipeline with partial rollout: 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True** 2. 与 c 相比,当触发参数同步时,如果 Rollouter 有正在生成的样本,它将中断 rollout 过程并执行参数同步。被中断的样本将在同步后继续生成。这样可以减少等待活跃任务完成的时间。 3. 如图 d 所示; ![fully_async_policy_mode]( https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true) ### 关键指标 | metrics | implication | |------------------------------------------------|--------------------------------------------------------------------------------------------------------| | `trainer/idle_ratio` | Trainer 闲置率 | | `rollouter/idle_ratio` | Rollouter 闲置率 | | `fully_async/count/stale_samples_processed` | 训练中使用的过期样本总数 | | `fully_async/count/stale_trajectory_processed` | 训练中使用的过期轨迹总数(一个样本产生 rollout.n 条轨迹) | | `fully_async/partial/total_partial_num` | Trainer 在两次 trigger_parameter_sync_step 之间处理的局部样本数量 | | `fully_async/partial/partial_ratio` | Trainer 在两次 trigger_parameter_sync_step 之间处理的局部样本比例 | | `fully_async/partial/max_partial_span` | Trainer 在两次 trigger_parameter_sync_step 之间处理的局部样本的最大参数跨度 | ### 参数调优建议 * 资源分配和调整: * 合理的资源分配是实现良好训练效率的前提。理想的资源分配应该让 rollout 时间和 train 时间接近,从而最小化整个训练过程中的管道气泡, 避免资源闲置,并确保 Trainer 不使用旧样本。在真实训练场景中,可以根据实际训练中 rollout 和 train 的闲置时间调整资源分配, 这可以从 rollouter/idle_ratio 和 trainer/idle_ratio 获得。如果 rollouter/idle_ratio 高且 trainer/idle_ratio 低, 应该增加 Trainer 资源并减少 Rollouter 资源,反之亦然。 * 关键参数: * staleness_threshold:设置得太高会导致使用更多旧样本,从而影响模型性能。建议设置为小于 1。 * require_batches:越接近 1,越接近纯流式过程,训练气泡越小,可以实现的加速效果越快,但会影响样本处理的顺序; * trigger_parameter_sync_step:设置越小,越接近 on policy,但会导致频繁的参数同步。长尾样本浪费无法被短样本填补的资源,导致资源利用率低。 设置越大,计算效率越高,但准确性会受到 off policy 的影响。 * rollout.test_freq:会占用 Rollouter 资源,不推荐设置太小。 * 模式选择:通过调整不同的参数,完全异步架构支持不同级别优化加速,适用于不同场景的任务。 * 对于需要确保训练稳定性和 on-policy 性质的小规模任务,低速要求不高,可以尝试 on policy pipeline 模式(模式 1)。 * 对于需要提高训练吞吐但对 staleness 敏感的场景,可以尝试 stream off policy pipeline 模式。即, 通过设置 trigger_parameter_sync_step>1 来提高训练效率,但仍保持同步机制(staleness_threshold=0)(模式 2)。 * 对于高速训练要求高,可以容忍一定程度 off-policy 和 staleness 的大规模任务,通过设置 staleness_threshold> 0 和 partial_rollout=True,可以提高训练效率,使用 async stream pipeline 模式(模式 3 或 4)。 ### 快速开始 ```shell rollout_mode="async" rollout_name="vllm" # sglang or vllm if [ "$rollout_mode" = "async" ]; then export VLLM_USE_V1=1 return_raw_chat="True" fi train_prompt_bsz=0 gen_prompt_bsz=1 n_resp_per_prompt=16 train_prompt_mini_bsz=32 total_rollout_steps=$(((512*400))) test_freq=10 staleness_threshold=0 trigger_parameter_sync_step=16 partial_rollout=False python -m recipe.fully_async_policy.fully_async_main \ train_batch_size=${train_prompt_bsz} \ data.gen_batch_size=${gen_prompt_bsz} \ data.return_raw_chat=${return_raw_chat} \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.actor.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.hybrid_engine=False \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.name=${rollout_name} \ actor_rollout_ref.rollout.mode=${rollout_mode} \ actor_rollout_ref.rollout.calculate_log_probs=True \ trainer.nnodes="${NNODES_TRAIN}" \ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ rollout.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.partial_rollout="${partial_rollout}" ``` ## 实验 ### 7B 模型上的异步训练 我们使用 Qwen2.5-Math-7B 验证了完全异步策略在长候选和多资源下的收益。 使用 `async stream pipeline with stale samples` 策略,我们在 32 卡、64 卡和 128 卡上实现了约 2x 性能提升,而未显著影响实验结果。 * Machine: H20 * Model: Qwen2.5-Math-7B * Rollout length: max_response_length FSDP2: 28K tokens; * Algorithm: DAPO * Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet * Engine: vllm+FSDP2 * rollout.n: 16 * ppo_mini_batch_size: 32 * test_freq: 20 * colocate sync: * step: 400 * train_batch_size: 512 * fully_async_policy * total_rollout_steps: 512*400 * require_batches: 4 * trigger_parameter_sync_step: 4 * staleness_threshold: 0.5 * partial_rollout: True | training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | |:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:| | colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 | | fully_async_policy | 16:16 | 294.77 | 21.26 | \ | 269.80 | 7h 58m
(1.72x) | 16h 21m
(1.70x) | 1d 0h 53m
(2.31x) | 1d 9h 26m
(2.66x) | max: 0.3302
last: 0.2333 | | colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 | | fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 | | colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | | fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.67x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg ### 128 卡 7B 异步模式实验 我们使用 Qwen2.5-Math-7B 验证了 fully async 支持的各种模式的效果。 我们可以看到,流式带来的收益约为 1.6x,而结合 staleness 和 partial_rollout 后,收益达到 2.35x。 | mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | |:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:| | colocate sync | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | | `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | | `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | | | `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg ### 128 卡 Stale Ablation 实验 在 `async stream pipeline with partial rollout` 模式下,我们验证了 staleness 设置对训练效率的影响。 我们发现随着 staleness 的增大,最终收益越明显。 我们还注意到,staleness 值 0.3 和 0.5 的时间相当接近,因为随着训练步骤的增加,响应长度变化显著,导致训练不稳定。 对此需要进一步分析和优化。 | staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | |:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| | 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | | 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 | | 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 | | 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg ### 128 卡 7B require_batches Ablation 实验 在多次测试中,我们发现流式中每次发出的样本数量影响训练期间的响应长度,从而影响训练时间。我们通过修改 `async_training.require_batches` 来验证对结果的影响。 | require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 | |:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| | 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 | | 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 | | 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg ### 30B 模型模式实验 我们使用 `async stream pipeline with staleness samples` 策略在 Qwen3-30B-A3B-Base 模型上实现了 1.7x 性能提升,相比于 colocate 设置。值得注意的是,这远未达到异步可能实现的性能增益上限。首先,比较实验使用了仅 8k 的最大响应长度,这远低于之前实验中的 20k 序列长度,rollout 尾部效应不那么明显。其次,我们采用了高度倾斜的资源分配,rollout 使用 96 张 GPU,trainer 使用 32 张 GPU,这不是最优配置。在实验中,我们观察到当前 verl 实现有一定约束,如要求数据能被 GPU 数量均匀划分,使资源调整不那么灵活。此外,随着异步训练和部署的加速,性能差距逐渐缩小。因此,未来启用更灵活的资源分配和动态资源调整将是我们的下一个重点。 * Machine: H20 * Model: Qwen3-30B-A3B-Base * Rollout length: max_response_length : 8K tokens; * Algorithm: GRPO * Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet * Engine: vllm+Megatron * rollout.n: 16 * ppo_mini_batch_size: 128 * test_freq: 20 * colocate sync: * step:400 * train_batch_size: 512 * fully_async_policy * total_rollout_steps: 512*400 * trigger_parameter_sync_step: 512/128 = 4 * staleness_threshold: 0.5 * partial_rollout: True | Training Mode | Resource Allocation | Step | Gen | Old Log Prob | Ref | Update Actor | Total Time 100 Step | Total Time 200 Step | Total Time 300 Step | Total Time 400 Step | Acc/Mean@1 | |----------------------|--------------------|---------|--------|--------------|--------|--------------|---------------------|---------------------|---------------------|---------------------|-----------------------------| | Colocate Sync | 128 | 497.89 | 348.05 | 28.73 | 20.86 | 86.27 | 13h 36m | 1d 3h 48m | 1d 19h 4m | 2d 11h 39m | max: 0.3500
last: 0.3208 | | Fully Async Policy | 96:32 | 282.75 | 22.06 | \ | 50.05 | 206.63 | 6h 45m (2.01x) | 14h 48m (1.88x) | 1d 0h 9m (1.78x) | 1d 10h 41m (1.72x) | max: 0.3813
last: 0.3448 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-30B?nw=nwuserhouzg ## 多轮工具调用 参考 **recipe/retool** 和 **ToolAgentLoop**,我们实现了 **AsyncPartialToolAgentLoop**,一个支持 partial_rollout 的多轮工具调用循环,用于 **fully_async_policy**。 ### 核心设计 `AsyncPartialToolAgentLoop` 继承自 `ToolAgentLoop`,并适应于 `fully_async_policy` 的异步训练模式。当 `partial_rollout=True` 时,Rollouter 在与 Trainer 同步参数前中断正在进行的生成任务。`AsyncPartialToolAgentLoop` 能够: 1. **中断任务**:响应中断信号以保存当前状态。目前,中断发生在 `GENERATING` 过程或其他状态完成后。 2. **恢复任务**:在参数同步完成后,从保存的状态恢复执行,而不是重新开始。 ### 如何使用 在 `fully_async_policy` 中进行多轮工具调用的 RL 训练类似于 `recipe/retool`。通过在配置文件中指定 `multi_turn` 配置启用。 1. **SFT 阶段**:首先,模型应该进行 SFT 以学习如何遵循工具调用格式指令。 2. **多轮配置**:在 `fully_async_policy` 训练配置中,设置以下参数: ```yaml actor_rollout_ref: rollout: multi_turn: enable: True # AsyncPartialToolAgentLoop will be used by default in fully_async_policy mode # Other multi_turn related configurations ``` 3. **异步参数**:为了提高效率,在使用多轮工具调用时启用 `partial_rollout` 和 `staleness_threshold`: ```yaml async_training: partial_rollout: True staleness_threshold: 0.5 # Other async parameters ``` 4. **示例**:参见 `recipe/fully_async_policy/shell/dapo_7b_async_retool.sh`。 ### 实验结果 为了验证 `fully_async_policy` 在多轮工具调用任务上的性能,我们将其与标准的 `colocate` 同步模式进行了比较。关键参数设置如下。 * **SFT 模型**:基于 `Qwen2.5-7B-Instruct`,在 `ReTool-SFT` 数据集上训练 6 个 epoch * **RL 算法**:DAPO * **数据集**: * Train: `DAPO-Math-17k` * Test: `aime_2025` * **资源和模式比较**: * `colocate sync`: 32 H20 gpu * `fully_async_policy`: Trainer 16 gpu + Rollouter 16 gpu * **关键配置**: 1. **工具调用配置**: * `multi_turn.enable: True` * `multi_turn.max_user_turns: 16` * `multi_turn.max_assistant_turns: 16` * `multi_turn.tool_config_path: recipe/retool/sandbox_fusion_tool_config.yaml` 2. **`colocate sync` 配置**: * `ppo_mini_batch_size: 16` * `train_batch_size: 64` 3. **`fully_async_policy` 配置**: * `ppo_mini_batch_size: 16` * `trigger_parameter_sync_step: 4` * `require_batches: 1` * `staleness_threshold: 1` * `partial_rollout: True` | training mode | Resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | aime_2025
acc/mean@30 | |:------------------: |:-------------------: |:-------: |:-------: |:------------: |:------------: |:----------------------: |:----------------------: |:---------------------------: | | colocate | 32 | 375.47 | 228.03 | 35.19 | 111.84 | 9h 46m | 22h 28m | start:0.1078
last:0.2056 | | fully_async_policy | 16: 16 | 221.36 | 40.59 | \ | 179.58 | 6h 19m
(1.55x) | 14h 4m
(1.60x) | start:0.11
last:0.2044 | > source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-multiturn-tool?nw=nwuserhouzg ## 未来计划 * GRPO 实验 * Megatron 适配 * SGLang 集成 * Transfer queue 集成 * 异步参数同步 * AReaL 异步算法实现 * TPPO 算法实现 * 多轮和工具支持