Skip to content
Merged
6 changes: 5 additions & 1 deletion python/sglang/srt/compilation/piecewise_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def set_forward_batch(self, forward_batch: ForwardBatch):
def set_attention_layers(self, layers: List[Any]):
self.attention_layers = layers

def set_quant_config(self, quant_config: Any):
self.quant_config = quant_config


_forward_context: Optional[ForwardContext] = None

Expand All @@ -28,12 +31,13 @@ def get_forward_context() -> Optional[ForwardContext]:


@contextmanager
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any], quant_config: Any):
global _forward_context
prev_forward_context = _forward_context
_forward_context = ForwardContext()
_forward_context.set_forward_batch(forward_batch)
_forward_context.set_attention_layers(attention_layers)
_forward_context.set_quant_config(quant_config)
try:
yield
finally:
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,3 +954,25 @@ def apply(
use_wna16=True,
)
return StandardCombineInput(hidden_states=output)


# Register fake implementations for torch.compile support
if _is_cuda:

@torch.library.register_fake("sgl_kernel::awq_dequantize")
def _(
qweight,
scales,
qzeros,
ch_axis,
group_size,
num_bits,
):
out_shape = qweight.shape[:-1] + (qweight.shape[-1] * 32 // num_bits,)
return qweight.new_empty(out_shape, dtype=scales.dtype)

@torch.library.register_fake("sgl_kernel::awq_marlin_repack")
def _(b_q_weight, size_k, size_n, num_bits):
return b_q_weight.new_empty(
(size_k // 16, size_n * (num_bits // 2)), dtype=b_q_weight.dtype
)
47 changes: 47 additions & 0 deletions python/sglang/srt/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,3 +1094,50 @@ def apply(
is_k_full=self.is_k_full,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)


# Register fake implementations for torch.compile support
if _is_cuda:

@torch.library.register_fake("sgl_kernel::gptq_gemm")
def _(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit):
return a.new_empty((a.shape[0], b_q_weight.shape[-1]), dtype=a.dtype)

@torch.library.register_fake("sgl_kernel::gptq_marlin_repack")
def _(b_q_weight, perm, size_k, size_n, num_bits):
return b_q_weight.new_empty(
(size_k // 16, size_n * (num_bits // 2)), dtype=b_q_weight.dtype
)

@torch.library.register_fake("sgl_kernel::gptq_shuffle")
def _(q_weight, q_perm, bit):
return

@torch.library.register_fake("sgl_kernel::moe_wna16_marlin_gemm")
def _(
a,
c,
b_q_weight,
b_scales,
b_zeros,
g_idx,
perm,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size,
top_k,
mul_topk_weights,
is_ep,
b_q_type_id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
):
return c
247 changes: 210 additions & 37 deletions python/sglang/srt/layers/quantization/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE

from sglang.srt.compilation.piecewise_context_manager import get_forward_context

try:
from vllm import _custom_ops as ops
except ImportError:
ops = None

from sglang.srt.utils import direct_register_custom_op

_is_cuda = is_cuda()

if _is_cuda:
Expand Down Expand Up @@ -483,25 +487,44 @@ def apply_gptq_marlin_linear(
dtype=input.dtype,
)

output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
forward_context = get_forward_context()
if forward_context is None:
output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
else:
output = torch.ops.sglang.unified_apply_gptq_marlin_gemm_with_wtype(
input=reshaped_x,
weight=weight,
weight_scale=weight_scale,
weight_zp=weight_zp,
g_idx=g_idx,
g_idx_sort_indices=g_idx_sort_indices,
workspace=workspace,
wtype_id=wtype.id,
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)

if bias is not None:
output.add_(bias) # In-place add
Expand Down Expand Up @@ -534,24 +557,42 @@ def apply_awq_marlin_linear(
dtype=input.dtype,
)

output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# print(f"Quant type: {quant_type}")
forward_context = get_forward_context()
if forward_context is None:
output = gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
else:
output = torch.ops.sglang.unified_apply_gptq_marlin_gemm(
input=reshaped_x,
weight=weight,
weight_scale=weight_scale,
weight_zp=weight_zp,
g_idx=g_idx,
g_idx_sort_indices=g_idx_sort_indices,
workspace=workspace,
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)

if bias is not None:
output.add_(bias) # In-place add
Expand Down Expand Up @@ -818,3 +859,135 @@ def apply(
output.add_(bias) # In-place add

return output


def unified_apply_gptq_marlin_gemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
output_size_per_partition: int,
input_size_per_partition: int,
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
) -> torch.Tensor:
quant_config = get_forward_context().quant_config
quant_type = quant_config.quant_type
return gptq_marlin_gemm(
input,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=input.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=is_zp_float,
)

def fake_unified_apply_gptq_marlin_gemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
output_size_per_partition: int,
input_size_per_partition: int,
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
) -> torch.Tensor:
return input.new_empty((input.shape[0], output_size_per_partition), dtype=input.dtype)


direct_register_custom_op(
op_name="unified_apply_gptq_marlin_gemm",
op_func=unified_apply_gptq_marlin_gemm,
mutates_args=[],
fake_impl=fake_unified_apply_gptq_marlin_gemm,
)


def unified_apply_gptq_marlin_gemm_with_wtype(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype_id: int,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
) -> torch.Tensor:
# Reconstruct ScalarType from id
wtype = None
for attr_name in dir(scalar_types):
if not attr_name.startswith('_'):
st = getattr(scalar_types, attr_name)
if hasattr(st, 'id') and st.id == wtype_id:
wtype = st
break
return gptq_marlin_gemm(
input,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=input.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=is_zp_float,
)


def fake_unified_apply_gptq_marlin_gemm_with_wtype(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype_id: int,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
) -> torch.Tensor:
return input.new_empty((input.shape[0], output_size_per_partition), dtype=input.dtype)


direct_register_custom_op(
op_name="unified_apply_gptq_marlin_gemm_with_wtype",
op_func=unified_apply_gptq_marlin_gemm_with_wtype,
mutates_args=[],
fake_impl=fake_unified_apply_gptq_marlin_gemm_with_wtype,
)
Loading
Loading