-
Couldn't load subscription status.
- Fork 44
Description
TLDR: We can’t (cleanly) checkpoint everything yet (e.g., dataloader, replay buffer, RNG, etc.) because Titan’s checkpointer lives inside trainer.engine and exclusively controls the step-<N> folders, while those other pieces live in separate actors. There’s no safe, centralized way to co-write into the same step right now, so we will need to deferring full multi-component checkpointing until after PTC.
What does work today: with #444, we can enable saving and resuming model weights + optimizer + LR scheduler via Titan’s checkpointer.
Scope: This RFC only discuss the challenges in saving the checkpoint. For loading, the discussion can be referred to #425
1) Context (today’s flow)
We spin up components in main:
(
dataloader,
policy,
trainer, # has ForgeEngine
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
) = await asyncio.gather(...)Model checkpointing is delegated to TorchTitan via the trainer’s engine:
trainercreates the engine and loads a checkpoint:
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)- Every training step (in
train_step), the trainer asks Titan to save:
self.engine.checkpointer.save(
curr_step=self.step,
last_step=self.step == self.num_training_steps,
)Titan’s checkpointer writes model weights, optimizer state, LR schedulers into:
<folder>/step-<N>/__0_0.distcp
For example
./checkpoint/step-100/__0_0.distcp
Per issue #362, we also want to save/load the states for:
- data step,
- replay buffer data,
- RNG states,
- etc.
2) The problems
Problem 1: Step-folder ownership
We have Titan-owned directory per saving step (e.g., step-200) created from inside trainer.engine.checkpointer. Other actors (dataloader, replay buffer) do not have access to the trainer’s engine or to Titan’s private folder-naming method. That leaves us with two awkward choices:
-
Two folders per step
checkpoint/ step-200/ # Titan __0_0.distcp step-200-other/ # Ours dataloader.json replay_buffer.bin rng.jsonDownsides: clunky UX, hard to purge/retain atomically, and easy to drift.
-
Single folder per step (preferred)
To co-locate our files insidestep-200/, we must either:- call Titan’s private
_create_checkpoint_idto learn the folder name, but other components (e.g.,dataloader) doesn't have anengine. - re-implement a look-alike function and hope it never diverges.
- (preferred) Add a
pathparameter to thecheckpointer.save.
- call Titan’s private
Problem 2: No unified checkpoint scope
Currently, non-model states (like dataloader, replay_buffer, RNG, etc.) live in separate actors/services (e.g., dataloder), and the trainer, which owns Titan’s checkpointer, only manages model, optimizer, and LR scheduler.
Because Titan’s checkpointing is embedded inside trainer.engine.checkpointer, there’s no single coordinator that can write all components into the same step-<N> folder in a clean, synchronized way.
3) Proposal for Problem 2
Option 1: Make trainer own all other components
class RLTrainer:
self.dataloader = ...
self.replay_buffer = ...
self.rng = ...This is very fast to implement. However, it introduces heavy coupling, breaks actor/service boundaries, hurts scalability and reuse.
Option 2: Reimplement checkpointing ourselves
Write our own model/optim/lr checkpointing.
It gives us full control over layout and atomicity.
Downside
- Rebuilding the hardest part (distributed model/optim/lr, DCP, async, FT).
- High risk, high effort; guaranteed divergence from Titan.
Option 3: Coordinator layered above current actors
Introduce a light Checkpoint Coordinator that:
- calls Titan to save the model/lr/optimizers to specified path (requires one small API addition to Titan save:
path=), - then asks each actor to
state_dict()and writes to the same folder (e.g.,step-200/dataloader.json, etc.), - on load, after Titan resolves the step it will load, coordinator tries to load each components' states by calling their
load_state().
class CheckpointCoordinator:
def __init__(self):
self._trainer: RLTrainer = None
self._components: Dict[str, ForgeActor | SeverceInterface] = {}
def set_trainer(self, trainer:RLTrainer):
self._trainer = trainer
def register(self, name, comp: ForgeActor | SeverceInterface):
self._components[name] = comp
async def save(self, step: int, path: str):
path = get_path(folder, step)
if self._trainer:
self._trainer.engine.checkpointer.save(path = path)
for name, comp in self._components.items():
states = comp.state_dict()
dump_json(states, f"{path}/{name}.json")
async def load(self, step: int, path: str):
...The changes required in grpo/main:
coord = CheckpointCoordinator()
coord.set_trainer(trainer)
coord.register("dataloader", dataloader)
coord.register("replay_buffer", replay_buffer)
...
await coord.load(step, path=checkpoint_folder)
async def continuous_training():
...
await coord.save(step, path=checkpoint_folder)It is a relatively easy change and keeps most of the existing wheels.
Downside
- Nested structure: we still do
coord.savewhich callsself._trainer.engine.checkpointer.save. - Slightly specific to our the existing grpo script. May have generalizability issue in the future.
Option 4: Standalone ForgeCheckpointManager
Longer-term, we could create a standalone manager ForgeCheckpointManager that inherits Titan’s CheckpointManager and orchestrates both Titan and forge's components in one save()/load() call. Actors register their export_state/import_state with this one object; main calls just it.
Open question: where does ForgeCheckpointManager live if the engine stays inside the trainer? And how does it read/write model/optim/lr state without re-nesting the trainer or breaking actor decoupling?