Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/megatron/bridge/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.bridge.training.callbacks import Callback, CallbackManager
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.forward_step_func_types import ForwardStepCallable
from megatron.bridge.training.pretrain import pretrain
Expand All @@ -22,6 +23,7 @@
def finetune(
config: ConfigContainer,
forward_step_func: ForwardStepCallable,
callbacks: list[Callback] | CallbackManager | None = None,
) -> None:
"""Main function to run the finetuning.

Expand All @@ -34,6 +36,7 @@ def finetune(
- 3 args: (data_iterator, model, return_schedule_plan=False)
OR (state: GlobalState, data_iterator, model)
- 4 args: (state: GlobalState, data_iterator, model, return_schedule_plan=False)
callbacks: Optional list of Callback instances, a CallbackManager, or None.

Note:
Use the signature with GlobalState type hint for full access to configuration, timers, and training state.
Expand All @@ -47,4 +50,4 @@ def finetune(
assert config.checkpoint.pretrained_checkpoint is not None or config.checkpoint.load is not None, (
"Finetuning requires a loading from a pretrained checkpoint or resuming from a checkpoint"
)
return pretrain(config, forward_step_func)
return pretrain(config, forward_step_func, callbacks=callbacks)
4 changes: 2 additions & 2 deletions tests/unit_tests/training/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_finetune_succeeds_with_pretrained_checkpoint(self):
# This should not raise an AssertionError
finetune(container, mock_forward_step_func)
# Verify that pretrain was called with the correct arguments
mock_pretrain.assert_called_once_with(container, mock_forward_step_func)
mock_pretrain.assert_called_once_with(container, mock_forward_step_func, callbacks=None)
finally:
restore_get_world_size_safe(og_ws, cfg_mod)

Expand Down Expand Up @@ -108,6 +108,6 @@ def test_finetune_succeeds_with_load_checkpoint(self):
# This should not raise an AssertionError
finetune(container, mock_forward_step_func)
# Verify that pretrain was called with the correct arguments
mock_pretrain.assert_called_once_with(container, mock_forward_step_func)
mock_pretrain.assert_called_once_with(container, mock_forward_step_func, callbacks=None)
finally:
restore_get_world_size_safe(og_ws, cfg_mod)