diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py
index 015c6514530..77c3fcbee74 100644
--- a/python/sglang/srt/model_loader/weight_utils.py
+++ b/python/sglang/srt/model_loader/weight_utils.py
@@ -9,7 +9,17 @@
 import os
 import tempfile
 from collections import defaultdict
-from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import filelock
 import gguf
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
 
     # If there were no matches, return the untouched param name
     return name
+
+
+def kv_cache_scales_loader(
+    filename: str,
+    tp_rank: int,
+    tp_size: int,
+    num_hidden_layers: int,
+    model_type: Optional[str],
+) -> Iterable[Tuple[int, float]]:
+    """
+    A simple utility to read in KV cache scaling factors that have been
+    previously serialized to disk. Used by the model to populate the appropriate
+    KV cache scaling factors. The serialization should represent a dictionary
+    whose keys are the TP ranks and values are another dictionary mapping layers
+    to their KV cache scaling factors.
+    """
+    try:
+        with open(filename) as f:
+            context = {
+                "model_type": model_type,
+                "num_hidden_layers": num_hidden_layers,
+                "tp_rank": tp_rank,
+                "tp_size": tp_size,
+            }
+            schema_dct = json.load(f)
+            schema = QuantParamSchema.model_validate(schema_dct, context=context)
+            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
+            return layer_scales_map.items()
+    except FileNotFoundError:
+        logger.error("File or directory '%s' not found.", filename)
+    except json.JSONDecodeError:
+        logger.error("Error decoding JSON in file '%s'.", filename)
+    except Exception:
+        logger.exception("An error occurred while reading '%s'.", filename)
+    # This section is reached if and only if any of the excepts are hit
+    # Return an empty iterable (list) => no KV cache scales are loaded
+    # which ultimately defaults to 1.0 scales
+    logger.warning(
+        "Defaulting to KV cache scaling factors = 1.0 for all "
+        "layers in TP rank %d as an error occurred during loading.",
+        tp_rank,
+    )
+    return []
diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py
index 198d53995e4..9ea80d0c05d 100644
--- a/python/sglang/srt/models/llama.py
+++ b/python/sglang/srt/models/llama.py
@@ -23,7 +23,6 @@
 from torch import nn
 from transformers import LlamaConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
 
 from sglang.srt.distributed import (
     get_tensor_model_parallel_rank,
@@ -45,7 +44,10 @@
     VocabParallelEmbedding,
 )
 from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.model_loader.weight_utils import (
+    default_weight_loader,
+    kv_cache_scales_loader,
+)
 from sglang.srt.utils import make_layers
 from sglang.utils import get_exception_traceback
 
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index f1d37118a86..04faa8dea1b 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -22,7 +22,10 @@
 from torch import nn
 from vllm.model_executor.layers.rotary_embedding import get_rope
 
