使用 FSDP 后端添加模型 ============================ 上次更新:02/09/2025。 Model -------------------------- 原则上,我们的 FSDP 后端 可以支持任何 HF 模型,我们可以使用位于 `third_party/vllm` 中的 `hf_weight_loader.py` 来同步 actor 模型权重和 vLLM。然而,``hf_weight_loader`` 会在同步期间收集模型的完整 state_dict,这可能导致 OOM(内存溢出)。我们建议使用 ``dtensor_weight_loader``,它会逐层收集完整模型参数,以减少峰值内存使用。我们已经在位于 `third_party/vllm` 的 `dtensor_weight_loader.py` 中为以下模型支持了 dtensor 权重加载器: - ``GPT2LMHeadModel`` - ``LlamaForCausalLM`` - ``LLaMAForCausalLM`` - ``MistralForCausalLM`` - ``InternLMForCausalLM`` - ``AquilaModel`` - ``AquilaForCausalLM`` - ``Phi3ForCausalLM`` - ``GemmaForCausalLM`` - ``Gemma2ForCausalLM`` - ``GPTBigCodeForCausalLM`` - ``Starcoder2ForCausalLM`` - ``Qwen2ForCausalLM`` - ``DeepseekV2ForCausalLM`` 要为 vLLM 支持的模型实现 ``dtensor_weight_loader``,请按照以下 gemma 模型的指南操作: 1. 将 ``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` 函数从 vLLM 模型类复制到 ``dtensor_weight_loaders.py`` 中。 2. 将参数修改为 ``(actor_weights: Dict, vllm_model: nn.Module)``。 3. 将 ``self`` 替换为 ``vllm_model``。 4. 在每个 ``param = params_dict[name]`` 之前添加 ``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)``,并修改后续的权重加载以使用 ``local_loaded_weight``。 5. 将实现的 dtensor 权重加载器注册到 ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``。 .. code-block:: diff - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - params_dict = dict(self.named_parameters()) + params_dict = dict(vllm_model.named_parameters()) loaded_params = set() - for name, loaded_weight in weights: + for name, loaded_weight in actor_weights.items(): for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) break else: # lm_head is not used in vllm as it is tied with embed_token. # To prevent errors, skip loading lm_head.weight. if "lm_head.weight" in name: continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: raise RuntimeError( "Some weights are not initialized from checkpoints: " f"{unloaded_params}")