The Design of verl.single_controller
Last updated: 05/21/2025.
Author: Wang Zhang
Preface
我们为 verl 的开发者准备了这份文档,特别是那些希望了解或为 verl.single_controller 模块做出贡献的人。这份文档并非面向最终用户,而是为那些希望理解架构设计 rationale (原因)以及内部机制的贡献者编写的。
Origin
single_controller 模块源于我收到的一项请求 —— 将一个玩具级的单进程 RLHF(Reinforcement Learning from Human Feedback)脚本适配成分布式系统,并尽可能少的修改代码,同时保持调试的便利性。
常见的做法 —— 例如使用 PyTorch 的 Distributed Data Parallel (DDP) —— 通常涉及包装 nn.Module 并启动多个进程,这些进程在不同 rank 下执行相同的函数。然而,这种方法在分布式 RLHF 上下文中存在两个主要局限: - 难以表示 PPO(Proximal Policy Optimization)所需的多个 DAG(Directed Acyclic Graph,有向无环图); - 难以在训练期间检查中间张量。
为了保持调试能力,我们选择了一种不同的方法 —— 将训练循环分解为明确定义的阶段,例如 generate_sequences、 compute_advantages 等。
我们选择了 Ray Ray 作为 verl 的初始后端,因为它能够将 Python 类方法公开为 RPC 端点。然而,Ray 的默认模型仅支持 “一条方法调用,一次 RPC”,而训练大语言模型通常需要跨多个进程的协调。
为了向用户隐藏单个方法的多 Ray 演员(actors)调用,我们引入了以下组件:
WorkerGroup—— 管理一组远程工作者,并为多进程分布式计算提供统一的接口;ResourcePool—— 将计算资源绑定到工作者进程;ClassWithArgs—— 支持延迟远程实例化,并指定初始化参数。
A Running Example: generate_sequences
为了说明设计思路,我们将逐步讲解 ActorRolloutRefWorker 类中的 generate_sequences 方法如何在分布式工作者间进行注册和调用。
Step 1: Register with a Decorator
第一步是定义 generate_sequences 方法,并使用 @register 装饰器进行装饰,因为它将在驱动脚本中被调用。
Source: fsdp_workers.py
class ActorRolloutRefWorker(Worker):
...
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
prompts = prompts.to(torch.cuda.current_device())
...
@register 装饰器为 generate_sequences 方法添加了元数据。目前,它不会改变功能,但会通过一个魔法键(MAGIC_ATTR)附加属性:
Source: decorator.py
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
...
def decorator(func):
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
args, kwargs = _materialize_futures(*args, **kwargs)
return func(*args, **kwargs)
attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
setattr(inner, MAGIC_ATTR, attrs)
return inner
return decorator
如代码所示,dispatch_mode、execute_mode 和 blocking 的值被附加到 generate_sequences 方法上。
Step 2: Binding During Initialization
在 ActorRolloutRefWorker 被包装在 RayClassWithArgs 中并传入 RayWorkerGroup 时,这些附加属性会被提取并使用。
Source: main_generation.py
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
在 RayWorkerGroup
的初始化过程中
,会发生两个关键步骤:
创建工作者实例(Ray 演员): RayWorkerGroup._init_with_resource_pool
将使用
@register装饰的方法绑定到RayWorkerGroup: RayWorkerGroup._bind_worker_method
initialization_and_binding_of_worker_group
绑定过程是 verl.single_controller 的核心。
Key function: WorkerGroup._bind_worker_method
def _bind_worker_method(self, user_defined_cls, func_generator):
...
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method)
except Exception:
continue # Skip properties
<<<to be continue 1>>>
如果方法具有 MAGIC_ATTR,则会提取 @register 设置的属性:
<<<continue 1>>>
if hasattr(method, MAGIC_ATTR):
attribute = getattr(method, MAGIC_ATTR)
dispatch_mode = attribute["dispatch_mode"]
execute_mode = attribute["execute_mode"]
blocking = attribute["blocking"]
<<<to be continue 2>>>
如上方的流程图所示,这些属性会被传入 func_generator。但是,func_generator 需要 method_name、dispatch_fn、collect_fn、execute_fn 和 blocking。我们需要从 DISPATCH_MODE_FN_REGISTRY``DISPATCH_MODE_FN_REGISTRY
中查找与 dispatch_mode (DP_COMPUTE_PROTO)对应的 dispatch_fn 和 collect_fn:
DISPATCH_MODE_FN_REGISTRY = {
Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all,
"collect_fn": collect_all_to_all,
},
...
Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto,
},
...
}
类似地,execute_fn 根据 execute_mode 选择,并通过以下方式提取:
<<<continue 2>>>
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode["execute_fn_name"]
# get execute_fn from string
try:
execute_fn = getattr(self, wg_execute_fn_name)
assert callable(execute_fn), "execute_fn must be callable"
except Exception:
print(f"execute_fn {wg_execute_fn_name} is invalid")
raise
<<<to be continue 3>>>
在 generate_sequences 的情况下: -
dispatch_mode = Dispatch.DP_COMPUTE_PROTO -
dispatch_fn = dispatch_dp_compute_data_proto -
collect_fn = collect_dp_compute_data_proto -
execute_fn = RayWorkerGroup.execute_all
ONE_TO_ALL v.s. DP_COMPUTE_PROTO
dispatch_mode 会关联一个 dispatch_fn 和一个 collect_fn。顾名思义,dispatch_fn 在 WorkerGroup 中处理输入参数,并生成一批(列表)输入参数,每个参数都会送入绑定到 WorkerGroup 的工作者。
ONE_TO_ALL 的 dispatch_fn 是
dispatch_one_to_all
,它只是将所有输入参数复制成 N 个副本,其中 N 等于绑定到 worker_group 的工作者数量:
def dispatch_one_to_all(worker_group, *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
DP_COMPUTE_PROTO 的 dispatch_fn 是
dispatch_dp_compute_data_proto
,它使用 DataProto.chunk 将一个大的 DataProto 拆分成 N 个较小的 DataProto,其中 N 等于 worker_group 的 world_size(工作者数量):
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
# Note: enable auto padding for dp compute DatapProto
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
worker_group.world_size,
*args,
**kwargs,
)
return splitted_args, splitted_kwargs
collect_fn 遵循相同的模式,处理来自 WorkerGroup 所有工作者的批量(列表)返回值,并像 collect_all_to_all 那样将其合并成一个列表,或像 collect_dp_compute_data_proto 那样合并成一个大的 DataProto。
最后,使用 func_generator 动态生成一个新方法,并将其添加到 WorkerGroup 实例中:
<<<continue 3>>>
# bind a new method to the RayWorkerGroup
func = func_generator(
self,
method_name,
dispatch_fn=dispatch_fn,
collect_fn=collect_fn,
execute_fn=execute_fn,
blocking=blocking,
)
try:
setattr(self, method_name, func)
method_names.append(method_name)
except Exception as e:
raise ValueError(f"Fail to set method_name {method_name}") from e
这使得该方法可以通过 WorkerGroup 接口调用。
Step 3: Call Chain
上述所有机制确保分布式调用与单进程调用感觉完全一致。在原始单进程脚本中,代码如下:
rollout = Rollout()
rollout.generate_sequences(batch)
使用 verl 时,多进程程序变为:
rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))
rollout.generate_sequences(batch)
call_chain_of_generate_sequences
在这个简单的调用背后: - dispatch_fn 将输入拆分到工作者间 - execute_fn 执行实际的远程调用 - collect_fn 收集结果
这一切都被抽象化,使得开发者只需对现有逻辑进行最小修改即可编写分布式代码。
Beyond RL Post-Training: Generalizing verl.single_controller
verl.single_controller 模块不仅适用于强化学习,还适用于更广泛的场景。它提供了一种简洁的抽象,用于批量处理远程方法调用,并自动处理输入/输出。
通过最小化单进程脚本与多进程脚本间的差距,verl.single_controller 为更广泛的分布式计算领域开辟了大门 —— 而不仅仅是 RL 后训练。
我们希望这个设计能激发社区中更多示例和扩展。