Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions unsloth/import_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __getattr__(self, name):
# Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0
logging.getLogger("torchao").setLevel(logging.ERROR)
# SyntaxWarning: invalid escape sequence '\.'
warnings.filterwarnings("ignore", message = "invalid escape sequence", category = SyntaxWarning)
warnings.filterwarnings(
"ignore", message = "invalid escape sequence", category = SyntaxWarning
)


# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
Expand Down Expand Up @@ -549,5 +551,8 @@ def fix_diffusers_warnings():
def fix_huggingface_hub():
# huggingface_hub.is_offline_mode got removed, so add it back
import huggingface_hub

if not hasattr(huggingface_hub, "is_offline_mode"):
huggingface_hub.is_offline_mode = lambda: huggingface_hub.constants.HF_HUB_OFFLINE
huggingface_hub.is_offline_mode = (
lambda: huggingface_hub.constants.HF_HUB_OFFLINE
)
57 changes: 57 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"verify_fp8_support_if_applicable",
"_get_inference_mode_context_manager",
"hf_login",
"make_fast_generate_wrapper",
]

import torch
Expand Down Expand Up @@ -2378,3 +2379,59 @@ def hf_login(token: Optional[str] = None) -> Optional[str]:
except Exception as e:
logger.info(f"Failed to login to huggingface using token with error: {e}")
return token


def make_fast_generate_wrapper(original_generate):
"""
Creates a wrapper around model.generate that checks for incorrect
vLLM-style usage when fast_inference=False.
"""

@functools.wraps(original_generate)
def _fast_generate_wrapper(*args, **kwargs):
# Check for vLLM-specific arguments
if "sampling_params" in kwargs:
raise ValueError(
"Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). "
"Since `fast_inference=False`, use HuggingFace generate arguments instead:\n"
" model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)"
)

if "lora_request" in kwargs:
raise ValueError(
"Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). "
"Since `fast_inference=False`, LoRA weights are already merged into the model."
)

# Check if first positional argument is a string or list of strings
if len(args) > 0:
first_arg = args[0]
is_string_input = False

if isinstance(first_arg, str):
is_string_input = True
elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:
if isinstance(first_arg[0], str):
Comment on lines +2408 to +2414
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to detect string input can be simplified into a single boolean expression for better readability and conciseness. The current implementation is correct, but a more Pythonic way would be to combine the checks using boolean short-circuiting.

            is_string_input = isinstance(first_arg, str) or (
                isinstance(first_arg, (list, tuple)) and first_arg and isinstance(first_arg[0], str)
            )

is_string_input = True

if is_string_input:
raise ValueError(
"Unsloth: Passing text strings to `fast_generate` is only supported "
"when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must "
"tokenize the input first:\n\n"
" messages = tokenizer.apply_chat_template(\n"
' [{"role": "user", "content": "Your prompt here"}],\n'
" tokenize=True, add_generation_prompt=True,\n"
' return_tensors="pt", return_dict=True\n'
" )\n"
" output = model.fast_generate(\n"
" **messages.to('cuda'),\n"
" max_new_tokens=64,\n"
" temperature=1.0,\n"
" )"
)

# Call original generate
return original_generate(*args, **kwargs)

return _fast_generate_wrapper
2 changes: 1 addition & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,7 +2326,7 @@ def from_pretrained(
attn_implementation = "eager",
**kwargs,
)
model.fast_generate = model.generate
model.fast_generate = make_fast_generate_wrapper(model.generate)
model.fast_generate_batches = None
else:
from unsloth_zoo.vllm_utils import (
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def from_pretrained(
**kwargs,
)
if hasattr(model, "generate"):
model.fast_generate = model.generate
model.fast_generate = make_fast_generate_wrapper(model.generate)
model.fast_generate_batches = error_out_no_vllm
if offload_embedding:
if bool(
Expand Down