diff --git a/unsloth_zoo/fused_losses/cross_entropy_loss.py b/unsloth_zoo/fused_losses/cross_entropy_loss.py index 4e9a8ede3..17276c417 100644 --- a/unsloth_zoo/fused_losses/cross_entropy_loss.py +++ b/unsloth_zoo/fused_losses/cross_entropy_loss.py @@ -33,6 +33,23 @@ TARGET_GB = os.environ.get("UNSLOTH_CE_LOSS_TARGET_GB", None) N_CHUNKS = os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None) +# Register grad_and_value_impl in trace_rules as defense-in-depth. +# grad_impl is registered but grad_and_value_impl is not, which can cause +# GB0149 "Unsupported functorch tracing attempt" in some configurations. +try: + from torch._dynamo.trace_rules import manual_torch_name_rule_map as _trace_map + from torch._dynamo.variables.higher_order_ops import FunctorchHigherOrderVariable as _FHOV + _key = "torch._functorch.eager_transforms.grad_and_value_impl" + if _key not in _trace_map: + _trace_map[_key] = _FHOV + torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() +except Exception: + pass + +# Module-level flag: None = untested, True = works, False = skip compile. +_FUSED_CE_COMPILE_SUPPORTED = None if \ + os.environ.get("UNSLOTH_FUSED_CE_COMPILE_DISABLE", "0") != "1" else False + @functools.cache def _get_mapping(autograd): parameters = inspect.signature(getattr(autograd, "forward")).parameters @@ -301,30 +318,100 @@ def accumulate_chunk( accumulated_loss.add_(unscaled_loss) grad_inputs_j[:] = chunk_grad_input pass - if torch_compile: - accumulate_chunk = torch.compile( - accumulate_chunk, - dynamic = True, - fullgraph = True, - options = torch_compile_options, - ) + global _FUSED_CE_COMPILE_SUPPORTED + uncompiled_accumulate_chunk = accumulate_chunk + + if torch_compile and _FUSED_CE_COMPILE_SUPPORTED is not False: + try: + accumulate_chunk = torch.compile( + accumulate_chunk, + dynamic = True, + fullgraph = True, + options = torch_compile_options, + ) + except Exception: + _FUSED_CE_COMPILE_SUPPORTED = False + accumulate_chunk = uncompiled_accumulate_chunk + + # Probe path: first-ever forward pass, test if compiled version works + if _FUSED_CE_COMPILE_SUPPORTED is None and torch_compile and \ + accumulate_chunk is not uncompiled_accumulate_chunk: - for (grad_inputs_j, hidden_states_j, labels_j,) in \ - zip(__grad_inputs, __shift_states, __shift_labels,): - accumulate_chunk( - n_chunks = n_chunks, - grad_inputs_j = grad_inputs_j, - grad_lm_head = grad_lm_head, - grad_lm_head_bias = grad_lm_head_bias, - hidden_states_j = hidden_states_j, - lm_head_weight = lm_head_weight, - lm_head_bias = lm_head_bias, - labels_j = labels_j, - divisor = divisor, - scaling = scaling, - shift_labels = shift_labels, - **extra_kwargs, - ) + _iter = iter(zip(__grad_inputs, __shift_states, __shift_labels)) + grad_inputs_j, hidden_states_j, labels_j = next(_iter) + try: + accumulate_chunk( + n_chunks = n_chunks, + grad_inputs_j = grad_inputs_j, + grad_lm_head = grad_lm_head, + grad_lm_head_bias = grad_lm_head_bias, + hidden_states_j = hidden_states_j, + lm_head_weight = lm_head_weight, + lm_head_bias = lm_head_bias, + labels_j = labels_j, + divisor = divisor, + scaling = scaling, + shift_labels = shift_labels, + **extra_kwargs, + ) + _FUSED_CE_COMPILE_SUPPORTED = True + except Exception: + _FUSED_CE_COMPILE_SUPPORTED = False + torch._dynamo.reset() + accumulated_loss.zero_() + if not overwrite: + grad_inputs.zero_() + if grad_lm_head is not None: grad_lm_head.zero_() + if grad_lm_head_bias is not None: grad_lm_head_bias.zero_() + accumulate_chunk = uncompiled_accumulate_chunk + accumulate_chunk( + n_chunks = n_chunks, + grad_inputs_j = grad_inputs_j, + grad_lm_head = grad_lm_head, + grad_lm_head_bias = grad_lm_head_bias, + hidden_states_j = hidden_states_j, + lm_head_weight = lm_head_weight, + lm_head_bias = lm_head_bias, + labels_j = labels_j, + divisor = divisor, + scaling = scaling, + shift_labels = shift_labels, + **extra_kwargs, + ) + # Process remaining chunks via fast path + for (grad_inputs_j, hidden_states_j, labels_j,) in _iter: + accumulate_chunk( + n_chunks = n_chunks, + grad_inputs_j = grad_inputs_j, + grad_lm_head = grad_lm_head, + grad_lm_head_bias = grad_lm_head_bias, + hidden_states_j = hidden_states_j, + lm_head_weight = lm_head_weight, + lm_head_bias = lm_head_bias, + labels_j = labels_j, + divisor = divisor, + scaling = scaling, + shift_labels = shift_labels, + **extra_kwargs, + ) + else: + # Fast path: compile status already known, original main branch loop + for (grad_inputs_j, hidden_states_j, labels_j,) in \ + zip(__grad_inputs, __shift_states, __shift_labels,): + accumulate_chunk( + n_chunks = n_chunks, + grad_inputs_j = grad_inputs_j, + grad_lm_head = grad_lm_head, + grad_lm_head_bias = grad_lm_head_bias, + hidden_states_j = hidden_states_j, + lm_head_weight = lm_head_weight, + lm_head_bias = lm_head_bias, + labels_j = labels_j, + divisor = divisor, + scaling = scaling, + shift_labels = shift_labels, + **extra_kwargs, + ) pass ctx.save_for_backward(grad_inputs, grad_lm_head, grad_lm_head_bias) ctx.scaling = scaling