diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index abe1b7ed98..b938a9d321 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,13 +1,23 @@ # SFT Algorithm Configuration sft: - num_steps: 100 - #val_period: 10 - #val_at_start: true - #checkpoint_dir: "results/sft" + num_steps: 20 + val_period: 10 + val_batches: 8 + val_global_batch_size: 32 + val_micro_batch_size: 2 + val_at_start: true + +checkpointing: + enabled: true + checkpoint_dir: "results/sft" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 3 + save_period: 10 policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 8 + train_global_batch_size: 32 train_micro_batch_size: 2 learning_rate: 5.0e-6 max_total_sequence_length: 1024 diff --git a/examples/run_sft.py b/examples/run_sft.py index f8649a9484..de0ef8c1ce 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -15,12 +15,20 @@ import argparse import os import pprint +from typing import Dict, Any from omegaconf import OmegaConf from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.config import load_config from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log +from transformers import AutoTokenizer +from nemo_reinforcer.models.policy import PolicyConfig def parse_args(): @@ -39,6 +47,80 @@ def parse_args(): return args, overrides +# ======================================================= +# Data Processing +# ======================================================= +def sft_preprocessor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for SFT training.""" + message_log = get_formatted_message_log( + datum_dict["messages"], tokenizer, task_data_spec + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + +def setup_data(data_config: DataConfig, policy_config: PolicyConfig): + print("\nā–¶ Setting up data...") + data_cls = data_config["dataset_name"] + if data_cls == "open_assistant": + data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") + elif data_cls == "squad": + data = hf_datasets.SquadDataset() + else: + raise ValueError(f"Unknown dataset class: {data_cls}") + print( + f" āœ“ Training and validation datasets loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively." + ) + + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + sft_task_spec = data.task_spec + + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, tokenizer, sft_task_spec + + def main(): """Main entry point.""" # Parse arguments @@ -47,13 +129,12 @@ def main(): if not args.config: args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml") - config = OmegaConf.load(args.config) + config = load_config(args.config) print(f"Loaded configuration from: {args.config}") if overrides: - override_conf = OmegaConf.from_cli() - print(f"Overrides: {override_conf}") - config = OmegaConf.merge(config, override_conf) + print(f"Overrides: {overrides}") + config = OmegaConf.merge(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") @@ -66,24 +147,33 @@ def main(): print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") init_ray() + + # setup data + dataset, val_dataset, tokenizer, sft_task_spec = setup_data( + config["data"], config["policy"] + ) ( policy, cluster, - dataloader, - tokenizer, + train_dataloader, + val_dataloader, loss_fn, - master_config, logger, - sft_task_spec, - ) = setup(config) + checkpointer, + sft_save_state, + master_config, + ) = setup(config, dataset, val_dataset) sft_train( policy, - dataloader, + train_dataloader, + val_dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec, + checkpointer, + sft_save_state, ) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 1f948b1b5c..402b5bad92 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -11,32 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Tuple, TypedDict - -from torch.utils.data import DataLoader -from transformers import AutoTokenizer +import os +from pathlib import Path +from typing import Optional, Tuple, TypedDict +import torch +from torchdata.stateful_dataloader import StatefulDataLoader from nemo_reinforcer.algorithms.loss_functions import ( NLLLoss, ) -from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn -from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.interfaces import TaskDataSpec from nemo_reinforcer.data.llm_message_utils import ( add_loss_mask_to_message_log, batched_message_log_to_flat_message, - get_formatted_message_log, ) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.models.interfaces import PolicyInterface from nemo_reinforcer.models.policy.hf_policy import HfPolicy from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig from nemo_reinforcer.utils.logger import Logger, LoggerConfig from nemo_reinforcer.utils.timer import Timer +class SFTSaveState(TypedDict): + step: int + val_loss: float + consumed_samples: int + + +def _default_sft_save_state() -> SFTSaveState: + return { + "step": 0, + "consumed_samples": 0, + } + + class SFTConfig(TypedDict): num_steps: int + val_period: int + val_batches: int + val_global_batch_size: int + val_micro_batch_size: int + val_at_start: bool class MasterConfig(TypedDict): @@ -45,51 +65,26 @@ class MasterConfig(TypedDict): sft: SFTConfig logger: LoggerConfig cluster: ClusterConfig + checkpointing: CheckpointingConfig -def sft_preprocessor( - datum_dict: Dict[str, Any], - task_data_spec: TaskDataSpec, - tokenizer, - max_seq_length: int, - idx: int, -) -> DatumSpec: - """Process a datum dictionary for SFT training.""" - message_log = get_formatted_message_log( - datum_dict["messages"], tokenizer, task_data_spec - ) - - length = sum(len(m["token_ids"]) for m in message_log) - - loss_multiplier = 1.0 - if length > max_seq_length: - # make smaller and mask out - for message in message_log: - message["token_ids"] = message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] - loss_multiplier = 0.0 - - output = { - "message_log": message_log, - "length": length, - "extra_env_info": None, - "loss_multiplier": loss_multiplier, - "idx": idx, - } - return output - - +# ======================================================= +# Setup & Initialization +# ======================================================= def setup( master_config: MasterConfig, + train_dataset: AllTaskProcessedDataset, + val_dataset: AllTaskProcessedDataset, ) -> Tuple[ HfPolicy, RayVirtualCluster, - DataLoader, - AutoTokenizer, + StatefulDataLoader, + StatefulDataLoader, NLLLoss, MasterConfig, Logger, + TaskDataSpec, + SFTSaveState, ]: """Main entry point for running SFT algorithm. @@ -101,36 +96,57 @@ def setup( data_config = master_config["data"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] - - ## TODO: unify this with grpo - data_cls = data_config["dataset_name"] - if data_cls == "open_assistant": - data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") - elif data_cls == "squad": - data = hf_datasets.SquadDataset() - else: - raise ValueError(f"Unknown dataset class: {data_cls}") - - base_dataset = data.formatted_ds["train"] - sft_task_spec = data.task_spec - - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) - - dataset = AllTaskProcessedDataset( - base_dataset, - tokenizer, - sft_task_spec, - sft_preprocessor, - max_seq_length=data_config["max_input_seq_length"], + sft_config = master_config["sft"] + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + sft_save_state: Optional[SFTSaveState] = checkpointer.load_training_info( + last_checkpoint_path ) + # config validation checks + if master_config["checkpointing"]["enabled"]: + assert master_config["checkpointing"]["save_period"] > 0 + assert ( + master_config["checkpointing"]["save_period"] + % master_config["sft"]["val_period"] + == 0 + ), ( + f"Checkpointing save period {master_config['checkpointing']['save_period']} " + f"must be a multiple of validation period {master_config['sft']['val_period']}" + f", or we won't know what metric to save!" + ) - dataloader = DataLoader( - dataset, + # ========================== + # Data + # ========================== + train_dataloader = StatefulDataLoader( + train_dataset, batch_size=policy_config["train_global_batch_size"], + shuffle=True, + collate_fn=rl_collate_fn, + ) + + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + train_dataloader.load_state_dict(dataloader_state_dict) + + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=sft_config["val_global_batch_size"], shuffle=False, - collate_fn=rl_collate_fn, ## TODO: change this for sft! or make it more general + collate_fn=rl_collate_fn, + drop_last=True, ) + # ========================== + # Cluster + # ========================== + print("\nā–¶ Setting up compute cluster...") cluster = RayVirtualCluster( name="sft_cluster", bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] @@ -139,72 +155,287 @@ def setup( num_gpus_per_node=cluster_config["gpus_per_node"], max_colocated_worker_groups=1, ) - - policy = HfPolicy(cluster=cluster, config=policy_config) + print(f" āœ“ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Training + # ========================== + print("\nā–¶ Setting up model...") + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=Path(last_checkpoint_path) / "policy.pt" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + if last_checkpoint_path + else None, + init_optimizer=True, + ) loss_fn = NLLLoss() + print(f" āœ“ Model initialized") logger = Logger(logger_config) + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + return ( policy, cluster, - dataloader, - tokenizer, + train_dataloader, + val_dataloader, loss_fn, - master_config, logger, - sft_task_spec, + checkpointer, + sft_save_state, + master_config, ) -def sft_train( - policy, dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec +# ======================================================= +# Training & Validation +# ======================================================= +def validate( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + loss_fn, + step: int, + master_config: MasterConfig, + sft_task_spec: TaskDataSpec, + val_batches: int, + val_batch_size: int, + val_mbs: int, ): - # Run basic sft training + """Run validation on the validation dataset.""" + if val_dataloader is None: + print(" āš ļø No validation dataloader provided, skipping validation") + return + timer = Timer() - policy.prepare_for_training() + with timer.time("total_validation_time"): + print(f"ā–¶ Starting validation at step {step}...") - for step, batch in enumerate(dataloader): - timer.start("sft_train_step") + # Show a progress indicator for validation + # val_total = len(val_dataloader) - timer.start("data_processing") - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - roles_to_train_on=["assistant"], - ) + val_metrics = {"val_loss": 0.0} + + for batch_idx, val_batch in enumerate(val_dataloader): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + val_batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + val_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + val_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": val_batch["loss_multiplier"], + } + ) + + ## just run model fwd + val_results = policy.train( + val_data, + loss_fn, + eval_mode=True, + gbs=val_batch_size, + mbs=val_mbs, + ) + val_metrics["val_loss"] += float(val_results["loss"]) + + if val_batches > 0 and batch_idx >= val_batches: + break + + val_metrics["val_loss"] /= val_batches + + # Calculate validation metrics + policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\nšŸ“Š Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + + # Print timing information + print("\n ā±ļø Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, - ) - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } +def sft_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + sft_task_spec, + checkpointer, + sft_save_state, +): + # Run basic sft training + timer = Timer() + + if sft_save_state is None: + sft_save_state = _default_sft_save_state() + step = 0 + else: + step = ( + sft_save_state["step"] + 1 + ) # N+1 because the checkpoint is _after_ SFT iteration N + + sft_config = master_config["sft"] + # Validation configuration + val_period = sft_config["val_period"] + val_at_start = sft_config["val_at_start"] + + # Run validation at the start if configured + if val_at_start and step == 0: + print("\nšŸ” Running initial validation...") + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], ) - timer.stop("data_processing") - ## train_data.to("cpu") - train_results = policy.train(train_data, loss_fn) - timer.stop("sft_train_step") + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation") + + policy.prepare_for_training() + + for batch in train_dataloader: + print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + + with timer.time("total_step_time"): + # Prepare batch and generate responses + print("ā–¶ Preparing batch...") + with timer.time("data_processing"): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + + ## train_data.to("cpu") + print("ā–¶ Taking a training step...") + train_results = policy.train(train_data, loss_fn) + + # Run validation if it's a validation step + if val_period > 0 and (step + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=step + 1, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], + ) + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + ## Checkpointing + sft_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): # +1 because step is 0-indexed + sft_save_state["step"] = step + sft_save_state["val_loss"] = val_metrics["val_loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, sft_save_state, master_config + ) + policy.save_checkpoint( + os.path.join(checkpoint_path, "policy.pt"), + os.path.join(checkpoint_path, "policy_optimizer.pt"), + ## NOTE: below is a workaround to avoid a bug with checkpointing + ## this should be removed once the bug is fixed + offload_to_cpu=False, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + losses = train_results["loss"] timing_metrics = timer.get_timing_metrics(reduction_op="sum") - print(f"Step {step} completed. Loss: {losses[-1].item()}") + metrics = { + "loss": losses.numpy(), + } + + print("\nšŸ“Š Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\nā±ļø Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - logger.log_metrics( - {"loss": losses[-1].item()}, - step, - prefix="train", - ) - logger.log_metrics(timing_metrics, step, prefix="timing/train") timer.reset() + step += 1 - if step >= master_config["sft"]["num_steps"] - 1: + if step >= master_config["sft"]["num_steps"]: break diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 16b048845a..6f44574c71 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -225,10 +225,19 @@ def get_gpu_info(self): }, } - def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - mbs = self.cfg["train_micro_batch_size"] - gbs = self.cfg["train_global_batch_size"] + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // torch.distributed.get_world_size() dataset_size = data.get("input_ids").shape[0] @@ -271,16 +280,18 @@ def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: loss, loss_metrics = loss_fn(logits, mb) # Backward pass - loss.backward() + if not eval_mode: + loss.backward() mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) # Clip gradients - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + if not eval_mode: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - # Update parameters - self.optimizer.step() - self.scheduler.step() + # Update parameters + self.optimizer.step() + self.scheduler.step() losses.append(torch.tensor(mb_losses).mean().item()) # Compute global loss across all ranks @@ -731,9 +742,16 @@ def move_to_cpu(self, model): return model - def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): # Config to save full state dict on rank 0, offloaded to CPU - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + state_dict_config = FullStateDictConfig( + offload_to_cpu=offload_to_cpu, rank0_only=True + ) with FullyShardedDataParallel.state_dict_type( self.model, @@ -869,7 +887,14 @@ def get_reference_policy_logprobs( ) return logprobs - def train(self, data: BatchedDataDict, loss_fn: LossFunction): + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ): """Train the policy on a batch of data with a given loss function.""" # Shard and replicate the batch shards = self.dp_size @@ -879,7 +904,14 @@ def train(self, data: BatchedDataDict, loss_fn: LossFunction): # Train each shard in parallel futures = self.worker_group.run_all_workers_multiple_data( - "train", sharded_data, common_kwargs={"loss_fn": loss_fn} + "train", + sharded_data, + common_kwargs={ + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": gbs, + "mbs": mbs, + }, ) results = self.worker_group.get_all_worker_results(futures) @@ -992,12 +1024,18 @@ def offload_after_refit(self): ) ray.get(futures) - def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path, optimizer_path, + offload_to_cpu=offload_to_cpu, respect_tied_workers=True, ) ray.get(futures)