Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def MistralForCausalLM_fast_forward(
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
**kwargs,
Comment thread
djsaunde marked this conversation as resolved.
)

hidden_states = outputs[0]
Expand Down
75 changes: 42 additions & 33 deletions unsloth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,28 @@

logger = logging.getLogger(__name__)

_AUTO_PACKING_ENV_DISABLED = os.environ.get(
"UNSLOTH_DISABLE_AUTO_PACKING", ""
).strip().lower() in {"1", "true", "yes", "on"}

_AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get(
"UNSLOTH_DISABLE_AUTO_PADDING_FREE", ""
).strip().lower() in {"1", "true", "yes", "on"}


# [TODO]
# Below cannot work with padding-free
PADDING_FREE_BLOCKLIST = {
"gemma2", # - gemma2: Uses slow_attention_softcapping which has torch.compile issues
"gpt_oss", # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly
"mistral", # - mistral: Unfortunately I think sliding window attention doesn't work correctly?
}


def _should_auto_pack(config) -> bool:
if config is None or _AUTO_PACKING_ENV_DISABLED:
return False
if not getattr(config, "packing", False):
def _should_pack(config) -> bool:
if config is None or not getattr(config, "packing", False):
return False
return not getattr(config, "_unsloth_disable_auto_packing", False)


def _should_auto_padding_free(config) -> bool:
if config is None or _AUTO_PADDING_FREE_ENV_DISABLED:
return False
if getattr(config, "packing", False):
if (
config is None
or _AUTO_PADDING_FREE_ENV_DISABLED
or getattr(config, "packing", False)
):
return False
return not getattr(config, "padding_free", False)

Expand Down Expand Up @@ -326,23 +318,31 @@ def new_init(self, *args, **kwargs):
data_collator is not None
or isinstance(processing_class, ProcessorMixin)
or is_vlm
or is_unsupported_model
)
if blocked and _should_auto_pack(config_arg):
reason = (
"custom data collator"
if data_collator is not None
else "processor-based model"
)
logger.info(
"Unsloth: Auto sample packing skipped (%s detected). Use UNSLOTH_DISABLE_AUTO_PACKING=1 to silence.",
reason,
)
requested_pack = bool(getattr(config_arg, "packing", False))
if blocked:
if hasattr(config_arg, "packing"):
setattr(config_arg, "packing", False)
if hasattr(config_arg, "padding_free"):
setattr(config_arg, "padding_free", False)

if blocked and requested_pack:
reason = "custom data collator"
if data_collator is None and isinstance(processing_class, ProcessorMixin):
reason = "processor-based model"
elif is_vlm:
reason = "vision-language model"
elif is_unsupported_model:
reason = f"unsupported model type(s): {', '.join(model_types)}"
message = "Unsloth: Sample packing skipped " f"({reason} detected)."
print(message)

auto_pack_active = False
if _should_auto_pack(config_arg) and not blocked:
packing_active = False
if _should_pack(config_arg) and not blocked:
configure_sample_packing(config_arg)
auto_pack_active = True
logger.info("Unsloth: Sample packing auto-enabled for SFTTrainer instance.")
packing_active = True
logger.info("Unsloth: Sample packing enabled for SFTTrainer instance.")

auto_padding_free_active = False
padding_free_requested = getattr(config_arg, "padding_free", None) is True
Expand All @@ -359,13 +359,13 @@ def new_init(self, *args, **kwargs):
try:
original_init(self, *args, **kwargs)
except ValueError as exc:
if auto_pack_active and _should_skip_auto_packing_error(exc):
if packing_active and _should_skip_auto_packing_error(exc):
logger.info(
"Unsloth: Auto sample packing failed because trainer reported an incompatible setup (%s).",
exc,
)
_disable_sample_packing(config_arg)
auto_pack_active = False
packing_active = False
original_init(self, *args, **kwargs)
else:
raise
Expand All @@ -376,12 +376,21 @@ def new_init(self, *args, **kwargs):
trainer_args and getattr(trainer_args, "padding_free", False)
)

if trainer_packing and (auto_pack_active or _should_auto_pack(trainer_args)):
if blocked and trainer_args is not None:
# Mirror the block on the trainer args to avoid re-enabling later
setattr(trainer_args, "packing", False)
setattr(trainer_args, "padding_free", False)

if (
not blocked
and trainer_packing
and (packing_active or _should_pack(trainer_args))
):
enable_sample_packing(self.model, self)
print(
"🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!"
)
elif trainer_padding_free:
elif not blocked and trainer_padding_free:
enable_padding_free_metadata(self.model, self)
message = (
"🦥 Unsloth: Padding-free auto-enabled, enabling faster training."
Expand Down
51 changes: 2 additions & 49 deletions unsloth/utils/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,12 @@ def configure_sample_packing(config):
_ensure_trl_warning_filter()
setattr(config, "packing", True)
setattr(config, "padding_free", True)
setattr(config, "remove_unused_columns", False)


def configure_padding_free(config):
"""Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing."""
_ensure_trl_warning_filter()
setattr(config, "padding_free", True)
if hasattr(config, "remove_unused_columns"):
setattr(config, "remove_unused_columns", False)


def enable_sample_packing(
Expand Down Expand Up @@ -150,48 +147,15 @@ def torch_call_with_lengths(examples: Sequence[dict]):
batch = original_torch_call(examples)
if examples and isinstance(examples[0], dict):
seq_lengths: list[int] = []
per_example_counts: list[int] = []
for example in examples:
lengths = example.get(sequence_lengths_key)
if isinstance(lengths, Iterable):
numeric_lengths = [int(length) for length in lengths]
seq_lengths.extend(numeric_lengths)
per_example_counts.append(len(numeric_lengths))
else:
per_example_counts.append(0)
seq_lengths.extend(int(length) for length in lengths)

@djsaunde djsaunde Dec 10, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified this since trl emits this position_ids metadata.

if seq_lengths:
batch["packed_seq_lengths"] = torch.tensor(
seq_lengths, dtype = torch.int32
)

position_ids = batch.get("position_ids")
input_ids = batch.get("input_ids")
if position_ids is None and input_ids is not None:
position_ids = torch.zeros_like(
input_ids, dtype = torch.long, device = input_ids.device
)

if position_ids is not None and input_ids is not None:
seq_index = 0
for row_idx, count in enumerate(per_example_counts):
cursor = 0
for _ in range(count):
length = seq_lengths[seq_index]
if length > 0:
position_ids[row_idx, cursor : cursor + length] = (
torch.arange(
length,
dtype = torch.long,
device = position_ids.device,
)
)
cursor += length
seq_index += 1
batch["position_ids"] = position_ids

if "attention_mask" in batch and getattr(
collator, "return_position_ids", False
):
if "attention_mask" in batch:
batch.pop("attention_mask")
return batch

Expand All @@ -201,23 +165,12 @@ def torch_call_with_lengths(examples: Sequence[dict]):

def enable_padding_free_metadata(model, trainer):
"""Inject seq-length metadata when padding-free batching is enabled without packing."""

trainer_args = getattr(trainer, "args", None)
if (
trainer_args is not None
and hasattr(trainer_args, "remove_unused_columns")
and trainer_args.remove_unused_columns
):
trainer_args.remove_unused_columns = False

_ensure_trl_warning_filter()
collator = getattr(trainer, "data_collator", None)
if (
collator is None
or getattr(collator, "_unsloth_padding_free_lengths_wrapped", False)
or not getattr(collator, "padding_free", False)
):
# Nothing to do if there's no collator, we've already wrapped it, or padding-free is off.
return

mark_allow_overlength(model)
Expand Down