Skip to content

Commit 2facb5e

Browse files
committed
wip
Signed-off-by: Daniel Afrimi <[email protected]>
1 parent bb7e1e1 commit 2facb5e

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..logger import logger
2626
from ..mapping import Mapping
2727
from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM
28+
from ..models.modeling_utils import QuantAlgo # noqa: F401
2829
from ..models.modeling_utils import PretrainedConfig, QuantConfig
2930
from ..module import Module
3031
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,

tensorrt_llm/models/modeling_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,11 @@ def _infer_kv_cache_quant_algo_from_scheme(kv_scheme: dict) -> str | None:
235235
bits = kv_scheme.get("num_bits")
236236
dynamic = bool(kv_scheme.get("dynamic", False))
237237

238-
# todo add here all options...
238+
# TODO (danielafrimi) needs to check all supported options...
239239
if kv_type == "float" and bits == 8 and not dynamic:
240-
return QuantAlgo("FP8_BLOCK_SCALES")
240+
return QuantAlgo.FP8
241241
if kv_type in ("int", "uint") and bits == 8:
242-
return QuantAlgo("INT8")
242+
return QuantAlgo.INT8
243243
return None
244244

245245
def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict:
@@ -261,8 +261,7 @@ def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict:
261261
hf_quant_config.get("ignore") or [])
262262

263263
kv_scheme = hf_quant_config.get("kv_cache_scheme") or {}
264-
kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme(
265-
kv_scheme) # todo check it
264+
kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme(kv_scheme)
266265
if kv_algo is not None:
267266
qunatization_dict["kv_cache_quant_algo"] = kv_algo
268267

@@ -273,7 +272,6 @@ def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict:
273272
if "symmetric" in hf_quant_config:
274273
qunatization_dict["zero_point"] = hf_quant_config["symmetric"]
275274

276-
# todo add here pre qunat scale and other keys....
277275
return qunatization_dict
278276

279277
def _update_from_quant_config_json(self, path, moe_backend: str,

0 commit comments

Comments
 (0)