diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c98feeca1e..99a34f0a64 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1378,6 +1378,15 @@ def _wrap_fast_inference(generate, device_type, dtype, model): @torch.inference_mode def _fast_generate(*args, **kwargs): + if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"): + if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: + if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > model.config.max_position_embeddings: + raise ValueError( + f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ + 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' + ) + pass + # Set a flag for generation! internal_model = model while hasattr(internal_model, "model"):