diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py index f0dde256c1..bb6996a3e3 100644 --- a/unsloth/import_fixes.py +++ b/unsloth/import_fixes.py @@ -101,7 +101,9 @@ def __getattr__(self, name): # Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0 logging.getLogger("torchao").setLevel(logging.ERROR) # SyntaxWarning: invalid escape sequence '\.' - warnings.filterwarnings("ignore", message = "invalid escape sequence", category = SyntaxWarning) + warnings.filterwarnings( + "ignore", message = "invalid escape sequence", category = SyntaxWarning + ) # Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' @@ -549,5 +551,8 @@ def fix_diffusers_warnings(): def fix_huggingface_hub(): # huggingface_hub.is_offline_mode got removed, so add it back import huggingface_hub + if not hasattr(huggingface_hub, "is_offline_mode"): - huggingface_hub.is_offline_mode = lambda: huggingface_hub.constants.HF_HUB_OFFLINE + huggingface_hub.is_offline_mode = ( + lambda: huggingface_hub.constants.HF_HUB_OFFLINE + ) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3851e18f92..1cead3afaf 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -73,6 +73,7 @@ "verify_fp8_support_if_applicable", "_get_inference_mode_context_manager", "hf_login", + "make_fast_generate_wrapper", ] import torch @@ -2378,3 +2379,59 @@ def hf_login(token: Optional[str] = None) -> Optional[str]: except Exception as e: logger.info(f"Failed to login to huggingface using token with error: {e}") return token + + +def make_fast_generate_wrapper(original_generate): + """ + Creates a wrapper around model.generate that checks for incorrect + vLLM-style usage when fast_inference=False. + """ + + @functools.wraps(original_generate) + def _fast_generate_wrapper(*args, **kwargs): + # Check for vLLM-specific arguments + if "sampling_params" in kwargs: + raise ValueError( + "Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). " + "Since `fast_inference=False`, use HuggingFace generate arguments instead:\n" + " model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)" + ) + + if "lora_request" in kwargs: + raise ValueError( + "Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). " + "Since `fast_inference=False`, LoRA weights are already merged into the model." + ) + + # Check if first positional argument is a string or list of strings + if len(args) > 0: + first_arg = args[0] + is_string_input = False + + if isinstance(first_arg, str): + is_string_input = True + elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0: + if isinstance(first_arg[0], str): + is_string_input = True + + if is_string_input: + raise ValueError( + "Unsloth: Passing text strings to `fast_generate` is only supported " + "when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must " + "tokenize the input first:\n\n" + " messages = tokenizer.apply_chat_template(\n" + ' [{"role": "user", "content": "Your prompt here"}],\n' + " tokenize=True, add_generation_prompt=True,\n" + ' return_tensors="pt", return_dict=True\n' + " )\n" + " output = model.fast_generate(\n" + " **messages.to('cuda'),\n" + " max_new_tokens=64,\n" + " temperature=1.0,\n" + " )" + ) + + # Call original generate + return original_generate(*args, **kwargs) + + return _fast_generate_wrapper diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 29d41f4bb1..92d51b73ad 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2326,7 +2326,7 @@ def from_pretrained( attn_implementation = "eager", **kwargs, ) - model.fast_generate = model.generate + model.fast_generate = make_fast_generate_wrapper(model.generate) model.fast_generate_batches = None else: from unsloth_zoo.vllm_utils import ( diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1924373f67..6c5356e0b9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -673,7 +673,7 @@ def from_pretrained( **kwargs, ) if hasattr(model, "generate"): - model.fast_generate = model.generate + model.fast_generate = make_fast_generate_wrapper(model.generate) model.fast_generate_batches = error_out_no_vllm if offload_embedding: if bool(