Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 110 additions & 23 deletions unsloth_zoo/fused_losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@
TARGET_GB = os.environ.get("UNSLOTH_CE_LOSS_TARGET_GB", None)
N_CHUNKS = os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None)

# Register grad_and_value_impl in trace_rules as defense-in-depth.
# grad_impl is registered but grad_and_value_impl is not, which can cause
# GB0149 "Unsupported functorch tracing attempt" in some configurations.
try:
from torch._dynamo.trace_rules import manual_torch_name_rule_map as _trace_map
from torch._dynamo.variables.higher_order_ops import FunctorchHigherOrderVariable as _FHOV
_key = "torch._functorch.eager_transforms.grad_and_value_impl"
if _key not in _trace_map:
_trace_map[_key] = _FHOV
torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear()
except Exception:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging more difficult. Since you are interacting with internal PyTorch APIs, which are subject to change, it's better to catch more specific exceptions. ImportError and AttributeError are the most likely candidates here. This will make the code safer and prevent masking unrelated issues.

Suggested change
except Exception:
except (ImportError, AttributeError):

pass

# Module-level flag: None = untested, True = works, False = skip compile.
_FUSED_CE_COMPILE_SUPPORTED = None if \
os.environ.get("UNSLOTH_FUSED_CE_COMPILE_DISABLE", "0") != "1" else False

@functools.cache
def _get_mapping(autograd):
parameters = inspect.signature(getattr(autograd, "forward")).parameters
Expand Down Expand Up @@ -301,30 +318,100 @@ def accumulate_chunk(
accumulated_loss.add_(unscaled_loss)
grad_inputs_j[:] = chunk_grad_input
pass
if torch_compile:
accumulate_chunk = torch.compile(
accumulate_chunk,
dynamic = True,
fullgraph = True,
options = torch_compile_options,
)
global _FUSED_CE_COMPILE_SUPPORTED
uncompiled_accumulate_chunk = accumulate_chunk

if torch_compile and _FUSED_CE_COMPILE_SUPPORTED is not False:
try:
accumulate_chunk = torch.compile(
accumulate_chunk,
dynamic = True,
fullgraph = True,
options = torch_compile_options,
)
except Exception:
_FUSED_CE_COMPILE_SUPPORTED = False
accumulate_chunk = uncompiled_accumulate_chunk

# Probe path: first-ever forward pass, test if compiled version works
if _FUSED_CE_COMPILE_SUPPORTED is None and torch_compile and \
accumulate_chunk is not uncompiled_accumulate_chunk:

for (grad_inputs_j, hidden_states_j, labels_j,) in \
zip(__grad_inputs, __shift_states, __shift_labels,):
accumulate_chunk(
n_chunks = n_chunks,
grad_inputs_j = grad_inputs_j,
grad_lm_head = grad_lm_head,
grad_lm_head_bias = grad_lm_head_bias,
hidden_states_j = hidden_states_j,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels_j = labels_j,
divisor = divisor,
scaling = scaling,
shift_labels = shift_labels,
**extra_kwargs,
)
_iter = iter(zip(__grad_inputs, __shift_states, __shift_labels))
grad_inputs_j, hidden_states_j, labels_j = next(_iter)
try:
accumulate_chunk(
n_chunks = n_chunks,
grad_inputs_j = grad_inputs_j,
grad_lm_head = grad_lm_head,
grad_lm_head_bias = grad_lm_head_bias,
hidden_states_j = hidden_states_j,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels_j = labels_j,
divisor = divisor,
scaling = scaling,
shift_labels = shift_labels,
**extra_kwargs,
)
_FUSED_CE_COMPILE_SUPPORTED = True
except Exception:
_FUSED_CE_COMPILE_SUPPORTED = False
torch._dynamo.reset()
accumulated_loss.zero_()
if not overwrite:
grad_inputs.zero_()
if grad_lm_head is not None: grad_lm_head.zero_()
if grad_lm_head_bias is not None: grad_lm_head_bias.zero_()
accumulate_chunk = uncompiled_accumulate_chunk
accumulate_chunk(
n_chunks = n_chunks,
grad_inputs_j = grad_inputs_j,
grad_lm_head = grad_lm_head,
grad_lm_head_bias = grad_lm_head_bias,
hidden_states_j = hidden_states_j,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels_j = labels_j,
divisor = divisor,
scaling = scaling,
shift_labels = shift_labels,
**extra_kwargs,
)
# Process remaining chunks via fast path
for (grad_inputs_j, hidden_states_j, labels_j,) in _iter:
accumulate_chunk(
n_chunks = n_chunks,
grad_inputs_j = grad_inputs_j,
grad_lm_head = grad_lm_head,
grad_lm_head_bias = grad_lm_head_bias,
hidden_states_j = hidden_states_j,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels_j = labels_j,
divisor = divisor,
scaling = scaling,
shift_labels = shift_labels,
**extra_kwargs,
)
else:
# Fast path: compile status already known, original main branch loop
for (grad_inputs_j, hidden_states_j, labels_j,) in \
zip(__grad_inputs, __shift_states, __shift_labels,):
accumulate_chunk(
n_chunks = n_chunks,
grad_inputs_j = grad_inputs_j,
grad_lm_head = grad_lm_head,
grad_lm_head_bias = grad_lm_head_bias,
hidden_states_j = hidden_states_j,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels_j = labels_j,
divisor = divisor,
scaling = scaling,
shift_labels = shift_labels,
**extra_kwargs,
)
pass
ctx.save_for_backward(grad_inputs, grad_lm_head, grad_lm_head_bias)
ctx.scaling = scaling
Expand Down
Loading