diff --git a/src/megatron/bridge/training/optim.py b/src/megatron/bridge/training/optim.py index 38426990ba..658bf85b54 100644 --- a/src/megatron/bridge/training/optim.py +++ b/src/megatron/bridge/training/optim.py @@ -19,8 +19,20 @@ MegatronOptimizer, OptimizerConfig, get_megatron_optimizer, - get_mup_config_overrides, ) + + +# TODO: Remove try/except once `get_mup_config_overrides` lands in mcore main. +# This guard exists because the symbol lives in mcore dev but not yet in +# the main branch that the submodule tracks. +# +# We assign None (not a bool flag) so the module attribute always exists +# and tests can patch it without AttributeError. +try: + from megatron.core.optimizer import get_mup_config_overrides +except ImportError: + get_mup_config_overrides = None # type: ignore[assignment] + from megatron.core.optimizer.muon import get_megatron_muon_optimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.process_groups_config import ProcessGroupCollection @@ -65,10 +77,12 @@ def setup_optimizer( OptimizerConfigOverrideProviderContext(scheduler_config, optimizer_config, model) ) - # Apply μP optimizer scaling if enabled on the model config + # Apply μP optimizer scaling if enabled on the model config. + # Guard on the callable itself (None when mcore main lacks the symbol) so + # unit tests can patch the module attribute without hitting AttributeError. model_chunks = model if isinstance(model, list) else [model] model_config = get_model_config(model_chunks[0]) - if getattr(model_config, "use_mup", False): + if get_mup_config_overrides is not None and getattr(model_config, "use_mup", False): mup_overrides = get_mup_config_overrides( config=optimizer_config, mup_width_mult=model_config.mup_width_mult,