The Design of ``verl.single_controller`` ============================================== Last updated: 05/21/2025. **Author:**\ `Wang Zhang `__ Preface ------- 我们为 ``verl`` 的开发者准备了这份文档,特别是那些希望了解或为 ``verl.single_controller`` 模块做出贡献的人。这份文档并非面向最终用户,而是为那些希望理解架构设计 rationale (原因)以及内部机制的贡献者编写的。 -------------- Origin ------ ``single_controller`` 模块源于我收到的一项请求 —— 将一个玩具级的单进程 RLHF(Reinforcement Learning from Human Feedback)脚本适配成分布式系统,并尽可能少的修改代码,同时保持调试的便利性。 常见的做法 —— 例如使用 PyTorch 的 Distributed Data Parallel (DDP) —— 通常涉及包装 ``nn.Module`` 并启动多个进程,这些进程在不同 rank 下执行相同的函数。然而,这种方法在分布式 RLHF 上下文中存在两个主要局限: - 难以表示 PPO(Proximal Policy Optimization)所需的多个 DAG(Directed Acyclic Graph,有向无环图); - 难以在训练期间检查中间张量。 为了保持调试能力,我们选择了一种不同的方法 —— 将训练循环分解为明确定义的阶段,例如 ``generate_sequences``、 ``compute_advantages`` 等。 我们选择了 Ray `Ray `__ 作为 ``verl`` 的初始后端,因为它能够将 Python 类方法公开为 RPC 端点。然而,Ray 的默认模型仅支持 “一条方法调用,一次 RPC”,而训练大语言模型通常需要跨多个进程的协调。 为了向用户隐藏单个方法的多 Ray 演员(actors)调用,我们引入了以下组件: - ``WorkerGroup`` —— 管理一组远程工作者,并为多进程分布式计算提供统一的接口; - ``ResourcePool`` —— 将计算资源绑定到工作者进程; - ``ClassWithArgs`` —— 支持延迟远程实例化,并指定初始化参数。 -------------- A Running Example: ``generate_sequences`` ----------------------------------------- 为了说明设计思路,我们将逐步讲解 ``ActorRolloutRefWorker`` 类中的 ``generate_sequences`` 方法如何在分布式工作者间进行注册和调用。 -------------- Step 1: Register with a Decorator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 第一步是定义 ``generate_sequences`` 方法,并使用 ``@register`` 装饰器进行装饰,因为它将在驱动脚本中被调用。 **Source:** `fsdp_workers.py `__ .. code:: python class ActorRolloutRefWorker(Worker): ... @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): prompts = prompts.to(torch.cuda.current_device()) ... ``@register`` 装饰器为 ``generate_sequences`` 方法添加了元数据。目前,它不会改变功能,但会通过一个魔法键(``MAGIC_ATTR``)附加属性: **Source:** `decorator.py `__ .. code:: python def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): ... def decorator(func): @wraps(func) def inner(*args, **kwargs): if materialize_futures: args, kwargs = _materialize_futures(*args, **kwargs) return func(*args, **kwargs) attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} setattr(inner, MAGIC_ATTR, attrs) return inner return decorator 如代码所示,``dispatch_mode``、``execute_mode`` 和 ``blocking`` 的值被附加到 ``generate_sequences`` 方法上。 -------------- Step 2: Binding During Initialization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 在 ``ActorRolloutRefWorker`` 被包装在 ``RayClassWithArgs`` 中并传入 ``RayWorkerGroup`` 时,这些附加属性会被提取并使用。 **Source:** `main_generation.py `__ .. code:: python ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) 在 ``RayWorkerGroup`` `的初始化过程中 `__ ,会发生两个关键步骤: 1. 创建工作者实例(Ray 演员): `RayWorkerGroup._init_with_resource_pool `__ 2. 将使用 ``@register`` 装饰的方法绑定到 ``RayWorkerGroup``: `RayWorkerGroup._bind_worker_method `__ .. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true :alt: initialization_and_binding_of_worker_group initialization_and_binding_of_worker_group 绑定过程是 ``verl.single_controller`` 的核心。 **Key function:** `WorkerGroup._bind_worker_method `__ .. code:: python def _bind_worker_method(self, user_defined_cls, func_generator): ... for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) assert callable(method) except Exception: continue # Skip properties <<>> 如果方法具有 ``MAGIC_ATTR``,则会提取 ``@register`` 设置的属性: .. code:: python <<>> if hasattr(method, MAGIC_ATTR): attribute = getattr(method, MAGIC_ATTR) dispatch_mode = attribute["dispatch_mode"] execute_mode = attribute["execute_mode"] blocking = attribute["blocking"] <<>> 如上方的流程图所示,这些属性会被传入 ``func_generator``。但是,``func_generator`` 需要 ``method_name``、``dispatch_fn``、``collect_fn``、``execute_fn`` 和 ``blocking``。我们需要从 `DISPATCH_MODE_FN_REGISTRY``DISPATCH_MODE_FN_REGISTRY `__ 中查找与 ``dispatch_mode`` (``DP_COMPUTE_PROTO``)对应的 ``dispatch_fn`` 和 ``collect_fn``: .. code:: python3 DISPATCH_MODE_FN_REGISTRY = { Dispatch.ONE_TO_ALL: { "dispatch_fn": dispatch_one_to_all, "collect_fn": collect_all_to_all, }, ... Dispatch.DP_COMPUTE_PROTO: { "dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto, }, ... } 类似地,``execute_fn`` 根据 ``execute_mode`` 选择,并通过以下方式提取: .. code:: python <<>> # get execute_fn_name execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) wg_execute_fn_name = execute_mode["execute_fn_name"] # get execute_fn from string try: execute_fn = getattr(self, wg_execute_fn_name) assert callable(execute_fn), "execute_fn must be callable" except Exception: print(f"execute_fn {wg_execute_fn_name} is invalid") raise <<>> 在 ``generate_sequences`` 的情况下: - ``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` - ``dispatch_fn = dispatch_dp_compute_data_proto`` - ``collect_fn = collect_dp_compute_data_proto`` - ``execute_fn = RayWorkerGroup.execute_all`` ONE_TO_ALL v.s. DP_COMPUTE_PROTO ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``dispatch_mode`` 会关联一个 ``dispatch_fn`` 和一个 ``collect_fn``。顾名思义,``dispatch_fn`` 在 ``WorkerGroup`` 中处理输入参数,并生成一批(列表)输入参数,每个参数都会送入绑定到 ``WorkerGroup`` 的工作者。 ``ONE_TO_ALL`` 的 ``dispatch_fn`` 是 `dispatch_one_to_all `__ ,它只是将所有输入参数复制成 N 个副本,其中 N 等于绑定到 ``worker_group`` 的工作者数量: .. code:: python def dispatch_one_to_all(worker_group, *args, **kwargs): args = tuple([arg] * worker_group.world_size for arg in args) kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} return args, kwargs ``DP_COMPUTE_PROTO`` 的 ``dispatch_fn`` 是 `dispatch_dp_compute_data_proto `__ ,它使用 ``DataProto.chunk`` 将一个大的 ``DataProto`` 拆分成 N 个较小的 ``DataProto``,其中 N 等于 ``worker_group`` 的 world_size(工作者数量): .. code:: python def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) # Note: enable auto padding for dp compute DatapProto splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( worker_group.world_size, *args, **kwargs, ) return splitted_args, splitted_kwargs ``collect_fn`` 遵循相同的模式,处理来自 ``WorkerGroup`` 所有工作者的批量(列表)返回值,并像 ``collect_all_to_all`` 那样将其合并成一个列表,或像 ``collect_dp_compute_data_proto`` 那样合并成一个大的 ``DataProto``。 最后,使用 ``func_generator`` 动态生成一个新方法,并将其添加到 ``WorkerGroup`` 实例中: .. code:: python <<>> # bind a new method to the RayWorkerGroup func = func_generator( self, method_name, dispatch_fn=dispatch_fn, collect_fn=collect_fn, execute_fn=execute_fn, blocking=blocking, ) try: setattr(self, method_name, func) method_names.append(method_name) except Exception as e: raise ValueError(f"Fail to set method_name {method_name}") from e 这使得该方法可以通过 ``WorkerGroup`` 接口调用。 -------------- Step 3: Call Chain ~~~~~~~~~~~~~~~~~~ 上述所有机制确保分布式调用与单进程调用感觉完全一致。在原始单进程脚本中,代码如下: .. code:: python rollout = Rollout() rollout.generate_sequences(batch) 使用 ``verl`` 时,多进程程序变为: .. code:: python rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout)) rollout.generate_sequences(batch) .. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true :alt: call_chain_of_generate_sequences call_chain_of_generate_sequences 在这个简单的调用背后: - ``dispatch_fn`` 将输入拆分到工作者间 - ``execute_fn`` 执行实际的远程调用 - ``collect_fn`` 收集结果 这一切都被抽象化,使得开发者只需对现有逻辑进行最小修改即可编写分布式代码。 -------------- Beyond RL Post-Training: Generalizing ``verl.single_controller`` ---------------------------------------------------------------- ``verl.single_controller`` 模块不仅适用于强化学习,还适用于更广泛的场景。它提供了一种简洁的抽象,用于批量处理远程方法调用,并自动处理输入/输出。 通过最小化单进程脚本与多进程脚本间的差距,``verl.single_controller`` 为更广泛的分布式计算领域开辟了大门 —— 而不仅仅是 RL 后训练。 我们希望这个设计能激发社区中更多示例和扩展。