RolloutSkip 函数使用文档

最后更新时间:08/01/2025。

适用场景

RolloutSkip 功能旨在通过缓存和重用之前生成的序列来加速强化学习训练中的 rollout 过程。此功能特别适用于以下情况:

  1. 你需要使用相同配置重复运行实验

  2. 你希望通过避免冗余的序列生成来节省时间,从而接近最优策略

API 和使用示例

2.1 Trainer 适配

RayDAPOTrainer() (位于 verl/recipe/dapo/dapo_ray_trainer.py 中) 和 RayPPOTrainer()`(位于 `verl/trainer/ppo/ray_trainer.py`) 都已完成适配。

以下是如何在 RayPPOTrainer 中打补丁以应用 rollout_skip 的示例。

#* 导入 RolloutSkip 类
from verl.utils.rollout_skip import RolloutSkip

...
class RayPPOTrainer:
    ...
    def fit(self):
        ...

        #* 添加如下代码:
        rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
        rollout_skip.wrap_generate_sequences()

        ...

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                ...

2.2 基本配置

然后,你需要在配置中添加以下参数来启用 RolloutSkip 功能:

actor_rollout_ref.rollout.skip_rollout=True \
actor_rollout_ref.rollout.skip_dump_dir="/tmp/rollout_dump" \

注意:

  1. skip_dump_dir 是存储缓存序列的目录。请确保该目录可写且训练进程可以访问。此外,请确保 skip_dump_dir 不是相对路径,因为 Ray 会将数据存储在 /tmp/ray/session_<session_id>/ 中,而在 worker 中找不到相对路径。

  2. 转储数据路径遵循以下命名模式:{experiment_name}_{project_name}_TrainGBS{train_gbs}__InferGBS{gen_gbs}__N{n},一旦你更改了 experiment_nameproject_nametrain_gbsgen_gbsn,缓存数据将被存储在新目录中。