From c8d0c21df415c818455c82fc5b2d81272aed8bba Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Mon, 20 Oct 2025 18:25:31 -0700 Subject: [PATCH 1/6] persisting epoch on checkpointing logic Signed-off-by: JasonLi1909 --- .../pytorch/pytorch-fsdp/README.ipynb | 35 ++++++++++++------- .../examples/pytorch/pytorch-fsdp/README.md | 25 +++++++++---- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index eaacfc01c137..7ccceb8bc30c 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -218,9 +218,12 @@ " optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))\n", "\n", " # Load from checkpoint if available (for resuming training)\n", + " start_epoch = 0\n", " loaded_checkpoint = ray.train.get_checkpoint()\n", " if loaded_checkpoint:\n", - " load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", + " latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", + " start_epoch = latest_epoch + 1\n", + " logger.info(f\"Resuming training from epoch {start_epoch}\")\n", "\n", " # Prepare training data\n", " transform = Compose([\n", @@ -258,7 +261,7 @@ " num_batches = 0\n", " epochs = config.get('epochs', 5)\n", " \n", - " for epoch in range(epochs):\n", + " for epoch in range(start_epoch, epochs):\n", " # Set epoch for distributed sampler to ensure proper shuffling\n", " if ray.train.get_context().get_world_size() > 1:\n", " train_loader.sampler.set_epoch(epoch)\n", @@ -518,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -531,16 +534,18 @@ " and optimizer.\n", " \"\"\"\n", "\n", - " def __init__(self, model, optimizer=None):\n", + " def __init__(self, model, optimizer=None, epoch=0):\n", " self.model = model\n", " self.optimizer = optimizer\n", + " self.epoch = epoch\n", "\n", " def state_dict(self):\n", " # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT\n", " model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)\n", " return {\n", " \"model\": model_state_dict,\n", - " \"optim\": optimizer_state_dict\n", + " \"optim\": optimizer_state_dict,\n", + " \"epoch\": self.epoch\n", " }\n", "\n", " def load_state_dict(self, state_dict):\n", @@ -550,7 +555,10 @@ " self.optimizer,\n", " model_state_dict=state_dict[\"model\"],\n", " optim_state_dict=state_dict[\"optim\"],\n", - " )" + " )\n", + " # Load epoch information if available\n", + " if \"epoch\" in state_dict:\n", + " self.epoch = state_dict[\"epoch\"]" ] }, { @@ -574,7 +582,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -595,6 +603,7 @@ " try:\n", " with ckpt.as_directory() as checkpoint_dir:\n", " # Create state wrapper for DCP loading\n", + " app_state = AppState(model, optimizer)\n", " state_dict = {\"app\": AppState(model, optimizer)}\n", " \n", " # Load the distributed checkpoint\n", @@ -603,7 +612,8 @@ " checkpoint_id=checkpoint_dir\n", " )\n", " \n", - " logger.info(\"Successfully loaded distributed checkpoint\")\n", + " logger.info(f\"Successfully loaded distributed checkpoint from epoch {app_state.epoch}\")\n", + " return app_state.epoch\n", " except Exception as e:\n", " logger.error(f\"Failed to load checkpoint: {e}\")\n", " raise RuntimeError(f\"Checkpoint loading failed: {e}\") from e" @@ -620,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -636,13 +646,14 @@ " Args:\n", " model: The FSDP-wrapped model to checkpoint\n", " optimizer: The optimizer to checkpoint\n", - " metrics: Dictionary of metrics to report (e.g., loss, accuracy)\n", + " metrics: Dictionary of metrics to report (e.g., loss, accuracy, epoch)\n", " \"\"\"\n", " logger.info(\"Saving checkpoint and reporting metrics...\")\n", " \n", " with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n", " # Perform a distributed checkpoint with DCP\n", - " state_dict = {\"app\": AppState(model, optimizer)}\n", + " epoch = metrics.get(\"epoch\", 0)\n", + " state_dict = {\"app\": AppState(model, optimizer, epoch)}\n", " dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)\n", "\n", " # Report each checkpoint shard from all workers\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index 9679d01333e2..3f1aa541891c 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -166,9 +166,12 @@ def train_func(config): optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001)) # Load from checkpoint if available (for resuming training) + start_epoch = 0 loaded_checkpoint = ray.train.get_checkpoint() if loaded_checkpoint: - load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) + latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) + start_epoch = latest_epoch + 1 + logger.info(f"Resuming training from epoch {start_epoch}") # Prepare training data transform = Compose([ @@ -206,7 +209,7 @@ def train_func(config): num_batches = 0 epochs = config.get('epochs', 5) - for epoch in range(epochs): + for epoch in range(start_epoch, epochs): # Set epoch for distributed sampler to ensure proper shuffling if ray.train.get_context().get_world_size() > 1: train_loader.sampler.set_epoch(epoch) @@ -444,16 +447,18 @@ class AppState(Stateful): and optimizer. """ - def __init__(self, model, optimizer=None): + def __init__(self, model, optimizer=None, epoch=0): self.model = model self.optimizer = optimizer + self.epoch = epoch def state_dict(self): # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, - "optim": optimizer_state_dict + "optim": optimizer_state_dict, + "epoch": self.epoch } def load_state_dict(self, state_dict): @@ -464,6 +469,9 @@ class AppState(Stateful): model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], ) + # Load epoch information if available + if "epoch" in state_dict: + self.epoch = state_dict["epoch"] ``` ### Load distributed model from checkpoint @@ -495,6 +503,7 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck try: with ckpt.as_directory() as checkpoint_dir: # Create state wrapper for DCP loading + app_state = AppState(model, optimizer) state_dict = {"app": AppState(model, optimizer)} # Load the distributed checkpoint @@ -503,7 +512,8 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck checkpoint_id=checkpoint_dir ) - logger.info("Successfully loaded distributed checkpoint") + logger.info(f"Successfully loaded distributed checkpoint from epoch {app_state.epoch}") + return app_state.epoch except Exception as e: logger.error(f"Failed to load checkpoint: {e}") raise RuntimeError(f"Checkpoint loading failed: {e}") from e @@ -527,13 +537,14 @@ def report_metrics_and_save_fsdp_checkpoint( Args: model: The FSDP-wrapped model to checkpoint optimizer: The optimizer to checkpoint - metrics: Dictionary of metrics to report (e.g., loss, accuracy) + metrics: Dictionary of metrics to report (e.g., loss, accuracy, epoch) """ logger.info("Saving checkpoint and reporting metrics...") with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Perform a distributed checkpoint with DCP - state_dict = {"app": AppState(model, optimizer)} + epoch = metrics.get("epoch", 0) + state_dict = {"app": AppState(model, optimizer, epoch)} dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir) # Report each checkpoint shard from all workers From 764e8597a26be77619e62e8c4e75c0adf06063ce Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Mon, 20 Oct 2025 18:40:05 -0700 Subject: [PATCH 2/6] load_fsdp_checkpoint fix Signed-off-by: JasonLi1909 --- doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb | 2 +- doc/source/train/examples/pytorch/pytorch-fsdp/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index 7ccceb8bc30c..4f4a1d4d4f24 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -604,7 +604,7 @@ " with ckpt.as_directory() as checkpoint_dir:\n", " # Create state wrapper for DCP loading\n", " app_state = AppState(model, optimizer)\n", - " state_dict = {\"app\": AppState(model, optimizer)}\n", + " state_dict = {\"app\": app_state}\n", " \n", " # Load the distributed checkpoint\n", " dcp.load(\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index 3f1aa541891c..6b0f2a58b500 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -504,7 +504,7 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck with ckpt.as_directory() as checkpoint_dir: # Create state wrapper for DCP loading app_state = AppState(model, optimizer) - state_dict = {"app": AppState(model, optimizer)} + state_dict = {"app": app_state} # Load the distributed checkpoint dcp.load( From 147a3d9b3dd2de52309f5dfffd90a819c4141c7a Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Tue, 21 Oct 2025 15:14:08 -0700 Subject: [PATCH 3/6] moved epoch out of metrics Signed-off-by: JasonLi1909 --- .../examples/pytorch/pytorch-fsdp/README.ipynb | 15 +++++++++------ .../train/examples/pytorch/pytorch-fsdp/README.md | 15 +++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index 4f4a1d4d4f24..f2e672ddf859 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -285,8 +285,8 @@ "\n", " # Report metrics and save checkpoint after each epoch\n", " avg_loss = running_loss / num_batches\n", - " metrics = {\"loss\": avg_loss, \"epoch\": epoch}\n", - " report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics)\n", + " metrics = {\"loss\": avg_loss}\n", + " report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)\n", "\n", " # Log metrics from rank 0 only to avoid duplicate outputs\n", " if world_rank == 0:\n", @@ -586,7 +586,7 @@ "metadata": {}, "outputs": [], "source": [ - "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint):\n", + "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int:\n", " \"\"\"Load an FSDP checkpoint into the model and optimizer.\n", " \n", " This function handles distributed checkpoint loading with automatic resharding\n", @@ -597,6 +597,9 @@ " model: The FSDP-wrapped model to load state into\n", " optimizer: The optimizer to load state into\n", " ckpt: Ray Train checkpoint containing the saved state\n", + "\n", + " Returns:\n", + " int: The epoch number saved within the checkpoint.\n", " \"\"\"\n", " logger.info(\"Loading distributed checkpoint for resuming training...\")\n", " \n", @@ -635,7 +638,7 @@ "outputs": [], "source": [ "def report_metrics_and_save_fsdp_checkpoint(\n", - " model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict\n", + " model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0\n", ") -> None:\n", " \"\"\"Report training metrics and save an FSDP checkpoint.\n", " \n", @@ -646,13 +649,13 @@ " Args:\n", " model: The FSDP-wrapped model to checkpoint\n", " optimizer: The optimizer to checkpoint\n", - " metrics: Dictionary of metrics to report (e.g., loss, accuracy, epoch)\n", + " metrics: Dictionary of metrics to report (e.g., loss, accuracy)\n", + " epoch: The current epoch to be saved\n", " \"\"\"\n", " logger.info(\"Saving checkpoint and reporting metrics...\")\n", " \n", " with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n", " # Perform a distributed checkpoint with DCP\n", - " epoch = metrics.get(\"epoch\", 0)\n", " state_dict = {\"app\": AppState(model, optimizer, epoch)}\n", " dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)\n", "\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index 6b0f2a58b500..aeb048f42539 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -233,8 +233,8 @@ def train_func(config): # Report metrics and save checkpoint after each epoch avg_loss = running_loss / num_batches - metrics = {"loss": avg_loss, "epoch": epoch} - report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics) + metrics = {"loss": avg_loss} + report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch) # Log metrics from rank 0 only to avoid duplicate outputs if world_rank == 0: @@ -486,7 +486,7 @@ import torch.distributed.checkpoint as dcp ```python -def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint): +def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int: """Load an FSDP checkpoint into the model and optimizer. This function handles distributed checkpoint loading with automatic resharding @@ -497,6 +497,9 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck model: The FSDP-wrapped model to load state into optimizer: The optimizer to load state into ckpt: Ray Train checkpoint containing the saved state + + Returns: + int: The epoch number saved within the checkpoint. """ logger.info("Loading distributed checkpoint for resuming training...") @@ -526,7 +529,7 @@ The following function handles periodic checkpoint saving during training, combi ```python def report_metrics_and_save_fsdp_checkpoint( - model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict + model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0 ) -> None: """Report training metrics and save an FSDP checkpoint. @@ -537,13 +540,13 @@ def report_metrics_and_save_fsdp_checkpoint( Args: model: The FSDP-wrapped model to checkpoint optimizer: The optimizer to checkpoint - metrics: Dictionary of metrics to report (e.g., loss, accuracy, epoch) + metrics: Dictionary of metrics to report (e.g., loss, accuracy) + epoch: The current epoch to be saved """ logger.info("Saving checkpoint and reporting metrics...") with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Perform a distributed checkpoint with DCP - epoch = metrics.get("epoch", 0) state_dict = {"app": AppState(model, optimizer, epoch)} dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir) From d061e11ed398115b8eef5a0ae01f2117ecfb50c0 Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Tue, 21 Oct 2025 15:22:35 -0700 Subject: [PATCH 4/6] start epoch set to 0 if no epoch in checkpoint Signed-off-by: JasonLi1909 --- doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb | 4 ++-- doc/source/train/examples/pytorch/pytorch-fsdp/README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index f2e672ddf859..52d9e4db9e1d 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -222,7 +222,7 @@ " loaded_checkpoint = ray.train.get_checkpoint()\n", " if loaded_checkpoint:\n", " latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", - " start_epoch = latest_epoch + 1\n", + " start_epoch = latest_epoch + 1 if latest_epoch else 0\n", " logger.info(f\"Resuming training from epoch {start_epoch}\")\n", "\n", " # Prepare training data\n", @@ -534,7 +534,7 @@ " and optimizer.\n", " \"\"\"\n", "\n", - " def __init__(self, model, optimizer=None, epoch=0):\n", + " def __init__(self, model, optimizer=None, epoch=None):\n", " self.model = model\n", " self.optimizer = optimizer\n", " self.epoch = epoch\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index aeb048f42539..0595a84e6599 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -170,7 +170,7 @@ def train_func(config): loaded_checkpoint = ray.train.get_checkpoint() if loaded_checkpoint: latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) - start_epoch = latest_epoch + 1 + start_epoch = latest_epoch + 1 if latest_epoch else 0 logger.info(f"Resuming training from epoch {start_epoch}") # Prepare training data @@ -447,7 +447,7 @@ class AppState(Stateful): and optimizer. """ - def __init__(self, model, optimizer=None, epoch=0): + def __init__(self, model, optimizer=None, epoch=None): self.model = model self.optimizer = optimizer self.epoch = epoch From 4cd8251a2ec9c490eca695219ab7404b2c92763d Mon Sep 17 00:00:00 2001 From: JasonLi1909 Date: Wed, 22 Oct 2025 10:31:45 -0700 Subject: [PATCH 5/6] fix Signed-off-by: JasonLi1909 --- doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb | 4 ++-- doc/source/train/examples/pytorch/pytorch-fsdp/README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb index 52d9e4db9e1d..2f6474654901 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb @@ -222,7 +222,7 @@ " loaded_checkpoint = ray.train.get_checkpoint()\n", " if loaded_checkpoint:\n", " latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n", - " start_epoch = latest_epoch + 1 if latest_epoch else 0\n", + " start_epoch = latest_epoch + 1 if latest_epoch != None else 0\n", " logger.info(f\"Resuming training from epoch {start_epoch}\")\n", "\n", " # Prepare training data\n", @@ -586,7 +586,7 @@ "metadata": {}, "outputs": [], "source": [ - "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int:\n", + "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None:\n", " \"\"\"Load an FSDP checkpoint into the model and optimizer.\n", " \n", " This function handles distributed checkpoint loading with automatic resharding\n", diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index 0595a84e6599..d13e3733c898 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -170,7 +170,7 @@ def train_func(config): loaded_checkpoint = ray.train.get_checkpoint() if loaded_checkpoint: latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) - start_epoch = latest_epoch + 1 if latest_epoch else 0 + start_epoch = latest_epoch + 1 if latest_epoch != None else 0 logger.info(f"Resuming training from epoch {start_epoch}") # Prepare training data @@ -486,7 +486,7 @@ import torch.distributed.checkpoint as dcp ```python -def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int: +def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None: """Load an FSDP checkpoint into the model and optimizer. This function handles distributed checkpoint loading with automatic resharding From a3bd0a6064acd6326d6b71c0a6c42847a36678c9 Mon Sep 17 00:00:00 2001 From: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:14:16 -0700 Subject: [PATCH 6/6] Update doc/source/train/examples/pytorch/pytorch-fsdp/README.md Co-authored-by: matthewdeng Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> --- doc/source/train/examples/pytorch/pytorch-fsdp/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md index d13e3733c898..8292bf7cd49a 100644 --- a/doc/source/train/examples/pytorch/pytorch-fsdp/README.md +++ b/doc/source/train/examples/pytorch/pytorch-fsdp/README.md @@ -170,7 +170,7 @@ def train_func(config): loaded_checkpoint = ray.train.get_checkpoint() if loaded_checkpoint: latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint) - start_epoch = latest_epoch + 1 if latest_epoch != None else 0 + start_epoch = latest_epoch + 1 if latest_epoch is not None else 0 logger.info(f"Resuming training from epoch {start_epoch}") # Prepare training data