-from sglang.srt.distributed import get_tensor_model_parallel_world_size
+from sglang.srt.distributed import (
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from sglang.srt.layers.activation import SiluAndMul
 from sglang.srt.layers.layernorm import RMSNorm
 from sglang.srt.layers.linear import (
@@ -39,7 +42,10 @@
     VocabParallelEmbedding,
 )
 from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.model_loader.weight_utils import (
+    default_weight_loader,
+    kv_cache_scales_loader,
+)
 from sglang.srt.utils import make_layers
 
 Qwen2Config = None
@@ -265,6 +271,29 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
+    # If this function is called, it should always initialize KV cache scale
+    # factors (or else raise an exception). Thus, handled exceptions should
+    # make sure to leave KV cache scale factors in a known good (dummy) state
+    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        for layer_idx, scaling_factor in kv_cache_scales_loader(
+            quantization_param_path,
+            tp_rank,
+            tp_size,
+            self.config.num_hidden_layers,
+            self.config.__class__.model_type,
+        ):
+            if not isinstance(self.layers[layer_idx], nn.Identity):
+                layer_self_attn = self.layers[layer_idx].self_attn
+            if hasattr(layer_self_attn.attn, "k_scale"):
+                layer_self_attn.attn.k_scale = scaling_factor
+                layer_self_attn.attn.v_scale = scaling_factor
+            else:
+                raise RuntimeError(
+                    "Self attention has no KV cache scaling " "factor attribute!"
+                )
+
 
 class Qwen2ForCausalLM(nn.Module):
 
@@ -373,5 +402,8 @@ def set_embed_and_head(self, embed, head):
         torch.cuda.empty_cache()
         torch.cuda.synchronize()
 
+    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+        self.model.load_kv_cache_scales(quantization_param_path)
+
 
 EntryClass = Qwen2ForCausalLM
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 42e0b6d808a..d3c9b7cab5f 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -40,6 +40,7 @@
 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
+DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
 
 
 def is_in_ci():
diff --git a/test/srt/kv_cache_scales_llama3_8b.json b/test/srt/kv_cache_scales_llama3_8b.json
new file mode 100644
index 00000000000..466b0d01a74
--- /dev/null
+++ b/test/srt/kv_cache_scales_llama3_8b.json
@@ -0,0 +1,42 @@
+{
+    "model_type": "llama",
+    "kv_cache": {
+        "dtype": "float8_e4m3fn",
+        "scaling_factor": {
+            "0": {
+                "0": 0.0408,
+                "1": 0.0503,
+                "2": 0.0667,
+                "3": 0.0909,
+                "4": 0.1135,
+                "5": 0.127,
+                "6": 0.1768,
+                "7": 0.1488,
+                "8": 0.1135,
+                "9": 0.1203,
+                "10": 0.1013,
+                "11": 0.0842,
+                "12": 0.1231,
+                "13": 0.1096,
+                "14": 0.1221,
+                "15": 0.1013,
+                "16": 0.1067,
+                "17": 0.0952,
+                "18": 0.0899,
+                "19": 0.097,
+                "20": 0.087,
+                "21": 0.0994,
+                "22": 0.0904,
+                "23": 0.1013,
+                "24": 0.1019,
+                "25": 0.1053,
+                "26": 0.1,
+                "27": 0.0894,
+                "28": 0.1013,
+                "29": 0.1488,
+                "30": 0.0766,
+                "31": 0.0821
+            }
+        }
+    }
+}
diff --git a/test/srt/kv_cache_scales_qwen2_1_5b.json b/test/srt/kv_cache_scales_qwen2_1_5b.json
new file mode 100644
index 00000000000..984747509f7
--- /dev/null
+++ b/test/srt/kv_cache_scales_qwen2_1_5b.json
@@ -0,0 +1,38 @@
+{
+    "model_type": "qwen",
+    "kv_cache": {
+        "dtype": "float8_e4m3fn",
+        "scaling_factor": {
+            "0": {
+                "0": 0.9846,
+                 "1": 0.0645,
+                 "2": 0.0731,
+                 "3": 0.0800,
+                 "4": 0.0748,
+                 "5": 0.0780,
+                 "6": 0.0702,
+                 "7": 0.0894,
+                 "8": 0.0410,
+                 "9": 0.0758,
+                 "10": 0.0556,
+                 "11": 0.0731,
+                 "12": 0.0899,
+                 "13": 0.0780,
+                 "14": 0.1441,
+                 "15": 0.0914,
+                 "16": 0.5614,
+                 "17": 0.1067,
+                 "18": 0.0537,
+                 "19": 0.0658,
+                 "20": 0.0523,
+                 "21": 0.0533,
+                 "22": 0.0699,
+                 "23": 0.0635,
+                 "24": 0.0588,
+                 "25": 0.0884,
+                 "26": 0.0947,
+                 "27": 0.1032
+            }
+        }
+    }
+}
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index e2ecdfb3a68..fb1c6abf29b 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -52,6 +52,7 @@
         "test_vision_openai_server.py",
         "test_w8a8_quantization.py",
         "test_session_control.py",
+        "test_fp8_kvcache.py",
     ],
     "nightly": [
         "test_nightly_gsm8k_eval.py",
diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py
index 0d6602997de..4a8a2434699 100644
--- a/test/srt/test_fp8_kvcache.py
+++ b/test/srt/test_fp8_kvcache.py
@@ -6,19 +6,26 @@
 from sglang.test.run_eval import run_eval
 from sglang.test.test_utils import (
     DEFAULT_MODEL_NAME_FOR_TEST,
+    DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
     DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
     DEFAULT_URL_FOR_TEST,
     popen_launch_server,
 )
 
 
-class TestFp8Kvcache(unittest.TestCase):
+class TestFp8KvcacheBase(unittest.TestCase):
+    model_config = None
+
     @classmethod
     def setUpClass(cls):
-        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
+        if cls.model_config is None:
+            raise NotImplementedError("model_config must be specified in subclass")
+
+        cls.model = cls.model_config["model_name"]
         cls.base_url = DEFAULT_URL_FOR_TEST
         dirpath = os.path.dirname(__file__)
-        config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
+        config_file = os.path.join(dirpath, cls.model_config["config_filename"])
+
         cls.process = popen_launch_server(
             cls.model,
             cls.base_url,
@@ -31,6 +38,13 @@ def setUpClass(cls):
             ],
         )
 
+
+class TestFp8KvcacheLlama(TestFp8KvcacheBase):
+    model_config = {
+        "model_name": DEFAULT_MODEL_NAME_FOR_TEST,
+        "config_filename": "kv_cache_scales_llama3_8b.json",
+    }
+
     @classmethod
     def tearDownClass(cls):
         kill_process_tree(cls.process.pid)
@@ -45,7 +59,7 @@ def test_mgsm_en(self):
         )
 
         metrics = run_eval(args)
-        self.assertGreater(metrics["score"], 0.835)
+        self.assertGreater(metrics["score"], 0.80)
 
     def test_mmlu(self):
         args = SimpleNamespace(
@@ -60,5 +74,40 @@ def test_mmlu(self):
         self.assertGreaterEqual(metrics["score"], 0.65)
 
 
+class TestFp8KvcacheQwen(TestFp8KvcacheBase):
+    model_config = {
+        "model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
+        "config_filename": "kv_cache_scales_qwen2_1_5b.json",
+    }
+
+    @classmethod
+    def tearDownClass(cls):
+        kill_process_tree(cls.process.pid)
+
+    def test_mgsm_en(self):
+        args = SimpleNamespace(
+            base_url=self.base_url,
+            model=self.model,
+            eval_name="mgsm_en",
+            num_examples=None,
+            num_threads=1024,
+        )
+
+        metrics = run_eval(args)
+        self.assertGreater(metrics["score"], 0.01)
+
+    def test_mmlu(self):
+        args = SimpleNamespace(
+            base_url=self.base_url,
+            model=self.model,
+            eval_name="mmlu",
+            num_examples=64,
+            num_threads=32,
+        )
+
+        metrics = run_eval(args)
+        self.assertGreaterEqual(metrics["score"], 0.3)
+
+
 if __name__ == "__main__":
     unittest.main()