Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions doc/source/train/examples/pytorch/pytorch-fsdp/README.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -218,9 +218,12 @@
" optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))\n",
"\n",
" # Load from checkpoint if available (for resuming training)\n",
" start_epoch = 0\n",
" loaded_checkpoint = ray.train.get_checkpoint()\n",
" if loaded_checkpoint:\n",
" load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n",
" latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n",
" start_epoch = latest_epoch + 1 if latest_epoch != None else 0\n",
" logger.info(f\"Resuming training from epoch {start_epoch}\")\n",
"\n",
" # Prepare training data\n",
" transform = Compose([\n",
Expand Down Expand Up @@ -258,7 +261,7 @@
" num_batches = 0\n",
" epochs = config.get('epochs', 5)\n",
" \n",
" for epoch in range(epochs):\n",
" for epoch in range(start_epoch, epochs):\n",
" # Set epoch for distributed sampler to ensure proper shuffling\n",
" if ray.train.get_context().get_world_size() > 1:\n",
" train_loader.sampler.set_epoch(epoch)\n",
Expand All @@ -282,8 +285,8 @@
"\n",
" # Report metrics and save checkpoint after each epoch\n",
" avg_loss = running_loss / num_batches\n",
" metrics = {\"loss\": avg_loss, \"epoch\": epoch}\n",
" report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics)\n",
" metrics = {\"loss\": avg_loss}\n",
" report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)\n",
"\n",
" # Log metrics from rank 0 only to avoid duplicate outputs\n",
" if world_rank == 0:\n",
Expand Down Expand Up @@ -518,7 +521,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -531,16 +534,18 @@
" and optimizer.\n",
" \"\"\"\n",
"\n",
" def __init__(self, model, optimizer=None):\n",
" def __init__(self, model, optimizer=None, epoch=None):\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.epoch = epoch\n",
"\n",
" def state_dict(self):\n",
" # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT\n",
" model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)\n",
" return {\n",
" \"model\": model_state_dict,\n",
" \"optim\": optimizer_state_dict\n",
" \"optim\": optimizer_state_dict,\n",
" \"epoch\": self.epoch\n",
" }\n",
"\n",
" def load_state_dict(self, state_dict):\n",
Expand All @@ -550,7 +555,10 @@
" self.optimizer,\n",
" model_state_dict=state_dict[\"model\"],\n",
" optim_state_dict=state_dict[\"optim\"],\n",
" )"
" )\n",
" # Load epoch information if available\n",
" if \"epoch\" in state_dict:\n",
" self.epoch = state_dict[\"epoch\"]"
]
},
{
Expand All @@ -574,11 +582,11 @@
},
{
"cell_type": "code",
"execution_count": 134,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint):\n",
"def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None:\n",
" \"\"\"Load an FSDP checkpoint into the model and optimizer.\n",
" \n",
" This function handles distributed checkpoint loading with automatic resharding\n",
Expand All @@ -589,21 +597,26 @@
" model: The FSDP-wrapped model to load state into\n",
" optimizer: The optimizer to load state into\n",
" ckpt: Ray Train checkpoint containing the saved state\n",
"\n",
" Returns:\n",
" int: The epoch number saved within the checkpoint.\n",
" \"\"\"\n",
" logger.info(\"Loading distributed checkpoint for resuming training...\")\n",
" \n",
" try:\n",
" with ckpt.as_directory() as checkpoint_dir:\n",
" # Create state wrapper for DCP loading\n",
" state_dict = {\"app\": AppState(model, optimizer)}\n",
" app_state = AppState(model, optimizer)\n",
" state_dict = {\"app\": app_state}\n",
" \n",
" # Load the distributed checkpoint\n",
" dcp.load(\n",
" state_dict=state_dict,\n",
" checkpoint_id=checkpoint_dir\n",
" )\n",
" \n",
" logger.info(\"Successfully loaded distributed checkpoint\")\n",
" logger.info(f\"Successfully loaded distributed checkpoint from epoch {app_state.epoch}\")\n",
" return app_state.epoch\n",
" except Exception as e:\n",
" logger.error(f\"Failed to load checkpoint: {e}\")\n",
" raise RuntimeError(f\"Checkpoint loading failed: {e}\") from e"
Expand All @@ -620,12 +633,12 @@
},
{
"cell_type": "code",
"execution_count": 135,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def report_metrics_and_save_fsdp_checkpoint(\n",
" model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict\n",
" model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0\n",
") -> None:\n",
" \"\"\"Report training metrics and save an FSDP checkpoint.\n",
" \n",
Expand All @@ -637,12 +650,13 @@
" model: The FSDP-wrapped model to checkpoint\n",
" optimizer: The optimizer to checkpoint\n",
" metrics: Dictionary of metrics to report (e.g., loss, accuracy)\n",
" epoch: The current epoch to be saved\n",
" \"\"\"\n",
" logger.info(\"Saving checkpoint and reporting metrics...\")\n",
" \n",
" with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
" # Perform a distributed checkpoint with DCP\n",
" state_dict = {\"app\": AppState(model, optimizer)}\n",
" state_dict = {\"app\": AppState(model, optimizer, epoch)}\n",
" dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)\n",
"\n",
" # Report each checkpoint shard from all workers\n",
Expand Down
36 changes: 25 additions & 11 deletions doc/source/train/examples/pytorch/pytorch-fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,12 @@ def train_func(config):
optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))

