From d3024f4fc8d9956e817bdef8fdbdae26c474b448 Mon Sep 17 00:00:00 2001 From: bjmsong Date: Sat, 18 Jan 2025 11:43:22 +0800 Subject: [PATCH] support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894) Co-authored-by: bjmsong --- .../sglang/srt/model_loader/weight_utils.py | 55 +++++++++++++++++- python/sglang/srt/models/llama.py | 6 +- python/sglang/srt/models/qwen2.py | 36 +++++++++++- python/sglang/test/test_utils.py | 1 + test/srt/kv_cache_scales_llama3_8b.json | 42 ++++++++++++++ test/srt/kv_cache_scales_qwen2_1_5b.json | 38 +++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_fp8_kvcache.py | 57 +++++++++++++++++-- 8 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 test/srt/kv_cache_scales_llama3_8b.json create mode 100644 test/srt/kv_cache_scales_qwen2_1_5b.json diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 015c6514530..77c3fcbee74 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -9,7 +9,17 @@ import os import tempfile from collections import defaultdict -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import filelock import gguf @@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: # If there were no matches, return the untouched param name return name + + +def kv_cache_scales_loader( + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], +) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + except FileNotFoundError: + logger.error("File or directory '%s' not found.", filename) + except json.JSONDecodeError: + logger.error("Error decoding JSON in file '%s'.", filename) + except Exception: + logger.exception("An error occurred while reading '%s'.", filename) + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", + tp_rank, + ) + return [] diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 198d53995e4..9ea80d0c05d 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -23,7 +23,6 @@ from torch import nn from transformers import LlamaConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -45,7 +44,10 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index f1d37118a86..04faa8dea1b 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -22,7 +22,10 @@ from torch import nn from vllm.model_executor.layers.rotary_embedding import get_rope -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -39,7 +42,10 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers Qwen2Config = None @@ -265,6 +271,29 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + class Qwen2ForCausalLM(nn.Module): @@ -373,5 +402,8 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 42e0b6d808a..d3c9b7cab5f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -40,6 +40,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" +DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" def is_in_ci(): diff --git a/test/srt/kv_cache_scales_llama3_8b.json b/test/srt/kv_cache_scales_llama3_8b.json new file mode 100644 index 00000000000..466b0d01a74 --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0408, + "1": 0.0503, + "2": 0.0667, + "3": 0.0909, + "4": 0.1135, + "5": 0.127, + "6": 0.1768, + "7": 0.1488, + "8": 0.1135, + "9": 0.1203, + "10": 0.1013, + "11": 0.0842, + "12": 0.1231, + "13": 0.1096, + "14": 0.1221, + "15": 0.1013, + "16": 0.1067, + "17": 0.0952, + "18": 0.0899, + "19": 0.097, + "20": 0.087, + "21": 0.0994, + "22": 0.0904, + "23": 0.1013, + "24": 0.1019, + "25": 0.1053, + "26": 0.1, + "27": 0.0894, + "28": 0.1013, + "29": 0.1488, + "30": 0.0766, + "31": 0.0821 + } + } + } +} diff --git a/test/srt/kv_cache_scales_qwen2_1_5b.json b/test/srt/kv_cache_scales_qwen2_1_5b.json new file mode 100644 index 00000000000..984747509f7 --- /dev/null +++ b/test/srt/kv_cache_scales_qwen2_1_5b.json @@ -0,0 +1,38 @@ +{ + "model_type": "qwen", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.9846, + "1": 0.0645, + "2": 0.0731, + "3": 0.0800, + "4": 0.0748, + "5": 0.0780, + "6": 0.0702, + "7": 0.0894, + "8": 0.0410, + "9": 0.0758, + "10": 0.0556, + "11": 0.0731, + "12": 0.0899, + "13": 0.0780, + "14": 0.1441, + "15": 0.0914, + "16": 0.5614, + "17": 0.1067, + "18": 0.0537, + "19": 0.0658, + "20": 0.0523, + "21": 0.0533, + "22": 0.0699, + "23": 0.0635, + "24": 0.0588, + "25": 0.0884, + "26": 0.0947, + "27": 0.1032 + } + } + } +} diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e2ecdfb3a68..fb1c6abf29b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -52,6 +52,7 @@ "test_vision_openai_server.py", "test_w8a8_quantization.py", "test_session_control.py", + "test_fp8_kvcache.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py index 0d6602997de..4a8a2434699 100644 --- a/test/srt/test_fp8_kvcache.py +++ b/test/srt/test_fp8_kvcache.py @@ -6,19 +6,26 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) -class TestFp8Kvcache(unittest.TestCase): +class TestFp8KvcacheBase(unittest.TestCase): + model_config = None + @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + if cls.model_config is None: + raise NotImplementedError("model_config must be specified in subclass") + + cls.model = cls.model_config["model_name"] cls.base_url = DEFAULT_URL_FOR_TEST dirpath = os.path.dirname(__file__) - config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json") + config_file = os.path.join(dirpath, cls.model_config["config_filename"]) + cls.process = popen_launch_server( cls.model, cls.base_url, @@ -31,6 +38,13 @@ def setUpClass(cls): ], ) + +class TestFp8KvcacheLlama(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_MODEL_NAME_FOR_TEST, + "config_filename": "kv_cache_scales_llama3_8b.json", + } + @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) @@ -45,7 +59,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.835) + self.assertGreater(metrics["score"], 0.80) def test_mmlu(self): args = SimpleNamespace( @@ -60,5 +74,40 @@ def test_mmlu(self): self.assertGreaterEqual(metrics["score"], 0.65) +class TestFp8KvcacheQwen(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + "config_filename": "kv_cache_scales_qwen2_1_5b.json", + } + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.01) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.3) + + if __name__ == "__main__": unittest.main()