diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml new file mode 100644 index 000000000..66c8d3bd7 --- /dev/null +++ b/configs/mcli/mitchish-instruct.yml @@ -0,0 +1,102 @@ +run_name: olmo-7b-instruct +image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 +gpu_num: 64 +#gpu_num: 8 +#cluster: r12z3 +cluster: r7z2 +gpu_type: a100_40gb +integrations: + - integration_type: git_repo + git_repo: allenai/LLM + git_branch: epwalsh/tulu-fine-tune + pip_install: -e . + ssh_clone: true +command: |- + # NOTE: For some reason getting S3 and R2 authentication working both from the command line and + # from Python proved to be challenging, maybe because Mosaic's server are in Australia. + # In the end I had to use separate methods to get everything working: + # 1. AWS config files for CLI access. + # 2. Environment variables for boto3 access (to S3 only). + # Since we only need CLI access prior to training, we remove the AWS config files before launching + # the training job. Otherwise the environment variables won't work. + + # Install aws cli + apt-get update + apt-get install zip unzip + curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" + unzip awscliv2.zip + sudo ./aws/install + + cd LLM + + pip freeze + + # Prepare environment including AWS config files for both S3 and R2 access. + mkdir -p /root/.cache/torch + mkdir /root/checkpoint-unsharded + mkdir /root/data + mkdir /root/.aws + touch /root/.aws/credentials /root/.aws/config + echo '[s3]' >> /root/.aws/credentials + echo "aws_access_key_id = ${AWS_ACCESS_KEY_ID}" >> /root/.aws/credentials + echo "aws_secret_access_key = ${AWS_SECRET_ACCESS_KEY}" >> /root/.aws/credentials + echo '' >> /root/.aws/credentials + echo '[r2]' >> /root/.aws/credentials + echo "aws_access_key_id = ${R2_ACCESS_KEY_ID}" >> /root/.aws/credentials + echo "aws_secret_access_key = ${R2_SECRET_ACCESS_KEY}" >> /root/.aws/credentials + echo "[default]" >> /root/.aws/config + echo "region = auto" >> /root/.aws/config + echo "output = json" >> /root/.aws/config + + #export S3_PROFILE=s3 + #export R2_PROFILE=r2 + export OMP_NUM_THREADS=8 + export LOG_FILTER_TYPE=local_rank0_only + + # Download checkpoint (everything except optimizer state). + checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded + echo "Downloading checkpoint '${checkpoint}'..." + + # Download config. + aws s3 cp --profile=r2 --region=auto \ + --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + "${checkpoint}/config.yaml" /root/checkpoint-unsharded/ + + # Download trainer state. + aws s3 cp --profile=r2 --region=auto \ + --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + "${checkpoint}/train.pt" /root/checkpoint-unsharded/ + + # Download model weights. + aws s3 cp --profile=r2 --region=auto \ + --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + "${checkpoint}/model.pt" /root/checkpoint-unsharded/ + + # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3. + rm -rf /root/.aws + + # Download data (it's small enough so might as well). + echo "Downloading data..." + aws s3 cp \ + s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \ + /root/data/data.npy + + torchrun \ + --master_addr "$MASTER_ADDR" \ + --master_port "$MASTER_PORT" \ + --nnodes "$NUM_NODES" \ + --node_rank "$NODE_RANK" \ + --nproc_per_node 8 \ + scripts/train.py configs/mitchish-instruct.yaml \ + --run_name=mitchish-mcli-2.5T-instruct-2e-6 \ + --optimizer.learning_rate=2e-6 \ + --save_overwrite \ + --time_limit=169200 \ + --data.paths=[/root/data/data.npy] \ + --save_interval_unsharded=10000 \ + --load_path=/root/checkpoint-unsharded \ + --reset_optimizer_state \ + --reset_trainer_state \ + --compile=null \ + --activation_checkpointing=fine_grained \ + --fsdp.wrapping_strategy=size_based diff --git a/configs/mitchish-instruct.yaml b/configs/mitchish-instruct.yaml new file mode 100644 index 000000000..a21f1ade6 --- /dev/null +++ b/configs/mitchish-instruct.yaml @@ -0,0 +1,182 @@ +run_name: v1_5-mix-medium-mitch-ish +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmo-medium + group: v1_5-mix + +model: + d_model: 4096 + n_heads: 32 + n_layers: 32 + # mlp_ratio: 6 + mlp_hidden_size: 22016 + weight_tying: false + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 0 + pad_token_id: 1 + init_device: meta + init_fn: mitchell + +compile: + fullgraph: false + +optimizer: + name: adamw + learning_rate: 2e-5 + weight_decay: 0.0 + betas: + - 0.9 + - 0.999 + metrics_log_interval: 10 + +scheduler: + name: linear_with_warmup + t_warmup: 100 + alpha_f: 0.001 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: runs/${run_name} +remote_save_folder: s3://ai2-llm/checkpoints/7b/${run_name} +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: -1 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: null # getting errors on LUMI right now +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 2ep +global_train_batch_size: 128 +device_train_microbatch_size: 2 +time_limit: null + +precision: amp_bf16 + +fsdp: + wrapping_strategy: by_block + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} +evaluators: + - label: all-small-ppl-validation + data: + num_workers: 0 + drop_last: true + # pin_memory: true + # prefetch_factor: 1 + # persistent_workers: false + # timeout: 0 + datasets: + 4chan-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy + c4_100_domains-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy + c4_en-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy + gab-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy + ice-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy + m2d2_s2orc-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy + m2d2_wiki-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy + manosphere-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy + mc4_en-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy + pile-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy + ptb-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy + twitterAEE-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy + wikitext_103-validation: + - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy + + ########################## + # Downstream evaluations # + ########################## + - label: piqa + type: downstream + + - label: hellaswag + type: downstream + + - label: winogrande + type: downstream + + - label: openbook_qa + type: downstream + + # - label: boolq # requires implemention of the pmi_dc matrix + # type: downstream + + - label: sciq + type: downstream + + - label: arc_easy + type: downstream + + # - label: arc_challenge # requires implemention of the pmi_dc matrix + # type: downstream + + - label: copa + type: downstream + + - label: rte + type: downstream + + - label: commitment_bank + type: downstream + + - label: mrpc + type: downstream + + - label: sst2 + type: downstream + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 1 + persistent_workers: true + timeout: 0 + generate_attention_mask: true + paths: + - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index ba875a801..1e2aba816 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -502,8 +502,10 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]: barrier() # Finally if all went well replace the temporary directory with the actual - # checkpoint directory. - if get_fs_local_rank() == 0: + # checkpoint directory. Note that for some checkpointers the local rank 0 might + # not use this folder, so it may not exist; FullCheckpointer, for example, only creates + # this for global rank 0. + if get_fs_local_rank() == 0 and checkpoint_dir_tmp.exists(): # Replace temp directory with target checkpoint directory. try: checkpoint_dir_tmp.replace(checkpoint_dir) diff --git a/olmo/config.py b/olmo/config.py index 9a353f540..b4b0576f9 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -510,6 +510,7 @@ class DataConfig(BaseConfig): paths: Optional[List[str]] = None datasets: Optional[Dict[str, List[str]]] = None pad_direction: PaddingDirection = PaddingDirection.right + generate_attention_mask: bool = False num_workers: int = 0 drop_last: bool = False pin_memory: bool = False @@ -683,7 +684,7 @@ class TrainConfig(BaseConfig): Used to seed all initial RNG states. """ - epoch: int = 0 + epoch: Optional[int] = None """ Increment this when starting a new epoch. """ @@ -832,6 +833,11 @@ class TrainConfig(BaseConfig): curve (according to the current learning rate schedule settings), and continues from there. """ + reset_trainer_state: bool = False + """ + When this is set we don't restore the trainer state from a checkpoint. + """ + sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy """ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training. diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index 65b6599bd..bc08ff863 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -37,6 +37,8 @@ def build_memmap_dataset( chunk_size=train_config.model.max_sequence_length, metadata=metadata, include_instance_metadata=include_instance_metadata, + pad_token_id=train_config.model.pad_token_id, + generate_attention_mask=data_config.generate_attention_mask, ) @@ -93,7 +95,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: IterableDataset( dataset, # type: ignore train_config.global_train_batch_size, - seed=train_config.seed + train_config.epoch, + seed=train_config.seed + (train_config.epoch or 0), shuffle=True, drop_last=train_config.data.drop_last, work_dir=work_dir, diff --git a/olmo/data/iterable_dataset.py b/olmo/data/iterable_dataset.py index 7a1771649..bae152446 100644 --- a/olmo/data/iterable_dataset.py +++ b/olmo/data/iterable_dataset.py @@ -64,21 +64,26 @@ def __init__( assert global_batch_size % self.world_size == 0 self.device_batch_size = global_batch_size // self.world_size self.global_indices_file: Optional[Path] = None + self.work_dir = work_dir if work_dir is not None: - self.global_indices_file = Path(work_dir) / "global_indices.npy" - if self.fs_local_rank == 0: - log.info("Saving global data order indices...") - self.global_indices_file.parent.mkdir(parents=True, exist_ok=True) - global_indices = self._build_global_indices() - global_indices_mmap = np.memmap( - self.global_indices_file, dtype=np.uint32, mode="w+", shape=(len(global_indices),) - ) - global_indices_mmap[:] = global_indices - global_indices_mmap.flush() - del global_indices_mmap - log.info("Global data order indices saved to '%s'", self.global_indices_file) - barrier() + self._build_and_save_global_indices() + + def _build_and_save_global_indices(self): + assert self.work_dir is not None + self.global_indices_file = Path(self.work_dir) / "global_indices.npy" + if self.fs_local_rank == 0: + log.info("Saving global data order indices...") + self.global_indices_file.parent.mkdir(parents=True, exist_ok=True) + global_indices = self._build_global_indices() + global_indices_mmap = np.memmap( + self.global_indices_file, dtype=np.uint32, mode="w+", shape=(len(global_indices),) + ) + global_indices_mmap[:] = global_indices + global_indices_mmap.flush() + del global_indices_mmap + log.info("Global data order indices saved to '%s'", self.global_indices_file) + barrier() def _build_global_indices(self) -> np.ndarray: assert len(self.dataset) < np.iinfo(np.uint32).max @@ -111,6 +116,11 @@ def get_global_indices(self) -> np.ndarray: else: return self._build_global_indices() + def reshuffle(self): + self.seed += 1 + if self.work_dir is not None: + self._build_and_save_global_indices() + def __iter__(self) -> Iterator[Dict[str, Any]]: indices = self.get_global_indices() diff --git a/olmo/data/memmap_dataset.py b/olmo/data/memmap_dataset.py index 9d7d5fd99..69f2d85b9 100644 --- a/olmo/data/memmap_dataset.py +++ b/olmo/data/memmap_dataset.py @@ -44,6 +44,8 @@ def __init__( memmap_dtype=np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: bool = True, + generate_attention_mask: bool = False, + pad_token_id: Optional[int] = None, ): if not paths: raise ValueError("At least one path is required") @@ -60,6 +62,8 @@ def __init__( self.dtype = memmap_dtype self._item_size = self.dtype(0).itemsize self._include_instance_metadata = include_instance_metadata + self._generate_attention_mask = generate_attention_mask + self._pad_token_id = pad_token_id @property def chunk_size(self) -> int: @@ -137,10 +141,18 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # Read the data from file. input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index) + out: Dict[str, Any] = {"input_ids": input_ids} if self._include_instance_metadata: metadata = self._metadata[memmap_index] out["metadata"] = deepcopy(metadata) + + if self._generate_attention_mask: + assert self._pad_token_id is not None + attn_mask = torch.ones_like(input_ids) + attn_mask.masked_fill_(input_ids == self._pad_token_id, 0) + out["attention_mask"] = attn_mask + return out def __add__(self, other: MemMapDataset) -> MemMapDataset: diff --git a/olmo/train.py b/olmo/train.py index c31b7b459..2207b0552 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -102,7 +102,7 @@ class Trainer: train_loader: DataLoader device: torch.device evaluators: List[Evaluator] - epoch: int = 0 + epoch: Optional[int] = None global_step: int = 0 global_train_examples_seen_this_epoch: int = 0 """Tracks the global number of training examples seen in the current epoch for the purpose of restoring @@ -117,6 +117,18 @@ class Trainer: indices_file: Optional[TextIO] = None _start_time: float = 0.0 + @property + def dataset(self) -> IterableDataset: + assert isinstance(self.train_loader.dataset, IterableDataset) + return self.train_loader.dataset + + @property + def max_epochs(self) -> int: + if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"): + return int(self.cfg.max_duration[:-2].strip()) + else: + return 1 + @property def max_steps(self) -> int: if isinstance(self.cfg.max_duration, int): @@ -129,6 +141,11 @@ def max_steps(self) -> int: tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length steps_remaining = tokens_remaining // tokens_per_batch return self.global_step + steps_remaining + elif self.cfg.max_duration.endswith("ep"): + max_epochs = int(self.cfg.max_duration[:-2].strip()) + examples_per_epoch = self.dataset.total_size + steps_per_epoch = examples_per_epoch // self.cfg.global_train_batch_size + return max_epochs * steps_per_epoch else: # convert to float *first* to handle scientific notation return int(float(self.cfg.max_duration)) @@ -192,6 +209,8 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: self.epoch = 0 self.global_train_tokens_seen = 0 self.global_train_examples_seen_this_epoch = 0 + elif self.epoch is None: + self.epoch = checkpoint_epoch elif checkpoint_epoch != self.epoch: log.info(f"Starting new epoch (epoch = {self.epoch})") self.global_train_examples_seen_this_epoch = 0 @@ -207,9 +226,9 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: # that variable is meant to track the actual number of tokens trained on. if self.global_train_examples_seen_this_epoch > 0: - assert isinstance(self.train_loader.dataset, IterableDataset) + assert isinstance(self.dataset, IterableDataset) log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}") - self.train_loader.dataset.start_index = self.global_train_examples_seen_this_epoch + self.dataset.start_index = self.global_train_examples_seen_this_epoch # Reset learning rate and weight decay to the values from the config, not the checkpoint. log.info("Resetting learning rate...") @@ -344,6 +363,7 @@ def restore_sharded_checkpoint( local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True, + load_trainer_state: bool = True, sharded_checkpointer: Optional[ShardedCheckpointerType] = None, ): # Zero-gradients to avoid gathering them. @@ -356,7 +376,8 @@ def restore_sharded_checkpoint( local_cache=local_cache, load_optimizer_state=load_optimizer_state, ) - self.load_trainer_state_dict(trainer_state) + if load_trainer_state: + self.load_trainer_state_dict(trainer_state) barrier() def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]: @@ -374,7 +395,12 @@ def remove_unsharded_checkpoint(self, idx: int = 0): barrier() def restore_unsharded_checkpoint( - self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True + self, + load_path: PathOrStr, + local_cache: Optional[PathOrStr] = None, + *, + load_optimizer_state: bool = True, + load_trainer_state: bool = True, ): # Zero-gradients to avoid gathering them. self.optim.zero_grad(set_to_none=True) @@ -386,7 +412,8 @@ def restore_unsharded_checkpoint( local_cache=local_cache, load_optimizer_state=load_optimizer_state, ) - self.load_trainer_state_dict(trainer_state) + if load_trainer_state: + self.load_trainer_state_dict(trainer_state) barrier() def save_checkpoint( @@ -408,19 +435,24 @@ def restore_checkpoint( checkpoint_type: Optional[CheckpointType] = None, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, sharded_checkpointer: Optional[ShardedCheckpointerType] = None, ): if checkpoint_type == CheckpointType.unsharded or ( checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded") ): self.restore_unsharded_checkpoint( - load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state + load_path, + local_cache=local_cache, + load_optimizer_state=load_optimizer_state, + load_trainer_state=load_trainer_state, ) elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None: self.restore_sharded_checkpoint( load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state, + load_trainer_state=load_trainer_state, sharded_checkpointer=sharded_checkpointer, ) elif checkpoint_type is not None: @@ -832,145 +864,153 @@ def on_trace_ready(p): save_checkpoints: bool = True with torch_profiler as p: - for batch in self.train_loader: - # Bookkeeping. - # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all - # batches see the same number of tokens, which should be the case for language model pre-training - # (at least when drop_last=True). - # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that - # overhead. So for now I'm putting these assertions here so if the assumption is violated it will - # fail loudly. - batch_size, seq_len = batch["input_ids"].shape - assert seq_len == self.cfg.model.max_sequence_length - assert batch_size == self.cfg.device_train_batch_size - global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks - self.global_step += 1 - self.global_train_examples_seen_this_epoch += global_batch_size - self.global_train_tokens_seen += global_batch_size * seq_len - speed_monitor.batch_start( - self.global_train_tokens_seen, - batch_size * seq_len, # num tokens in batch for this device - # We start monitoring speed after the first batch since the first - # batch might be an outlier due to compiling and other initialization overhead. - record=not first_batch, - ) + for epoch in range(self.epoch or 0, self.max_epochs): + for batch in self.train_loader: + # Bookkeeping. + # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all + # batches see the same number of tokens, which should be the case for language model pre-training + # (at least when drop_last=True). + # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that + # overhead. So for now I'm putting these assertions here so if the assumption is violated it will + # fail loudly. + batch_size, seq_len = batch["input_ids"].shape + assert seq_len == self.cfg.model.max_sequence_length + assert batch_size == self.cfg.device_train_batch_size + global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks + self.global_step += 1 + self.global_train_examples_seen_this_epoch += global_batch_size + self.global_train_tokens_seen += global_batch_size * seq_len + speed_monitor.batch_start( + self.global_train_tokens_seen, + batch_size * seq_len, # num tokens in batch for this device + # We start monitoring speed after the first batch since the first + # batch might be an outlier due to compiling and other initialization overhead. + record=not first_batch, + ) - should_log_this_step = self.should_log_this_step() - - # Run train step on batch. - metrics = self.train_step(batch, reduce_global_loss=should_log_this_step) - - # Maybe collect other metrics. - if should_log_this_step: - # Speed metrics. - metrics.update(speed_monitor.check()) - # System metrics. - metrics.update(self.system_metrics()) - # Learning rate metrics. - metrics.update(lr_monitor.check()) - - # Log metrics to console. - if self.global_step % self.cfg.console_log_interval == 0: - self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics) - - # Log metrics to W&B. - if ( - wandb.run is not None - and self.cfg.wandb is not None - and self.global_step % self.cfg.wandb.log_interval == 0 - ): - wandb.log(metrics, step=self.global_step) - - # Check if/when run should be canceled. - if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0: - cancel_initiated, extra_steps = self.check_if_cancelled() - if cancel_initiated: - stop_at = ( - self.global_step + extra_steps - if stop_at is None - else min(self.global_step + extra_steps, stop_at) - ) + should_log_this_step = self.should_log_this_step() - # Maybe save sharded checkpoint. - if save_checkpoints and ( - cancel_initiated - or ( - self.global_step % self.cfg.save_interval == 0 - and self.cfg.save_num_checkpoints_to_keep != 0 - ) - ): - log.info("Saving checkpoint...") - checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded) - log.info(f"Checkpoint saved to {checkpoint_path}") - - # Remove any ephemeral checkpoints. - while self.ephemeral_checkpoints: - self.remove_ephemeral_checkpoint() - - # Reset speed monitor so that we don't count the time taken to save checkpoints. - speed_monitor.reset() - - # If the run was just canceled this will be the final checkpoint. - if cancel_initiated: - save_checkpoints = False - elif ( - self.cfg.save_interval_ephemeral is not None - and self.global_step % self.cfg.save_interval_ephemeral == 0 - ): - log.info("Saving ephemeral checkpoint...") - checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral) - log.info(f"Checkpoint saved to {checkpoint_path}") - - # Reset speed monitor so that we don't count the time taken to save checkpoints. - speed_monitor.reset() - - # Maybe save unsharded checkpoint. - if ( - save_checkpoints - and self.cfg.save_interval_unsharded is not None - and self.global_step % self.cfg.save_interval_unsharded == 0 - and self.cfg.save_num_unsharded_checkpoints_to_keep != 0 - ): - log.info("Saving unsharded checkpoint...") - checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded) - log.info(f"Unsharded checkpoint saved to {checkpoint_path}") - - # Reset speed monitor so that we don't count the time taken to save checkpoints. - speed_monitor.reset() - - # Maybe run evaluations. - if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0: - eval_metrics = self.eval() + # Run train step on batch. + metrics = self.train_step(batch, reduce_global_loss=should_log_this_step) + + # Maybe collect other metrics. + if should_log_this_step: + # Speed metrics. + metrics.update(speed_monitor.check()) + # System metrics. + metrics.update(self.system_metrics()) + # Learning rate metrics. + metrics.update(lr_monitor.check()) + + # Log metrics to console. + if self.global_step % self.cfg.console_log_interval == 0: + self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics) # Log metrics to W&B. - if wandb.run is not None: - wandb.log(eval_metrics, step=self.global_step) - - # Reset speed monitor so that we don't count the time taken to run evaluations. - speed_monitor.reset() - - # Reset model to 'train' mode. - self.fsdp_model.train() - - # End of batch. - first_batch = False - if p is not None: - p.step() - - if stop_at is not None and self.global_step >= stop_at: - break - - # Python Profiler stuff - # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. - if python_profiler is not None: - if self.global_step == 5: - python_profiler.enable() - elif self.global_step == 8: - python_profiler.disable() - python_profiler.print_stats(sort=SortKey.CUMULATIVE) - python_profiler = None - else: - log.info("Training loop complete") + if ( + wandb.run is not None + and self.cfg.wandb is not None + and self.global_step % self.cfg.wandb.log_interval == 0 + ): + wandb.log(metrics, step=self.global_step) + + # Check if/when run should be canceled. + if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0: + cancel_initiated, extra_steps = self.check_if_cancelled() + if cancel_initiated: + stop_at = ( + self.global_step + extra_steps + if stop_at is None + else min(self.global_step + extra_steps, stop_at) + ) + + # Maybe save sharded checkpoint. + if save_checkpoints and ( + cancel_initiated + or ( + self.global_step % self.cfg.save_interval == 0 + and self.cfg.save_num_checkpoints_to_keep != 0 + ) + ): + log.info("Saving checkpoint...") + checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded) + log.info(f"Checkpoint saved to {checkpoint_path}") + + # Remove any ephemeral checkpoints. + while self.ephemeral_checkpoints: + self.remove_ephemeral_checkpoint() + + # Reset speed monitor so that we don't count the time taken to save checkpoints. + speed_monitor.reset() + + # If the run was just canceled this will be the final checkpoint. + if cancel_initiated: + save_checkpoints = False + elif ( + self.cfg.save_interval_ephemeral is not None + and self.global_step % self.cfg.save_interval_ephemeral == 0 + ): + log.info("Saving ephemeral checkpoint...") + checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral) + log.info(f"Checkpoint saved to {checkpoint_path}") + + # Reset speed monitor so that we don't count the time taken to save checkpoints. + speed_monitor.reset() + + # Maybe save unsharded checkpoint. + if ( + save_checkpoints + and self.cfg.save_interval_unsharded is not None + and self.global_step % self.cfg.save_interval_unsharded == 0 + and self.cfg.save_num_unsharded_checkpoints_to_keep != 0 + ): + log.info("Saving unsharded checkpoint...") + checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded) + log.info(f"Unsharded checkpoint saved to {checkpoint_path}") + + # Reset speed monitor so that we don't count the time taken to save checkpoints. + speed_monitor.reset() + + # Maybe run evaluations. + if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0: + eval_metrics = self.eval() + + # Log metrics to W&B. + if wandb.run is not None: + wandb.log(eval_metrics, step=self.global_step) + + # Reset speed monitor so that we don't count the time taken to run evaluations. + speed_monitor.reset() + + # Reset model to 'train' mode. + self.fsdp_model.train() + + # End of batch. + first_batch = False + if p is not None: + p.step() + + if stop_at is not None and self.global_step >= stop_at: + break + + # Python Profiler stuff + # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. + if python_profiler is not None: + if self.global_step == 5: + python_profiler.enable() + elif self.global_step == 8: + python_profiler.disable() + python_profiler.print_stats(sort=SortKey.CUMULATIVE) + python_profiler = None + else: + log.info("Training epoch complete") + self.epoch = epoch + 1 + self.global_train_examples_seen_this_epoch = 0 + if self.epoch < self.max_epochs: + self.dataset.reshuffle() + continue + + break # Save final checkpoint. if save_checkpoints: diff --git a/scripts/train.py b/scripts/train.py index e7c16bb96..710cf0255 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -40,9 +40,9 @@ def main(cfg: TrainConfig) -> None: log_extra_field("run_name", cfg.run_name) # Sanity check - if cfg.reset_optimizer_state and cfg.load_path is None: + if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None: log.warning( - "You want to reset the optimizer state, but we're not loading from the checkpoint. The" + "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The" "setting has no effect." ) @@ -202,12 +202,13 @@ def dummy_init_fn(module: torch.nn.Module) -> None: trainer.restore_checkpoint( cfg.load_path, load_optimizer_state=not cfg.reset_optimizer_state, + load_trainer_state=not cfg.reset_trainer_state, sharded_checkpointer=cfg.load_path_sharded_checkpointer, ) log.info("Checkpoint successfully loaded") # If we have to, set a new scheduler: - if cfg.reset_optimizer_state: + if cfg.reset_optimizer_state and not cfg.reset_trainer_state: trainer.scheduler = BoltOnWarmupScheduler.wrap( trainer.scheduler, trainer.global_step,