# Load from checkpoint if available (for resuming training)
start_epoch = 0
loaded_checkpoint = ray.train.get_checkpoint()
if loaded_checkpoint:
load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)
latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)
start_epoch = latest_epoch + 1 if latest_epoch is not None else 0
logger.info(f"Resuming training from epoch {start_epoch}")

# Prepare training data
transform = Compose([
Expand Down Expand Up @@ -206,7 +209,7 @@ def train_func(config):
num_batches = 0
epochs = config.get('epochs', 5)

for epoch in range(epochs):
for epoch in range(start_epoch, epochs):
# Set epoch for distributed sampler to ensure proper shuffling
if ray.train.get_context().get_world_size() > 1:
train_loader.sampler.set_epoch(epoch)
Expand All @@ -230,8 +233,8 @@ def train_func(config):

# Report metrics and save checkpoint after each epoch
avg_loss = running_loss / num_batches
metrics = {"loss": avg_loss, "epoch": epoch}
report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics)
metrics = {"loss": avg_loss}
report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)

# Log metrics from rank 0 only to avoid duplicate outputs
if world_rank == 0:
Expand Down Expand Up @@ -444,16 +447,18 @@ class AppState(Stateful):
and optimizer.
"""

def __init__(self, model, optimizer=None):
def __init__(self, model, optimizer=None, epoch=None):
self.model = model
self.optimizer = optimizer
self.epoch = epoch

def state_dict(self):
# this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
"optim": optimizer_state_dict,
"epoch": self.epoch
}

def load_state_dict(self, state_dict):
Expand All @@ -464,6 +469,9 @@ class AppState(Stateful):
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
# Load epoch information if available
if "epoch" in state_dict:
self.epoch = state_dict["epoch"]
```

### Load distributed model from checkpoint
Expand All @@ -478,7 +486,7 @@ import torch.distributed.checkpoint as dcp


```python
def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint):
def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None:
"""Load an FSDP checkpoint into the model and optimizer.

This function handles distributed checkpoint loading with automatic resharding
Expand All @@ -489,21 +497,26 @@ def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ck
model: The FSDP-wrapped model to load state into
optimizer: The optimizer to load state into
ckpt: Ray Train checkpoint containing the saved state

Returns:
int: The epoch number saved within the checkpoint.
"""
logger.info("Loading distributed checkpoint for resuming training...")

try:
with ckpt.as_directory() as checkpoint_dir:
# Create state wrapper for DCP loading
state_dict = {"app": AppState(model, optimizer)}
app_state = AppState(model, optimizer)
state_dict = {"app": app_state}

# Load the distributed checkpoint
dcp.load(
state_dict=state_dict,
checkpoint_id=checkpoint_dir
)

logger.info("Successfully loaded distributed checkpoint")
logger.info(f"Successfully loaded distributed checkpoint from epoch {app_state.epoch}")
return app_state.epoch
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise RuntimeError(f"Checkpoint loading failed: {e}") from e
Expand All @@ -516,7 +529,7 @@ The following function handles periodic checkpoint saving during training, combi

```python
def report_metrics_and_save_fsdp_checkpoint(
model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict
model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0
) -> None:
"""Report training metrics and save an FSDP checkpoint.

Expand All @@ -528,12 +541,13 @@ def report_metrics_and_save_fsdp_checkpoint(
model: The FSDP-wrapped model to checkpoint
optimizer: The optimizer to checkpoint
metrics: Dictionary of metrics to report (e.g., loss, accuracy)
epoch: The current epoch to be saved
"""
logger.info("Saving checkpoint and reporting metrics...")

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
# Perform a distributed checkpoint with DCP
state_dict = {"app": AppState(model, optimizer)}
state_dict = {"app": AppState(model, optimizer, epoch)}
dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)

# Report each checkpoint shard from all workers
Expand Down