diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a1cf43328bab..1547a016ab88 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -634,7 +634,7 @@ def row_parallel_weight_loader(param: torch.Tensor, return default_weight_loader(param, loaded_weight) -LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None] def sharded_weight_loader(shard_axis: int) -> LoaderFunction: