PyTorch FSDP 后端

最后更新:2025/12/01。

我们通过实现用于 actor、critic、reference、rollout 和 reward 模型的各种 worker 来提供对 PyTorch FSDP 后端的支持。

优势

  • 随时支持各种模型。

    • 用户只需要针对 FSDP 和 vLLM 之间的权重同步实现相应的 dtensor_weight_loader。而对于 hf_weight_loader,用户可以直接应用任何同时在 HF 和 vLLM 中受支持的模型,而无需任何代码更改。

  • 易于组织每个模型的前向和反向计算。

劣势

  • 当处理大规模模型(例如 Llama 70B 和 405B)时,可扩展性较差。

  • actor 和 rollout 之间的重新分片开销可能比 Megatron-LM 后端更大。

鉴于其简单性,我们推荐使用 FSDP 后端进行算法研究和原型验证。

FSDP Worker

ActorRolloutRefWorker

Actor/Rollout HybridEngine

  1. HybridEngine、Actor 和 Rollout 初始化 API。

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):

ONE_TO_ALL:当从驱动进程调用 init_model 函数时,每个 worker(位于 GPU 上)将执行以下模型初始化过程。

HybridEngine、Actor 和 Rollout 的初始化细节如下:

  1. DataParallelPPOActor 实现了当模型基于 FSDP 构建时简单的 PPO 计算逻辑,包括计算对数概率、模型更新。

  2. vLLMRollout 支持使用 vLLM 进行生成。我们修改了 vLLM Engine,使其在 SPMD 下执行,以适应我们的 WorkerGroup 设计。

请参考 源代码 了解更多信息。

  1. 生成序列并重新计算对数概率

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
  • Dispatch.DP_COMPUTE_PROTO:数据将沿着 DP 维度进行分发和收集

  • 在此函数中,rollout 模型将执行自回归生成,而 actor 模型将重新计算生成响应的旧对数概率。

  1. 更新 actor 模型

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
  • 使用 PPO 和熵损失更新 actor 模型权重。

ReferenceModel

  1. 参考模型初始化

参考模型使用与 actor 模型相同的函数进行初始化,但不初始化 HybridEngine 和 Optimizer。然后 actor 模型也被 DataParallelPPOActor 包装。

  1. 计算参考对数概率

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
  • 在此函数中,参考模型将调用 DataParallelPPOActor 中的计算对数概率函数来计算参考对数概率。

CriticWorker 和 RewardWorker

  1. 模型初始化

与参考模型非常相似。CriticWorker 将为 Optimizer 执行额外的初始化。

  1. 为 CriticWorker 计算值

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
  1. 更新 Critic

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
  1. 计算奖励

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):

HybridShard

我们不支持 FSDP HybridShard。要支持此功能,我们可能需要构建一个 2D 设备网格,并为每个模型测试相应的 dtensor_weight_loaderhf_weight_loader