Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,18 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)

from sglang.srt.layers.parameter import _ColumnvLLMParameter

if isinstance(param, _ColumnvLLMParameter):
# FIXME: why would we need this special case?
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
else:
param.load_column_parallel_weight(loaded_weight)

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ def __init__(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
Expand Down
56 changes: 51 additions & 5 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Callable, Dict, Optional, Type

from typing import Dict, Type

import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
Expand Down Expand Up @@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix):


def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)

from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead

if isinstance(layer, LinearBase):
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None


original_awq_moe_method_apply = AWQMoEMethod.apply


def awq_moe_method_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
return original_awq_moe_method_apply(
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)


def patch_vllm_linear_base_isinstance():
import builtins

Expand All @@ -107,8 +150,11 @@ def patched_isinstance(obj, classinfo):

def apply_monkey_patches():
"""Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod

setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)


patch_vllm_linear_base_isinstance()
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def __init__(
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -455,6 +457,8 @@ def __init__(
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)

Expand Down
Loading