diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 10a2dffb5..adb9d4784 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -384,7 +384,7 @@ def get_vllm_state_dict(llm, return_state_dict = False, config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict try: - llm_engine = getattr(llm, "llm_engine", llm) + llm_engine = getattr(llm, "llm_engine", getattr(llm, "engine", llm)) vllm_internals = llm_engine.model_executor.driver_worker.model_runner.model except: raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") @@ -916,7 +916,7 @@ def load_vllm( os.environ["VLLM_USE_V1"] = "1" pass - from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs + from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs, AsyncEngineArgs # Default vLLM max_num_seqs is 256 approx_max_num_seqs = 256 @@ -1004,7 +1004,7 @@ def load_vllm( swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB device = device, ) - good_keys = inspect.signature(EngineArgs).parameters.keys() + good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys() old_keys = engine_args.keys() for key in old_keys: if key not in good_keys: @@ -1018,7 +1018,7 @@ def load_vllm( while True: try: if use_async: - llm = AsyncLLMEngine.from_engine_args(EngineArgs(**engine_args)) + llm = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) elif use_engine: llm = LLMEngine.from_engine_args(EngineArgs(**engine_args)) else: