diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 5c2f489e9b53..94f9a1375c14 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -22,7 +22,7 @@ LinearMethodBase, QuantizeMethodBase, ) -from sglang.srt.layers.utils import MultiPlatformOp +from sglang.srt.layers.utils import MultiPlatformOp, copy_or_rebind_param from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -233,14 +233,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # because aiter CK kernels don't support all GEMM dimensions _should_use_aiter_moe = _use_aiter and get_moe_runner_backend().is_auto() if _should_use_aiter_moe: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, + copy_or_rebind_param( + layer, "w13_weight", shuffle_weight(layer.w13_weight.data, (16, 16)) ) torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, + copy_or_rebind_param( + layer, "w2_weight", shuffle_weight(layer.w2_weight.data, (16, 16)) ) torch.cuda.empty_cache()