Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,22 @@ def create_new_function(
overwrite = True
pass
if os.environ.get("UNSLOTH_COMPILE_OVERWRITE", "1") == "0":
overwrite = False
# Even with OVERWRITE disabled, force recompile on transformers version mismatch
if file_source is not None and "__UNSLOTH_VERSIONING__" in file_source:
cached_versions = file_source[:file_source.find("__UNSLOTH_VERSIONING__")]
cached_lines = [l.strip() for l in cached_versions.strip().strip('"').split("\n") if l.strip()]
# Format: [unsloth_zoo_version, unsloth_version, transformers_version, trl_version]
cached_tf_version = cached_lines[2] if len(cached_lines) > 2 else "0"
if cached_tf_version != transformers_version:
logger.warning_once(
f"Unsloth: UNSLOTH_COMPILE_OVERWRITE=0 is set, but transformers version changed "
f"({cached_tf_version} -> {transformers_version}). Forcing recompile of {name}."
)
# Don't set overwrite = False; keep overwrite = True from version mismatch detection
else:
overwrite = False
else:
overwrite = False

# Check location
def write_file(function_location, write_new_source):
Expand Down Expand Up @@ -3165,15 +3180,26 @@ def replaced_tqdm(*args, **kwargs):
bad_torch_modules.add(module)
pass

# Check if creating arrays in inside the function
# Error: DataDependentOutputException: aten._local_scalar_dense.default
# Check for data-dependent control flow that breaks torch.compile(fullgraph=True)
# Tier 1: Direct data escapes from tensor to Python
# .nonzero() -> data-dependent output shape (variable-length)
# .tolist() -> materializes tensor values into Python list
# .item() -> materializes tensor scalar into Python
# Tier 2: MoE expert dispatch via torch.where + index_add
# 1-arg torch.where returns data-dependent indices; combined with
# index_add this is the standard MoE routing loop pattern
if (
"torch.arange(" in source
or "torch.zeros(" in source
or "torch.ones(" in source
".nonzero()" in source
or ".tolist()" in source
or ".item()" in source
):
print(
f"Unsloth: Failed compiling function {module} since array creations are done."
f"Unsloth: Will not compile {module} since data-dependent operations are done."
)
bad_torch_modules.add(module)
elif "torch.where(" in source and ".index_add" in source:
print(
f"Unsloth: Will not compile {module} since data-dependent routing is done."
)
bad_torch_modules.add(module)
pass
Expand Down
32 changes: 31 additions & 1 deletion unsloth_zoo/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,28 +340,58 @@ def _train_on_responses_only(examples):
# Set it to int(memory_gb_left) so 16Gb = 16
num_proc = min(num_proc, int(memory_gb_left))

# In transformers 5.0+, VLM models skip dataset preparation in SFTTrainer.__init__
# (skip_prepare_dataset=True when _is_vlm=True). This means the dataset may not be
# tokenized yet. We need to tokenize it before applying _train_on_responses_only.
def _maybe_tokenize_dataset(dataset):
if dataset is None:
return dataset
sample = next(iter(dataset))
if "input_ids" in sample:
return dataset # Already tokenized
# Need to tokenize - get the processing class from trainer
_tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer
# Get the actual tokenizer (not processor) for tokenization
if hasattr(_tokenizer, "tokenizer"):
_tok = _tokenizer.tokenizer
else:
_tok = _tokenizer
max_length = getattr(trainer.args, "max_length", None) or getattr(trainer.args, "max_seq_length", 2048)
text_field = getattr(trainer.args, "dataset_text_field", "text")
def _tokenize_fn(examples):
texts = examples.get(text_field) or examples.get("text", [])
return _tok(texts, truncation=True, max_length=max_length, padding=False)
_map_kwargs = {"batched": True, "num_proc": num_proc}
if isinstance(dataset, IterableDataset):
_map_kwargs = {"batched": True}
return dataset.map(_tokenize_fn, **_map_kwargs)
pass

if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
if not hasattr(trainer.train_dataset, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
trainer.train_dataset = _maybe_tokenize_dataset(trainer.train_dataset)
if isinstance(trainer.train_dataset, IterableDataset):
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batch_size = trainer.train_dataset._ex_iterable.batch_size, batched = True)
else:
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
pass

if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
# Eval datasets could be a dict!
if type(trainer.eval_dataset) is dict:
for key, value in trainer.eval_dataset.items():
if not hasattr(value, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
value = _maybe_tokenize_dataset(value)
if isinstance(value, IterableDataset):
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batch_size = value._ex_iterable.batch_size, batched = True)
else:
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc)
else:
if not hasattr(trainer.eval_dataset, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
trainer.eval_dataset = _maybe_tokenize_dataset(trainer.eval_dataset)
if isinstance(trainer.eval_dataset, IterableDataset):
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True)
else:
Expand Down
14 changes: 11 additions & 3 deletions unsloth_zoo/temporary_patches/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,18 @@ def patch_bitsandbytes_linear4bit_forward():
return raise_error("bitsandbytes.Linear4bit", e)

def forward(self, x: torch.Tensor):
fix_4bit_weight_quant_state_from_module(self)
# In transformers 5.0+, weights may not be in packed format yet during init
if self.weight.shape[-1] == 1:
fix_4bit_weight_quant_state_from_module(self)

# Some layers may not be quantized (no quant_state) - fall back to regular matmul
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is None:
bias = None if self.bias is None else self.bias
return torch.nn.functional.linear(x, self.weight, bias)

# weights are cast automatically as Int8Params, but the bias has to be cast manually

# ** Errors out in torch.compile so remove it
# if self.bias is not None and self.bias.dtype != x.dtype:
# self.bias.data = self.bias.data.to(x.dtype)
Expand All @@ -72,7 +80,7 @@ def forward(self, x: torch.Tensor):
# Cannot do .t() on Params4bit, instead do it on torch.Tensor
weight = self.weight.data.t()

return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype)

patch_function(bitsandbytes.nn.modules.Linear4bit, "forward", forward)
try:
Expand Down
Loading