diff --git a/vllm_spyre/compilation_utils.py b/vllm_spyre/compilation_utils.py index 8bb37946a..ac4cd182e 100644 --- a/vllm_spyre/compilation_utils.py +++ b/vllm_spyre/compilation_utils.py @@ -2,6 +2,7 @@ import os from pathlib import Path from typing import TYPE_CHECKING +import importlib.metadata # Third Party from vllm.logger import init_logger @@ -133,20 +134,28 @@ def handle_disable_compilation(vllm_config: VllmConfig, is_decoder: bool): if matching_config: # Check vllm_spyre version try: - from vllm_spyre._version import version as vllm_spyre_version + vllm_spyre_version = importlib.metadata.version("vllm_spyre") - if matching_config["vllm_spyre_version"] != vllm_spyre_version: + config_version = matching_config.get("vllm_spyre_version") + if config_version is None: + logger.warning( + "[PRECOMPILED_WARN] Pre-compiled config missing vllm_spyre_version field. " + ) + elif config_version != vllm_spyre_version: # Can be converted to ValueError if we want to be strict # with checking logger.warning( "[PRECOMPILED_WARN] " "Model was compiled on vllm-spyre " "%s but the current vllm_spyre version is %s", - matching_config["vllm_spyre_version"], + config_version, vllm_spyre_version, ) except ImportError: - logger.warning("Cannot validate vllm_spyre version against pre-compiled model config") + logger.warning( + "[PRECOMPILED_WARN] Cannot validate vllm_spyre version against " + "pre-compiled model config" + ) # Check model name model_name = matching_config["data"]["MODEL_NAME"]