diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index 8fd7ab4ce3..11e2bfe9f1 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -150,6 +150,7 @@ def setup( main( fabric, devices, + num_nodes, seed, initial_checkpoint_dir, resume, @@ -168,6 +169,7 @@ def setup( def main( fabric: L.Fabric, devices: int, + num_nodes: int, seed: int, initial_checkpoint_dir: Optional[Path], resume: Union[bool, Literal["auto"], Path], @@ -229,7 +231,11 @@ def main( fabric.load(resume, state) train_time = time.perf_counter() - fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optimizer) + fit( + fabric=fabric, devices=devices, num_nodes=num_nodes, state=state, + train_dataloader=train_dataloader, val_dataloader=val_dataloader, + out_dir=out_dir, tokenizer_dir=tokenizer_dir, train=train, eval=eval, optimizer=optimizer + ) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") # Save final checkpoint @@ -242,6 +248,7 @@ def main( def fit( fabric: L.Fabric, devices: int, + num_nodes: int, state: dict, train_dataloader: DataLoader, val_dataloader: DataLoader, @@ -269,18 +276,18 @@ def fit( max_tokens_per_device = train.max_tokens // fabric.world_size tokens_per_iter = train.micro_batch_size * model.max_seq_length max_iters = max_tokens_per_device // tokens_per_iter - log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices) + log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes) initial_iter = state["iter_num"] train_iterator = CycleIterator(train_dataloader) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() total_t0 = time.perf_counter() val_loss = "n/a" - warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader) + warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader) for train_data in train_iterator: if state["iter_num"] >= max_iters: @@ -297,10 +304,10 @@ def fit( input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long() targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long() - is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): loss = forward_and_loss(model, input_ids, targets) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/litgpt/args.py b/litgpt/args.py index 62c644f423..1401b95c18 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -50,24 +50,24 @@ def __post_init__(self) -> None: "`--train.lr_warmup_steps` should be less than `--train.max_steps`." f" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.", UserWarning) - def gradient_accumulation_iters(self, devices: int) -> int: + def gradient_accumulation_iters(self, devices: int, num_nodes: int) -> int: """Number of iterations between gradient synchronizations""" - gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size + gradient_accumulation_iters = self.batch_size(devices, num_nodes) // self.micro_batch_size assert gradient_accumulation_iters > 0 return gradient_accumulation_iters - def batch_size(self, devices: int) -> int: + def batch_size(self, devices: int, num_nodes: int) -> int: """Number of samples between optimizer steps per data-parallel rank""" - batch_size = self.global_batch_size // devices + batch_size = self.global_batch_size // (devices * num_nodes) assert batch_size > 0 return batch_size - def warmup_iters(self, devices: int, max_iters: int, train_dataloader) -> int: + def warmup_iters(self, devices: int, num_nodes: int, max_iters: int, train_dataloader) -> int: """Number of iterations to warm up the learning rate.""" if self.lr_warmup_fraction: return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader))) if self.lr_warmup_steps: - return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices)) + return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices, num_nodes)) return 0 diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 2f7801b8f1..f75ab29655 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -138,12 +138,13 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) def main( fabric: L.Fabric, devices: int, + num_nodes: int, seed: int, config: Config, data: DataModule, @@ -157,7 +158,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) - steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) @@ -201,6 +202,7 @@ def main( train_dataloader, val_dataloader, devices, + num_nodes, checkpoint_dir, out_dir, train, @@ -237,6 +239,7 @@ def fit( train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, + num_nodes: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, @@ -261,7 +264,7 @@ def fit( train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") @@ -284,13 +287,13 @@ def fit( break input_ids, targets = batch["input_ids"], batch["labels"] - is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index f05fd0d4d3..d7a522845a 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -138,12 +138,13 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) def main( fabric: L.Fabric, devices: int, + num_nodes: int, seed: int, config: Config, data: DataModule, @@ -157,7 +158,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) - steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) @@ -201,6 +202,7 @@ def main( train_dataloader, val_dataloader, devices, + num_nodes, checkpoint_dir, out_dir, train, @@ -237,6 +239,7 @@ def fit( train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, + num_nodes: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, @@ -261,7 +264,7 @@ def fit( train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") @@ -285,13 +288,13 @@ def fit( input_ids, targets = batch["input_ids"], batch["labels"] - is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index b507aa58e4..4e3438cd50 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -119,12 +119,13 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, num_nodes, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) def main( fabric: L.Fabric, devices: int, + num_nodes: int, resume: Union[bool, Literal["auto"], Path], seed: int, config: Config, @@ -139,7 +140,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) - steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) @@ -168,7 +169,7 @@ def main( load_checkpoint(fabric, state["model"], checkpoint_path) train_time = time.perf_counter() - token_counts = fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir, out_dir, train, eval, data) + token_counts = fit(fabric, state, train_dataloader, val_dataloader, devices, num_nodes, resume, checkpoint_dir, out_dir, train, eval, data) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) @@ -197,6 +198,7 @@ def fit( train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, + num_nodes: int, resume: Union[bool, Literal["auto"], Path], checkpoint_dir: Path, out_dir: Path, @@ -246,7 +248,7 @@ def fit( f" {initial_iter}." ) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() @@ -259,12 +261,12 @@ def fit( break input_ids, targets = batch["input_ids"], batch["labels"] - is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) # shift the targets such that output n predicts token n+1 loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index af88afb0ec..aec9f430dd 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -168,12 +168,13 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) def main( fabric: L.Fabric, devices: int, + num_nodes: int, seed: int, config: Config, data: DataModule, @@ -187,7 +188,7 @@ def main( tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) - steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) @@ -231,6 +232,7 @@ def main( train_dataloader, val_dataloader, devices, + num_nodes, checkpoint_dir, out_dir, train, @@ -269,6 +271,7 @@ def fit( train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, + num_nodes: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, @@ -293,7 +296,7 @@ def fit( train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") @@ -316,13 +319,13 @@ def fit( break input_ids, targets = batch["input_ids"], batch["labels"] - is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 739ac2df77..3bc4174e54 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -155,6 +155,7 @@ def setup( main( fabric, devices, + num_nodes, seed, initial_checkpoint_dir, resume, @@ -172,6 +173,7 @@ def setup( def main( fabric: L.Fabric, devices: int, + num_nodes: int, seed: int, initial_checkpoint_dir: Optional[Path], resume: Union[bool, Literal["auto"], Path], @@ -232,7 +234,11 @@ def main( fabric.load(resume, state) train_time = time.perf_counter() - fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval) + fit( + fabric=fabric, devices=devices, num_nodes=num_nodes, state=state, + train_dataloader=train_dataloader, val_dataloader=val_dataloader, + out_dir=out_dir, tokenizer_dir=tokenizer_dir, train=train, eval=eval + ) # Save final checkpoint save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth") @@ -258,6 +264,7 @@ def main( def fit( fabric: L.Fabric, devices: int, + num_nodes: int, state: dict, train_dataloader: DataLoader, val_dataloader: DataLoader, @@ -291,17 +298,17 @@ def fit( max_tokens_per_device = train.max_tokens // fabric.world_size tokens_per_iter = train.micro_batch_size * model.max_seq_length max_iters = max_tokens_per_device // tokens_per_iter - log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices) + log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes) initial_iter = state["iter_num"] train_iterator = CycleIterator(train_dataloader) - running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to( + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() total_t0 = time.perf_counter() - warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader) + warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader) for train_data in train_iterator: if state["iter_num"] >= max_iters: @@ -318,11 +325,11 @@ def fit( input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long() targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long() - is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 + is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) loss = chunked_cross_entropy(logits, targets) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) diff --git a/tests/test_args.py b/tests/test_args.py index 0b13c83976..e4749b818f 100644 --- a/tests/test_args.py +++ b/tests/test_args.py @@ -7,7 +7,7 @@ def test_compute_warmup_iters(): # warmup disabled train = TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=0) - assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 0 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 0 # lr_warmup_steps and lr_warmup_fraction both are not allowed with pytest.raises(ValueError, match="Can't provide both `--train.lr_warmup_fraction`"): @@ -19,18 +19,18 @@ def test_compute_warmup_iters(): # lr_warmup_steps train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=0) - assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 100 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 100 # lr_warmup_steps multiplied by accumulation factor train.global_batch_size = 4 - assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 400 - assert train.warmup_iters(devices=2, max_iters=1000, train_dataloader=range(10)) == 200 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 400 + assert train.warmup_iters(devices=2, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 200 # lr_warmup_steps truncated by max iters - assert train.warmup_iters(devices=1, max_iters=120, train_dataloader=range(10)) == 120 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=120, train_dataloader=range(10)) == 120 # lr_warmup_fraction train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=0, lr_warmup_fraction=0.3) - assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(100)) == 30 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(100)) == 30 # lr_warmup_fraction truncated by max iters - assert train.warmup_iters(devices=1, max_iters=20, train_dataloader=range(100)) == 20 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=20, train_dataloader=range(100)) == 20 # lr_warmup_fraction rounds up - assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(5)) == 2 + assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(5)) == 2