From 98e0effd1974e5afa8851322df1176cff3b6891a Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Sun, 19 Apr 2026 20:12:10 +0900 Subject: [PATCH] Fix INC tail-shard review issues Co-authored-by: OpenAI Codex Signed-off-by: lesj0610 --- vllm/model_executor/layers/quantization/inc.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index c15998d08bc5..d67bff572a7d 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( LinearBase, @@ -724,16 +725,12 @@ def create_weights( shard_width = getattr( layer, "input_size_per_partition", input_size_per_partition ) - shard_offset = qweight.tp_rank * shard_width - g_idx = torch.tensor( - [ - (shard_offset + i) // self.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ) + shard_offset = get_tensor_model_parallel_rank() * shard_width + g_idx = ( + torch.arange(input_size_per_partition, dtype=torch.int32) + shard_offset + ) // self.group_size layer.register_parameter("g_idx", Parameter(g_idx, requires_grad=False)) - layer.register_buffer("_inc_tail_dequant_weight", None, persistent=False) + layer._inc_tail_dequant_weight = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.sym: