Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
Expand Down Expand Up @@ -724,16 +725,12 @@ def create_weights(
shard_width = getattr(
layer, "input_size_per_partition", input_size_per_partition
)
shard_offset = qweight.tp_rank * shard_width
g_idx = torch.tensor(
[
(shard_offset + i) // self.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
)
shard_offset = get_tensor_model_parallel_rank() * shard_width
g_idx = (
torch.arange(input_size_per_partition, dtype=torch.int32) + shard_offset
) // self.group_size
layer.register_parameter("g_idx", Parameter(g_idx, requires_grad=False))
layer.register_buffer("_inc_tail_dequant_weight", None, persistent=False)
layer._inc_tail_dequant_weight = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.sym:
Expand Down
Loading