diff --git a/docs/training/callbacks.md b/docs/training/callbacks.md index 575ff324d0..a445a6814d 100644 --- a/docs/training/callbacks.md +++ b/docs/training/callbacks.md @@ -9,8 +9,12 @@ Megatron Bridge provides a lightweight callback system for injecting custom logi Subclass {py:class}`bridge.training.callbacks.Callback` and override event methods: ```python +import time + from megatron.bridge.training.callbacks import Callback +from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config class MyCallback(Callback): def on_train_start(self, context): @@ -25,8 +29,11 @@ class MyCallback(Callback): elapsed = time.time() - context.user_state['start_time'] print(f"Training completed in {elapsed:.2f}s") +# Create a config that fits on a single GPU +config = qwen25_500m_pretrain_config() + # Pass callbacks to pretrain -pretrain(config, forward_step_func, callbacks=[MyCallback()]) +pretrain(config, forward_step, callbacks=[MyCallback()]) ``` ### Functional Callbacks @@ -35,7 +42,9 @@ Register functions directly with {py:class}`bridge.training.callbacks.CallbackMa ```python from megatron.bridge.training.callbacks import CallbackManager -from megatron.bridge.training import pretrain +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config def log_step(context): step = context.state.train_state.step @@ -45,7 +54,10 @@ def log_step(context): callback_manager = CallbackManager() callback_manager.register("on_train_step_end", log_step) -pretrain(config, forward_step_func, callbacks=callback_manager) +# Create a config that fits on a single GPU +config = qwen25_500m_pretrain_config() + +pretrain(config, forward_step, callbacks=callback_manager) ``` ### Mixing Both Patterns @@ -53,12 +65,20 @@ pretrain(config, forward_step_func, callbacks=callback_manager) Both registration patterns can be combined: ```python +from megatron.bridge.training.callbacks import CallbackManager +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config + manager = CallbackManager() manager.add(MyCallback()) manager.add([TimingCallback(), MetricsCallback()]) manager.register("on_eval_end", lambda ctx: print("Evaluation complete!")) -pretrain(config, forward_step_func, callbacks=manager) +# Create a config that fits on a single GPU +config = qwen25_500m_pretrain_config() + +pretrain(config, forward_step, callbacks=manager) ``` ## Available Events