diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index eaacfc01c137..2f6474654901 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -218,9 +218,12 @@ " optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))\n", "\n", " # Load from checkpoint if available (for resuming training)\n", + " start_epoch = 0\n", " loaded_checkpoint = ray.train.get_checkpoint()\n", " if loaded_checkpoint:\n", - " load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", + " latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", + " start_epoch = latest_epoch + 1 if latest_epoch != None else 0\n", + " logger.info(f\"Resuming training from epoch {start_epoch}\")\n", "\n", " # Prepare training data\n", " transform = Compose([\n", @@ -258,7 +261,7 @@ " num_batches = 0\n", " epochs = config.get('epochs', 5)\n", " \n", - " for epoch in range(epochs):\n", + " for epoch in range(start_epoch, epochs):\n", " # Set epoch for distributed sampler to ensure proper shuffling\n", " if ray.train.get_context().get_world_size() > 1:\n", " train_loader.sampler.set_epoch(epoch)\n", @@ -282,8 +285,8 @@ "\n", " # Report metrics and save checkpoint after each epoch\n", " avg_loss = running_loss / num_batches\n", - " metrics = {\"loss\": avg_loss, \"epoch\": epoch}\n", - " report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics)\n", + " metrics = {\"loss\": avg_loss}\n", + " report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)\n", "\n", " # Log metrics from rank 0 only to avoid duplicate outputs\n", " if world_rank == 0:\n", @@ -518,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -531,16 +534,18 @@ " and optimizer.\n", " \"\"\"\n", "\n", - " def __init__(self, model, optimizer=None):\n", + " def __init__(self, model, optimizer=None, epoch=None):\n", " self.model = model\n", " self.optimizer = optimizer\n", + " self.epoch = epoch\n", "\n", " def state_dict(self):\n", " # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT\n", " model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)\n", " return {\n", " \"model\": model_state_dict,\n", - " \"optim\": optimizer_state_dict\n", + " \"optim\": optimizer_state_dict,\n", + " \"epoch\": self.epoch\n", " }\n", "\n", " def load_state_dict(self, state_dict):\n", @@ -550,7 +555,10 @@ " self.optimizer,\n", " model_state_dict=state_dict[\"model\"],\n", " optim_state_dict=state_dict[\"optim\"],\n", - " )" + " )\n", + " # Load epoch information if available\n", + " if \"epoch\" in state_dict:\n", + " self.epoch = state_dict[\"epoch\"]" ] }, { @@ -574,11 +582,11 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint):\n", + "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None:\n", " \"\"\"Load an FSDP checkpoint into the model and optimizer.\n", " \n", " This function handles distributed checkpoint loading with automatic resharding\n", @@ -589,13 +597,17 @@ " model: The FSDP-wrapped model to load state into\n", " optimizer: The optimizer to load state into\n", " ckpt: Ray Train checkpoint containing the saved state\n", + "\n", + " Returns:\n", + " int: The epoch number saved within the checkpoint.\n", " \"\"\"\n", " logger.info(\"Loading distributed checkpoint for resuming training...\")\n", " \n", " try:\n", " with ckpt.as_directory() as checkpoint_dir:\n", " # Create state wrapper for DCP loading\n", - " state_dict = {\"app\": AppState(model, optimizer)}\n", + " app_state = AppState(model, optimizer)\n", + " state_dict = {\"app\": app_state}\n", " \n", " # Load the distributed checkpoint\n", " dcp.load(\n", @@ -603,7 +615,8 @@ " checkpoint_id=checkpoint_dir\n", " )\n", " \n", - " logger.info(\"Successfully loaded distributed checkpoint\")\n", + " logger.info(f\"Successfully loaded distributed checkpoint from epoch {app_state.epoch}\")\n", + " return app_state.epoch\n", " except Exception as e:\n", " logger.error(f\"Failed to load checkpoint: {e}\")\n", " raise RuntimeError(f\"Checkpoint loading failed: {e}\") from e" @@ -620,12 +633,12 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def report_metrics_and_save_fsdp_checkpoint(\n", - " model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict\n", + " model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0\n", ") -> None:\n", " \"\"\"Report training metrics and save an FSDP checkpoint.\n", " \n", @@ -637,12 +650,13 @@ " model: The FSDP-wrapped model to checkpoint\n", " optimizer: The optimizer to checkpoint\n", " metrics: Dictionary of metrics to report (e.g., loss, accuracy)\n", + " epoch: The current epoch to be saved\n", " \"\"\"\n", " logger.info(\"Saving checkpoint and reporting metrics...\")\n", " \n", " with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n", " # Perform a distributed checkpoint with DCP\n", - " state_dict = {\"app\": AppState(model, optimizer)}\n", + " state_dict = {\"app\": AppState(model, optimizer, epoch)}\n", " dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)\n", "\n", " # Report each checkpoint shard from all workers\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index 9679d01333e2..8292bf7cd49a 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -166,9 +166,12 @@ def train_func(config): optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001)) # Load from checkpoint if available (for resuming training) + start_epoch = 0 loaded_checkpoint = ray.train.get_checkpoint() if loaded_checkpoint: - load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) + latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) + start_epoch = latest_epoch + 1 if latest_epoch is not None else 0 + logger.info(f"Resuming training from epoch {start_epoch}") # Prepare training data transform = Compose([ @@ -206,7 +209,7 @@ def train_func(config): num_batches = 0 epochs = config.get('epochs', 5) - for epoch in range(epochs): + for epoch in range(start_epoch, epochs): # Set epoch for distributed sampler to ensure proper shuffling if ray.train.get_context().get_world_size() > 1: train_loader.sampler.set_epoch(epoch) @@ -230,8 +233,8 @@ def train_func(config): # Report metrics and save checkpoint after each epoch avg_loss = running_loss / num_batches - metrics = {"loss": avg_loss, "epoch": epoch} - report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics) + metrics = {"loss": avg_loss} + report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch) # Log metrics from rank 0 only to avoid duplicate outputs if world_rank == 0: @@ -444,16 +447,18 @@ class AppState(Stateful): and optimizer. """ - def __init__(self, model, optimizer=None): + def __init__(self, model, optimizer=None, epoch=None): self.model = model self.optimizer = optimizer + self.epoch = epoch def state_dict(self): # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, - "optim": optimizer_state_dict + "optim": optimizer_state_dict, + "epoch": self.epoch } def load_state_dict(self, state_dict): @@ -464,6 +469,9 @@ class AppState(Stateful): model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], ) + # Load epoch information if available + if "epoch" in state_dict: + self.epoch = state_dict["epoch"] ``` ### Load distributed model from checkpoint @@ -478,7 +486,7 @@ import torch.distributed.checkpoint as dcp ```python -def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint): +def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None: """Load an FSDP checkpoint into the model and optimizer. This function handles distributed checkpoint loading with automatic resharding @@ -489,13 +497,17 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck model: The FSDP-wrapped model to load state into optimizer: The optimizer to load state into ckpt: Ray Train checkpoint containing the saved state + + Returns: + int: The epoch number saved within the checkpoint. """ logger.info("Loading distributed checkpoint for resuming training...") try: with ckpt.as_directory() as checkpoint_dir: # Create state wrapper for DCP loading - state_dict = {"app": AppState(model, optimizer)} + app_state = AppState(model, optimizer) + state_dict = {"app": app_state} # Load the distributed checkpoint dcp.load( @@ -503,7 +515,8 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck checkpoint_id=checkpoint_dir ) - logger.info("Successfully loaded distributed checkpoint") + logger.info(f"Successfully loaded distributed checkpoint from epoch {app_state.epoch}") + return app_state.epoch except Exception as e: logger.error(f"Failed to load checkpoint: {e}") raise RuntimeError(f"Checkpoint loading failed: {e}") from e @@ -516,7 +529,7 @@ The following function handles periodic checkpoint saving during training, combi ```python def report_metrics_and_save_fsdp_checkpoint( - model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict + model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0 ) -> None: """Report training metrics and save an FSDP checkpoint. @@ -528,12 +541,13 @@ def report_metrics_and_save_fsdp_checkpoint( model: The FSDP-wrapped model to checkpoint optimizer: The optimizer to checkpoint metrics: Dictionary of metrics to report (e.g., loss, accuracy) + epoch: The current epoch to be saved """ logger.info("Saving checkpoint and reporting metrics...") with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Perform a distributed checkpoint with DCP - state_dict = {"app": AppState(model, optimizer)} + state_dict = {"app": AppState(model, optimizer, epoch)} dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir) # Report each checkpoint shard from all workers