diff --git a/src/engine_args.py b/src/engine_args.py index 8105a09..0518f1b 100644 --- a/src/engine_args.py +++ b/src/engine_args.py @@ -13,6 +13,28 @@ "MAX_CONTEXT_LEN_TO_CAPTURE": "max_seq_len_to_capture" } +def parse_json_env_var(env_var_name, default=None): + """ + Parse a JSON-formatted environment variable. + + Args: + env_var_name (str): Name of the environment variable + default: Default value to return if env var is not set or invalid JSON + + Returns: + Parsed JSON value (dict, list, str, int, float, bool, or None) + """ + env_value = os.getenv(env_var_name, None) + if env_value is None: + return default + + try: + return json.loads(env_value) + except json.JSONDecodeError as e: + logging.warning(f"Warning: Invalid JSON in {env_var_name}: {env_value}") + logging.warning(f"JSON Error: {e}") + return default + DEFAULT_ARGS = { "disable_log_stats": os.getenv('DISABLE_LOG_STATS', 'False').lower() == 'true', "disable_log_requests": os.getenv('DISABLE_LOG_REQUESTS', 'False').lower() == 'true', @@ -46,7 +68,7 @@ "max_logprobs": int(os.getenv('MAX_LOGPROBS', 20)), # Default value for OpenAI Chat Completions API "revision": os.getenv('REVISION', None), "code_revision": os.getenv('CODE_REVISION', None), - "rope_scaling": os.getenv('ROPE_SCALING', None), + "rope_scaling": parse_json_env_var('ROPE_SCALING'), "rope_theta": float(os.getenv('ROPE_THETA', 0)) or None, "tokenizer_revision": os.getenv('TOKENIZER_REVISION', None), "quantization": os.getenv('QUANTIZATION', None), @@ -56,7 +78,7 @@ "disable_custom_all_reduce": os.getenv('DISABLE_CUSTOM_ALL_REDUCE', 'False').lower() == 'true', "tokenizer_pool_size": int(os.getenv('TOKENIZER_POOL_SIZE', 0)), "tokenizer_pool_type": os.getenv('TOKENIZER_POOL_TYPE', 'ray'), - "tokenizer_pool_extra_config": os.getenv('TOKENIZER_POOL_EXTRA_CONFIG', None), + "tokenizer_pool_extra_config": parse_json_env_var('TOKENIZER_POOL_EXTRA_CONFIG'), "enable_lora": os.getenv('ENABLE_LORA', 'False').lower() == 'true', "max_loras": int(os.getenv('MAX_LORAS', 1)), "max_lora_rank": int(os.getenv('MAX_LORA_RANK', 16)), @@ -72,7 +94,7 @@ "ray_workers_use_nsight": os.getenv('RAY_WORKERS_USE_NSIGHT', 'False').lower() == 'true', "num_gpu_blocks_override": int(os.getenv('NUM_GPU_BLOCKS_OVERRIDE', 0)) or None, "num_lookahead_slots": int(os.getenv('NUM_LOOKAHEAD_SLOTS', 0)), - "model_loader_extra_config": os.getenv('MODEL_LOADER_EXTRA_CONFIG', None), + "model_loader_extra_config": parse_json_env_var('MODEL_LOADER_EXTRA_CONFIG'), "ignore_patterns": os.getenv('IGNORE_PATTERNS', None), "preemption_mode": os.getenv('PREEMPTION_MODE', None), "scheduler_delay_factor": float(os.getenv('SCHEDULER_DELAY_FACTOR', 0.0)), @@ -136,25 +158,25 @@ def get_local_args(): def get_engine_args(): # Start with default args args = DEFAULT_ARGS - + # Get env args that match keys in AsyncEngineArgs args.update(os.environ) - + # Get local args if model is baked in and overwrite env args args.update(get_local_args()) - + # if args.get("TENSORIZER_URI"): TODO: add back once tensorizer is ready # args["load_format"] = "tensorizer" # args["model_loader_extra_config"] = TensorizerConfig(tensorizer_uri=args["TENSORIZER_URI"], num_readers=None) # logging.info(f"Using tensorized model from {args['TENSORIZER_URI']}") - - + + # Rename and match to vllm args args = match_vllm_args(args) if args.get("load_format") == "bitsandbytes": args["quantization"] = args["load_format"] - + # Set tensor parallel size and max parallel loading workers if more than 1 GPU is available num_gpus = device_count() if num_gpus > 1: @@ -162,7 +184,7 @@ def get_engine_args(): args["max_parallel_loading_workers"] = None if os.getenv("MAX_PARALLEL_LOADING_WORKERS"): logging.warning("Overriding MAX_PARALLEL_LOADING_WORKERS with None because more than 1 GPU is available.") - + # Deprecated env args backwards compatibility if args.get("kv_cache_dtype") == "fp8_e5m2": args["kv_cache_dtype"] = "fp8" @@ -170,9 +192,9 @@ def get_engine_args(): if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"): args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.") - + # if "gemma-2" in args.get("model", "").lower(): # os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" # logging.info("Using FLASHINFER for gemma-2 model.") - + return AsyncEngineArgs(**args)