Skip to content

Commit

Permalink
support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)
Browse files Browse the repository at this point in the history
Co-authored-by: bjmsong <[email protected]>
  • Loading branch information
bjmsong and mdattack authored Jan 18, 2025
1 parent 13387e6 commit d3024f4
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 9 deletions.
55 changes: 54 additions & 1 deletion python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
import os
import tempfile
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)

import filelock
import gguf
Expand Down Expand Up @@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:

# If there were no matches, return the untouched param name
return name


def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.",
tp_rank,
)
return []
6 changes: 4 additions & 2 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -45,7 +44,10 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback

Expand Down
36 changes: 34 additions & 2 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
Expand All @@ -39,7 +42,10 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers

Qwen2Config = None
Expand Down Expand Up @@ -265,6 +271,29 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)


class Qwen2ForCausalLM(nn.Module):

Expand Down Expand Up @@ -373,5 +402,8 @@ def set_embed_and_head(self, embed, head):
torch.cuda.empty_cache()
torch.cuda.synchronize()

def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)


EntryClass = Qwen2ForCausalLM
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"


def is_in_ci():
Expand Down
42 changes: 42 additions & 0 deletions test/srt/kv_cache_scales_llama3_8b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0408,
"1": 0.0503,
"2": 0.0667,
"3": 0.0909,
"4": 0.1135,
"5": 0.127,
"6": 0.1768,
"7": 0.1488,
"8": 0.1135,
"9": 0.1203,
"10": 0.1013,
"11": 0.0842,
"12": 0.1231,
"13": 0.1096,
"14": 0.1221,
"15": 0.1013,
"16": 0.1067,
"17": 0.0952,
"18": 0.0899,
"19": 0.097,
"20": 0.087,
"21": 0.0994,
"22": 0.0904,
"23": 0.1013,
"24": 0.1019,
"25": 0.1053,
"26": 0.1,
"27": 0.0894,
"28": 0.1013,
"29": 0.1488,
"30": 0.0766,
"31": 0.0821
}
}
}
}
38 changes: 38 additions & 0 deletions test/srt/kv_cache_scales_qwen2_1_5b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"model_type": "qwen",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.9846,
"1": 0.0645,
"2": 0.0731,
"3": 0.0800,
"4": 0.0748,
"5": 0.0780,
"6": 0.0702,
"7": 0.0894,
"8": 0.0410,
"9": 0.0758,
"10": 0.0556,
"11": 0.0731,
"12": 0.0899,
"13": 0.0780,
"14": 0.1441,
"15": 0.0914,
"16": 0.5614,
"17": 0.1067,
"18": 0.0537,
"19": 0.0658,
"20": 0.0523,
"21": 0.0533,
"22": 0.0699,
"23": 0.0635,
"24": 0.0588,
"25": 0.0884,
"26": 0.0947,
"27": 0.1032
}
}
}
}
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"test_vision_openai_server.py",
"test_w8a8_quantization.py",
"test_session_control.py",
"test_fp8_kvcache.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
Expand Down
57 changes: 53 additions & 4 deletions test/srt/test_fp8_kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestFp8Kvcache(unittest.TestCase):
class TestFp8KvcacheBase(unittest.TestCase):
model_config = None

@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
if cls.model_config is None:
raise NotImplementedError("model_config must be specified in subclass")

cls.model = cls.model_config["model_name"]
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
config_file = os.path.join(dirpath, cls.model_config["config_filename"])

cls.process = popen_launch_server(
cls.model,
cls.base_url,
Expand All @@ -31,6 +38,13 @@ def setUpClass(cls):
],
)


class TestFp8KvcacheLlama(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_MODEL_NAME_FOR_TEST,
"config_filename": "kv_cache_scales_llama3_8b.json",
}

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
Expand All @@ -45,7 +59,7 @@ def test_mgsm_en(self):
)

metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
self.assertGreater(metrics["score"], 0.80)

def test_mmlu(self):
args = SimpleNamespace(
Expand All @@ -60,5 +74,40 @@ def test_mmlu(self):
self.assertGreaterEqual(metrics["score"], 0.65)


class TestFp8KvcacheQwen(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
"config_filename": "kv_cache_scales_qwen2_1_5b.json",
}

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)

metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.01)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.3)


if __name__ == "__main__":
unittest.main()

0 comments on commit d3024f4

Please sign in to comment.