SGLang 后端
最后更新日期:05/31/2025。
作者:SGLang RL 团队,按姓氏字母顺序排列
Jingyi Chen, Yitong Guan, Zhuobin Huang, Jiajun Li, Ji Li, Shenggui Li, Junrong Lin, Xiang Long, Rui Lu, Jin Pan, Shuai Shi, Yushen Su, Xinyuan Tong, Chendong Wang, Hanchen Zhang, Haoran Wang, Yongan Xiang, Chengxing Xie, Yuhao Yang, Jinwei Yao, Qiaolin Yu, Yuzhen Zhou, Chenyang Zhao
简介
SGLang 是一个开源的高级推理服务引擎,已被 xAI 全面采用,用于在研究和提供服务过程中支持 Grok 的所有推理需求。
目前,verl 完全支持在 rollout 阶段将 SGLang 用作推理引擎。作为 rollout 引擎,SGLang 提供了与 vLLM 相同的功能覆盖范围,包括内存节省和多节点 rollout 功能。安装 verl 和 SGLang 后,只需在启动脚本中添加 actor_rollout_ref.rollout.name=sglang 即可无缝切换这两种推理框架。
此外,SGLang 团队正积极致力于支持多轮代理强化学习、视觉语言模型 RLHF(基于人类反馈的强化学习)、基于服务器的 RLHF 以及部分 rollout 等功能。您可以在 跟踪路线图 中查看相关开发进展。
安装
请始终遵循以下命令来安装带有 verl 的 SGLang。
pip install --upgrade pip
# 当前版本 0.4.8,随时可能更新,请参考 setup.py 中指定的最新版本
pip install -e ".[sglang]"
您可以检查以下依赖项是否在您的环境中:
Note
PyTorch:2.6.0+cu124
CUDA:12.4
flashinfer-python:0.2.5+cu124torch2.6
SGLang:0.4.6.post5
sgl-kernel:0.1.4
在单机上使用 SGLang 作为 PPO 训练的推理后端
我们使用 Qwen/Qwen2-7B-Instruct 在 gsm8k 数据集上进行简单测试。
运行以下命令来准备 gsm8k 数据集:
python3 examples/data_preprocess/gsm8k.py
运行以下脚本,在单机上使用 4 个 GPU 进行 PPO 实验:
export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=4096 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.model.fsdp_config.param_offload=True \
critic.model.fsdp_config.optimizer_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=console \
trainer.val_before_train=False \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
为什么要导出 SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK?
verl在 rollout( rollout 指的是模型推理与生成过程)期间会初始化一个SGLangRollout模块,用于评估或生成样本。SGLangRollout会初始化 ``Engine``(引擎),进而初始化一个 ``torch.distributed.DeviceMesh``(PyTorch 分布式设备网格),用于支持张量并行 (TP)。DeviceMesh.init()在内部会检查所有参与设备的空闲 GPU 内存。如果差异过大(大约超过 10%),它会直接报错,以避免初始化失败或死锁。
为什么 GPU 内存可能不一致?
1. Ray 分布式 Actor 会在不同时间加载模型
verl 使用基于 Ray 的多进程、多 GPU 并发训练。每个 WorkerDict 可能在不同时间被调用:
self.rollout = SGLangRollout(...)
不同的工作者(workers)在不同时间初始化模型 → 导致不同的内存使用情况。
2. 延迟初始化导致内存偏差
某些工作者会比其他工作者更早开始模型加载或推理(例如,``generate_sequences()``(生成序列)、``compute_log_prob()``(计算对数概率))。 早期工作者已经占用了 GPU 内存 → 晚期工作者仍保留空闲内存 → 出现内存差异。
3. SGLang 的 TP 初始化使用“全设备广播”,但释放时机不统一
尽管 SGLangRollout 可能仅涉及部分 GPU,但其 Engine 初始化会调用 torch.distributed.init_process_group() 并广播权重,因此:
非 rollout GPU 也会加入通信。
后来
DeviceMesh初始化会因“内存不一致”而失败。
4. FSDP/TP 加载行为不同也会导致不匹配
如果使用:
actor.fsdp_config.param_offload=True
ref.fsdp_config.param_offload=True
那么某些工作进程会将参数保留在 CPU 上,而其他工作进程已将参数分片到 GPU → 导致不对称的内存布局。
在多机上使用 SGLang 作为 PPO 训练的推理后端
SGLang 还支持在 IPv4 和 IPv6 场景下运行 verl 的基于 RAY 的跨机推理。在下面的脚本中,我们使用 TP=16(张量并行度为 16)进行跨机推理。假设我们有两台互连的机器:node0 IP 为 10.94.16.4,node1 IP 为 10.94.16.5。
在 node0 上启动 Ray:
ray start --head --dashboard-host=0.0.0.0
您将看到以下提示:
使用情况统计已启用。要禁用此功能,请在启动集群的命令中添加 `--disable-usage-stats`,或在启动集群之前运行以下命令:`ray disable-usage-stats`。有关详细信息,请参见 https://docs.ray.io/en/master/cluster/usage-stats.html。
本地节点 IP:10.94.16.4
--------------------
Ray 运行时已启动。
--------------------
后续步骤
要向此 Ray 集群添加另一个节点,请运行
ray start --address='10.94.16.4:6379'
让 node1 加入 Ray 集群:
在 node1 上运行以下命令:
ray start --address='10.94.16.4:6379'
运行以下命令以确认 Ray 集群现在有两台机器:
ray status
您可以看到集群有两台机器,共 16 个 GPU:
======== Autoscaler 状态:2025-04-09 09:25:37.694016 ========
节点状态
---------------------------------------------------------------
活跃:
1 node_ef382ffd687d8f6b060c1b68e63ada7341b936fe5b1901dd04de1027
1 node_1eb4d7d07e793114c23a89d1a41f1f76acf6ef5b35af844a4ee8e4ba
待处理:
(no pending nodes)
最近失败:
(no failures)
资源
---------------------------------------------------------------
使用:
0.0/360.0 CPU
0.0/16.0 GPU
0B/3.39TiB 内存
0B/372.53GiB 对象存储内存
运行以下脚本,使用 TP=16 在 2 台机器上通过 16 个 GPU 训练 meta-llama/Llama-3.1-8B-Instruct:
DATA_DIR=$HOME/data/gsm8k
python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.name=sglang \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=4096 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
actor_rollout_ref.model.path=meta-llama/Llama-3.1-8B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=16 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=meta-llama/Llama-3.1-8B-Instruct \
critic.model.enable_gradient_checkpointing=True \
critic.ppo_micro_batch_size=16 \
critic.model.fsdp_config.param_offload=True \
critic.model.fsdp_config.optimizer_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=console \
trainer.val_before_train=True \
trainer.n_gpus_per_node=8 \
trainer.nnodes=2 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15 2>&1 | tee verl_demo.log