diff --git a/pyproject.toml b/pyproject.toml index 6bf403849d..1d206913a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.2", + "unsloth_zoo>=2025.3.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -354,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.1", + "unsloth_zoo>=2025.3.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8439ab8212..4336ec494b 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.2"): + if Version(unsloth_zoo_version) < Version("2025.3.4"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index e11cd54417..a187ee577a 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from .granite import FastGraniteModel -from .loader import FastLanguageModel, FastVisionModel from .llama import FastLlamaModel +from .loader import FastLanguageModel, FastVisionModel from .mistral import FastMistralModel from .qwen2 import FastQwen2Model +from .granite import FastGraniteModel from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported, __version__ from .rl import PatchFastRL, vLLMSamplingParams diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7d6bbfb78b..c01e0ccc82 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.5" +__version__ = "2025.3.6" __all__ = [ "SUPPORTS_BFLOAT16", @@ -1050,7 +1050,10 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): pass pass - if num_items_in_batch is None: + # Get gradient accumulation steps if possible + if num_items_in_batch is None and \ + getattr(self, "args", {}).get("gradient_accumulation_steps", 1) != 1: + name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ @@ -1245,10 +1248,11 @@ def unsloth_compile_transformers( # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ - 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n\n'\ - "import os\n"\ + 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ + "```\nimport os\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ - "... trainer.train() ..." + "trainer.train()\n```\n"\ + "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c9ea922272..cf9c16514e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -284,6 +284,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += eval_changes pass + # Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used + if "model" in call_args: + logits_check = \ + "_output_logits = False\n"\ + "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\ + "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\ + "if _output_logits:\n"\ + " os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n" + extra_args += logits_check + pass + # Check max_seq_length if "model" in call_args: length_check = \