diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 3f9ca3dee2b..2121056211f 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,5 +1,6 @@ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" +import itertools import logging from abc import abstractmethod from typing import Dict, List, Optional, Tuple @@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_bitsandbytes_4bit_shard( - param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str + param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str ) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" - total, _ = qkv_offsets["total"] - orig_offset, orig_size = qkv_offsets[loaded_shard_id] + total, _ = shard_offsets["total"] + orig_offset, orig_size = shard_offsets[loaded_shard_id] quantized_total = param.data.shape[0] quantized_offset = orig_offset * quantized_total // total @@ -573,6 +574,8 @@ def weight_loader( shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantization. # If quantized, we need to adjust the offset and size to account @@ -585,6 +588,17 @@ def weight_loader( param, shard_size, shard_offset ) + if use_bitsandbytes_4bit: + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id) + ) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 57707c3499d..008e542048a 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module): column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), + ".q_proj": (".qkv_proj", 0), + ".k_proj": (".qkv_proj", 1), + ".v_proj": (".qkv_proj", 2), + ".gate_proj": (".gate_up_proj", 0), + ".up_proj": (".gate_up_proj", 1), } def __init__( diff --git a/test/srt/models/test_unsloth_models.py b/test/srt/models/test_unsloth_models.py new file mode 100644 index 00000000000..24660ea34fc --- /dev/null +++ b/test/srt/models/test_unsloth_models.py @@ -0,0 +1,213 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestUnslothPhi4(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/phi-4" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.78) + + +class TestUnslothPhi4Bnb4bit(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/phi-4-bnb-4bit" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--load-format", + "bitsandbytes", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.75) + + +class TestUnslothPhi4UnslothBnb4bit(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/phi-4-unsloth-bnb-4bit" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--load-format", + "bitsandbytes", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.75) + + +class TestUnslothPhi4MiniInstruct(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/Phi-4-mini-instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.65) + + +class TestUnslothPhi4MiniBnb4bit(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--load-format", + "bitsandbytes", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.6) + + +class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--load-format", + "bitsandbytes", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.6) + + +if __name__ == "__main__": + unittest.main()