diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index cdc0b22c70..0eed45c5cd 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -265,6 +265,7 @@ def MistralForCausalLM_fast_forward( output_attentions = output_attentions, output_hidden_states = output_hidden_states, return_dict = return_dict, + **kwargs, ) hidden_states = outputs[0] diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 35edbb8af8..339af63f33 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -49,36 +49,28 @@ logger = logging.getLogger(__name__) -_AUTO_PACKING_ENV_DISABLED = os.environ.get( - "UNSLOTH_DISABLE_AUTO_PACKING", "" -).strip().lower() in {"1", "true", "yes", "on"} - _AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get( "UNSLOTH_DISABLE_AUTO_PADDING_FREE", "" ).strip().lower() in {"1", "true", "yes", "on"} - -# [TODO] -# Below cannot work with padding-free PADDING_FREE_BLOCKLIST = { "gemma2", # - gemma2: Uses slow_attention_softcapping which has torch.compile issues "gpt_oss", # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly - "mistral", # - mistral: Unfortunately I think sliding window attention doesn't work correctly? } -def _should_auto_pack(config) -> bool: - if config is None or _AUTO_PACKING_ENV_DISABLED: - return False - if not getattr(config, "packing", False): +def _should_pack(config) -> bool: + if config is None or not getattr(config, "packing", False): return False return not getattr(config, "_unsloth_disable_auto_packing", False) def _should_auto_padding_free(config) -> bool: - if config is None or _AUTO_PADDING_FREE_ENV_DISABLED: - return False - if getattr(config, "packing", False): + if ( + config is None + or _AUTO_PADDING_FREE_ENV_DISABLED + or getattr(config, "packing", False) + ): return False return not getattr(config, "padding_free", False) @@ -326,23 +318,31 @@ def new_init(self, *args, **kwargs): data_collator is not None or isinstance(processing_class, ProcessorMixin) or is_vlm + or is_unsupported_model ) - if blocked and _should_auto_pack(config_arg): - reason = ( - "custom data collator" - if data_collator is not None - else "processor-based model" - ) - logger.info( - "Unsloth: Auto sample packing skipped (%s detected). Use UNSLOTH_DISABLE_AUTO_PACKING=1 to silence.", - reason, - ) + requested_pack = bool(getattr(config_arg, "packing", False)) + if blocked: + if hasattr(config_arg, "packing"): + setattr(config_arg, "packing", False) + if hasattr(config_arg, "padding_free"): + setattr(config_arg, "padding_free", False) + + if blocked and requested_pack: + reason = "custom data collator" + if data_collator is None and isinstance(processing_class, ProcessorMixin): + reason = "processor-based model" + elif is_vlm: + reason = "vision-language model" + elif is_unsupported_model: + reason = f"unsupported model type(s): {', '.join(model_types)}" + message = "Unsloth: Sample packing skipped " f"({reason} detected)." + print(message) - auto_pack_active = False - if _should_auto_pack(config_arg) and not blocked: + packing_active = False + if _should_pack(config_arg) and not blocked: configure_sample_packing(config_arg) - auto_pack_active = True - logger.info("Unsloth: Sample packing auto-enabled for SFTTrainer instance.") + packing_active = True + logger.info("Unsloth: Sample packing enabled for SFTTrainer instance.") auto_padding_free_active = False padding_free_requested = getattr(config_arg, "padding_free", None) is True @@ -359,13 +359,13 @@ def new_init(self, *args, **kwargs): try: original_init(self, *args, **kwargs) except ValueError as exc: - if auto_pack_active and _should_skip_auto_packing_error(exc): + if packing_active and _should_skip_auto_packing_error(exc): logger.info( "Unsloth: Auto sample packing failed because trainer reported an incompatible setup (%s).", exc, ) _disable_sample_packing(config_arg) - auto_pack_active = False + packing_active = False original_init(self, *args, **kwargs) else: raise @@ -376,12 +376,21 @@ def new_init(self, *args, **kwargs): trainer_args and getattr(trainer_args, "padding_free", False) ) - if trainer_packing and (auto_pack_active or _should_auto_pack(trainer_args)): + if blocked and trainer_args is not None: + # Mirror the block on the trainer args to avoid re-enabling later + setattr(trainer_args, "packing", False) + setattr(trainer_args, "padding_free", False) + + if ( + not blocked + and trainer_packing + and (packing_active or _should_pack(trainer_args)) + ): enable_sample_packing(self.model, self) print( "🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!" ) - elif trainer_padding_free: + elif not blocked and trainer_padding_free: enable_padding_free_metadata(self.model, self) message = ( "🦥 Unsloth: Padding-free auto-enabled, enabling faster training." diff --git a/unsloth/utils/packing.py b/unsloth/utils/packing.py index 5c13c55c7f..81b721a29b 100644 --- a/unsloth/utils/packing.py +++ b/unsloth/utils/packing.py @@ -107,15 +107,12 @@ def configure_sample_packing(config): _ensure_trl_warning_filter() setattr(config, "packing", True) setattr(config, "padding_free", True) - setattr(config, "remove_unused_columns", False) def configure_padding_free(config): """Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing.""" _ensure_trl_warning_filter() setattr(config, "padding_free", True) - if hasattr(config, "remove_unused_columns"): - setattr(config, "remove_unused_columns", False) def enable_sample_packing( @@ -150,48 +147,15 @@ def torch_call_with_lengths(examples: Sequence[dict]): batch = original_torch_call(examples) if examples and isinstance(examples[0], dict): seq_lengths: list[int] = [] - per_example_counts: list[int] = [] for example in examples: lengths = example.get(sequence_lengths_key) if isinstance(lengths, Iterable): - numeric_lengths = [int(length) for length in lengths] - seq_lengths.extend(numeric_lengths) - per_example_counts.append(len(numeric_lengths)) - else: - per_example_counts.append(0) + seq_lengths.extend(int(length) for length in lengths) if seq_lengths: batch["packed_seq_lengths"] = torch.tensor( seq_lengths, dtype = torch.int32 ) - - position_ids = batch.get("position_ids") - input_ids = batch.get("input_ids") - if position_ids is None and input_ids is not None: - position_ids = torch.zeros_like( - input_ids, dtype = torch.long, device = input_ids.device - ) - - if position_ids is not None and input_ids is not None: - seq_index = 0 - for row_idx, count in enumerate(per_example_counts): - cursor = 0 - for _ in range(count): - length = seq_lengths[seq_index] - if length > 0: - position_ids[row_idx, cursor : cursor + length] = ( - torch.arange( - length, - dtype = torch.long, - device = position_ids.device, - ) - ) - cursor += length - seq_index += 1 - batch["position_ids"] = position_ids - - if "attention_mask" in batch and getattr( - collator, "return_position_ids", False - ): + if "attention_mask" in batch: batch.pop("attention_mask") return batch @@ -201,23 +165,12 @@ def torch_call_with_lengths(examples: Sequence[dict]): def enable_padding_free_metadata(model, trainer): """Inject seq-length metadata when padding-free batching is enabled without packing.""" - - trainer_args = getattr(trainer, "args", None) - if ( - trainer_args is not None - and hasattr(trainer_args, "remove_unused_columns") - and trainer_args.remove_unused_columns - ): - trainer_args.remove_unused_columns = False - - _ensure_trl_warning_filter() collator = getattr(trainer, "data_collator", None) if ( collator is None or getattr(collator, "_unsloth_padding_free_lengths_wrapped", False) or not getattr(collator, "padding_free", False) ): - # Nothing to do if there's no collator, we've already wrapped it, or padding-free is off. return mark_allow_overlength(model)