diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py index 23450d264f4..22d95411bc1 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py @@ -1397,11 +1397,7 @@ def load_weights(self, assert len(weights) == 1 weights = weights[0] - self.quant_method.load_weights( - self, - weights, - self.weight_loading_mode, - allow_partial_loading=allow_partial_loading) + self.quant_method.load_weights(self, weights, self.weight_loading_mode) def post_load_weights(self): self.quant_method.post_load_weights(self) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 5e80d4840cf..97f5a22a88b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -332,17 +332,17 @@ def load_weights(self, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode, allow_partial_loading: bool = False): + additional_kargs = {} + if "allow_partial_loading" in inspect.getfullargspec( + self.load_expert_weights_to_dst).args: + additional_kargs["allow_partial_loading"] = allow_partial_loading self.load_expert_weights_to_dst( - module, - weights, - weight_loading_mode, - module.initial_local_expert_ids, - module.w3_w1_weight.data, + module, weights, weight_loading_mode, + module.initial_local_expert_ids, module.w3_w1_weight.data, module.w2_weight.data, module.w3_w1_bias.data if module.bias else None, - module.w2_bias.data if module.bias else None, - allow_partial_loading=allow_partial_loading) + module.w2_bias.data if module.bias else None, **additional_kargs) self.load_quant_scales(module, weights) @@ -397,15 +397,12 @@ def load_weights(self, setattr(module, 'local_shared_w2_bias_tensors', local_shared_w2_bias_tensors) self.load_expert_weights_to_dst( - module, - weights, - weight_loading_mode, - local_shared_load_expert_ids, - local_shared_w3_w1_tensors, + module, weights, weight_loading_mode, + local_shared_load_expert_ids, local_shared_w3_w1_tensors, local_shared_w2_tensors, local_shared_w3_w1_bias_tensors if module.bias else None, local_shared_w2_bias_tensors if module.bias else None, - allow_partial_loading=allow_partial_loading) + **additional_kargs) def post_load_weights(self, module: torch.nn.Module): if self.need_load_shared_weights(module):