Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/engine_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@
"MAX_CONTEXT_LEN_TO_CAPTURE": "max_seq_len_to_capture"
}

def parse_json_env_var(env_var_name, default=None):
"""
Parse a JSON-formatted environment variable.

Args:
env_var_name (str): Name of the environment variable
default: Default value to return if env var is not set or invalid JSON

Returns:
Parsed JSON value (dict, list, str, int, float, bool, or None)
"""
env_value = os.getenv(env_var_name, None)
if env_value is None:
return default

try:
return json.loads(env_value)
except json.JSONDecodeError as e:
logging.warning(f"Warning: Invalid JSON in {env_var_name}: {env_value}")
logging.warning(f"JSON Error: {e}")
return default

DEFAULT_ARGS = {
"disable_log_stats": os.getenv('DISABLE_LOG_STATS', 'False').lower() == 'true',
"disable_log_requests": os.getenv('DISABLE_LOG_REQUESTS', 'False').lower() == 'true',
Expand Down Expand Up @@ -46,7 +68,7 @@
"max_logprobs": int(os.getenv('MAX_LOGPROBS', 20)), # Default value for OpenAI Chat Completions API
"revision": os.getenv('REVISION', None),
"code_revision": os.getenv('CODE_REVISION', None),
"rope_scaling": os.getenv('ROPE_SCALING', None),
"rope_scaling": parse_json_env_var('ROPE_SCALING'),
"rope_theta": float(os.getenv('ROPE_THETA', 0)) or None,
"tokenizer_revision": os.getenv('TOKENIZER_REVISION', None),
"quantization": os.getenv('QUANTIZATION', None),
Expand All @@ -56,7 +78,7 @@
"disable_custom_all_reduce": os.getenv('DISABLE_CUSTOM_ALL_REDUCE', 'False').lower() == 'true',
"tokenizer_pool_size": int(os.getenv('TOKENIZER_POOL_SIZE', 0)),
"tokenizer_pool_type": os.getenv('TOKENIZER_POOL_TYPE', 'ray'),
"tokenizer_pool_extra_config": os.getenv('TOKENIZER_POOL_EXTRA_CONFIG', None),
"tokenizer_pool_extra_config": parse_json_env_var('TOKENIZER_POOL_EXTRA_CONFIG'),
"enable_lora": os.getenv('ENABLE_LORA', 'False').lower() == 'true',
"max_loras": int(os.getenv('MAX_LORAS', 1)),
"max_lora_rank": int(os.getenv('MAX_LORA_RANK', 16)),
Expand All @@ -72,7 +94,7 @@
"ray_workers_use_nsight": os.getenv('RAY_WORKERS_USE_NSIGHT', 'False').lower() == 'true',
"num_gpu_blocks_override": int(os.getenv('NUM_GPU_BLOCKS_OVERRIDE', 0)) or None,
"num_lookahead_slots": int(os.getenv('NUM_LOOKAHEAD_SLOTS', 0)),
"model_loader_extra_config": os.getenv('MODEL_LOADER_EXTRA_CONFIG', None),
"model_loader_extra_config": parse_json_env_var('MODEL_LOADER_EXTRA_CONFIG'),
"ignore_patterns": os.getenv('IGNORE_PATTERNS', None),
"preemption_mode": os.getenv('PREEMPTION_MODE', None),
"scheduler_delay_factor": float(os.getenv('SCHEDULER_DELAY_FACTOR', 0.0)),
Expand Down Expand Up @@ -136,43 +158,43 @@ def get_local_args():
def get_engine_args():
# Start with default args
args = DEFAULT_ARGS

# Get env args that match keys in AsyncEngineArgs
args.update(os.environ)

# Get local args if model is baked in and overwrite env args
args.update(get_local_args())

# if args.get("TENSORIZER_URI"): TODO: add back once tensorizer is ready
# args["load_format"] = "tensorizer"
# args["model_loader_extra_config"] = TensorizerConfig(tensorizer_uri=args["TENSORIZER_URI"], num_readers=None)
# logging.info(f"Using tensorized model from {args['TENSORIZER_URI']}")


# Rename and match to vllm args
args = match_vllm_args(args)

if args.get("load_format") == "bitsandbytes":
args["quantization"] = args["load_format"]

# Set tensor parallel size and max parallel loading workers if more than 1 GPU is available
num_gpus = device_count()
if num_gpus > 1:
args["tensor_parallel_size"] = num_gpus
args["max_parallel_loading_workers"] = None
if os.getenv("MAX_PARALLEL_LOADING_WORKERS"):
logging.warning("Overriding MAX_PARALLEL_LOADING_WORKERS with None because more than 1 GPU is available.")

# Deprecated env args backwards compatibility
if args.get("kv_cache_dtype") == "fp8_e5m2":
args["kv_cache_dtype"] = "fp8"
logging.warning("Using fp8_e5m2 is deprecated. Please use fp8 instead.")
if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"):
args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"))
logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.")

# if "gemma-2" in args.get("model", "").lower():
# os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
# logging.info("Using FLASHINFER for gemma-2 model.")

return AsyncEngineArgs(**args)