diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e7a831399547..6b3620f12174 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -385,6 +385,14 @@ def latency_test( 8, # shorter decoding to speed up the warmup server_args.device, ) + + try: + from gemlite.core import GemLiteLinearTriton + import os, pwd + GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except ImportError: + pass + rank_print("Benchmark ...") # Run the sweep diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 910309da973f..ced19e096da0 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -25,6 +25,7 @@ def apply_torchao_config_to_model( int8_weight_only, quantize_, ) + from torchao.quantization.quant_api import gemlite_uintx_weight_only from torchao.quantization.observer import PerRow, PerTensor if filter_fn is None: @@ -47,6 +48,20 @@ def filter_fn(module, fqn): 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + elif "gemlite" in torchao_config: + _quant_args = torchao_config.split("-") + bit_width = int(_quant_args[-2]) + group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) + packing_bitwidth = int(_quant_args[-3]) + quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth), filter_fn=filter_fn) + + import pwd + import os + from gemlite.core import GemLiteLinearTriton + try: + GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except: + pass elif "fp8wo" in torchao_config: # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89