From 559730a7cc5ef21571f4f8fc7ad578638d1f7855 Mon Sep 17 00:00:00 2001 From: Michal Futrega Date: Mon, 9 Feb 2026 12:54:11 +0100 Subject: [PATCH 1/2] Add callbacks to finetune function --- src/megatron/bridge/training/finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/megatron/bridge/training/finetune.py b/src/megatron/bridge/training/finetune.py index 8ab3956148..1464286fec 100644 --- a/src/megatron/bridge/training/finetune.py +++ b/src/megatron/bridge/training/finetune.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from megatron.bridge.training.callbacks import Callback, CallbackManager from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.forward_step_func_types import ForwardStepCallable from megatron.bridge.training.pretrain import pretrain @@ -22,6 +23,7 @@ def finetune( config: ConfigContainer, forward_step_func: ForwardStepCallable, + callbacks: list[Callback] | CallbackManager | None = None, ) -> None: """Main function to run the finetuning. @@ -34,6 +36,7 @@ def finetune( - 3 args: (data_iterator, model, return_schedule_plan=False) OR (state: GlobalState, data_iterator, model) - 4 args: (state: GlobalState, data_iterator, model, return_schedule_plan=False) + callbacks: Optional list of Callback instances, a CallbackManager, or None. Note: Use the signature with GlobalState type hint for full access to configuration, timers, and training state. @@ -47,4 +50,4 @@ def finetune( assert config.checkpoint.pretrained_checkpoint is not None or config.checkpoint.load is not None, ( "Finetuning requires a loading from a pretrained checkpoint or resuming from a checkpoint" ) - return pretrain(config, forward_step_func) + return pretrain(config, forward_step_func, callbacks=callbacks) From cdbb03ef008f7576ffcd6b59c61d6badaf77ee0d Mon Sep 17 00:00:00 2001 From: Michal Futrega Date: Mon, 9 Feb 2026 16:24:42 +0100 Subject: [PATCH 2/2] fix tests --- tests/unit_tests/training/test_finetune.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/training/test_finetune.py b/tests/unit_tests/training/test_finetune.py index adf5a8f1b8..33aa23fb0f 100644 --- a/tests/unit_tests/training/test_finetune.py +++ b/tests/unit_tests/training/test_finetune.py @@ -80,7 +80,7 @@ def test_finetune_succeeds_with_pretrained_checkpoint(self): # This should not raise an AssertionError finetune(container, mock_forward_step_func) # Verify that pretrain was called with the correct arguments - mock_pretrain.assert_called_once_with(container, mock_forward_step_func) + mock_pretrain.assert_called_once_with(container, mock_forward_step_func, callbacks=None) finally: restore_get_world_size_safe(og_ws, cfg_mod) @@ -108,6 +108,6 @@ def test_finetune_succeeds_with_load_checkpoint(self): # This should not raise an AssertionError finetune(container, mock_forward_step_func) # Verify that pretrain was called with the correct arguments - mock_pretrain.assert_called_once_with(container, mock_forward_step_func) + mock_pretrain.assert_called_once_with(container, mock_forward_step_func, callbacks=None) finally: restore_get_world_size_safe(og_ws, cfg_mod)