Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def setup(
main(
fabric,
devices,
num_nodes,
seed,
initial_checkpoint_dir,
resume,
Expand All @@ -168,6 +169,7 @@ def setup(
def main(
fabric: L.Fabric,
devices: int,
num_nodes: int,
seed: int,
initial_checkpoint_dir: Optional[Path],
resume: Union[bool, Literal["auto"], Path],
Expand Down Expand Up @@ -229,7 +231,11 @@ def main(
fabric.load(resume, state)

train_time = time.perf_counter()
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optimizer)
fit(
fabric=fabric, devices=devices, num_nodes=num_nodes, state=state,
train_dataloader=train_dataloader, val_dataloader=val_dataloader,
out_dir=out_dir, tokenizer_dir=tokenizer_dir, train=train, eval=eval, optimizer=optimizer
)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")

# Save final checkpoint
Expand All @@ -242,6 +248,7 @@ def main(
def fit(
fabric: L.Fabric,
devices: int,
num_nodes: int,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
Expand Down Expand Up @@ -269,18 +276,18 @@ def fit(
max_tokens_per_device = train.max_tokens // fabric.world_size
tokens_per_iter = train.micro_batch_size * model.max_seq_length
max_iters = max_tokens_per_device // tokens_per_iter
log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices)
log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes)
initial_iter = state["iter_num"]
train_iterator = CycleIterator(train_dataloader)

running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
total_t0 = time.perf_counter()
val_loss = "n/a"

warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader)
warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader)

for train_data in train_iterator:
if state["iter_num"] >= max_iters:
Expand All @@ -297,10 +304,10 @@ def fit(
input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long()
targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long()

is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0
is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
loss = forward_and_loss(model, input_ids, targets)
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

running_loss.update(loss.detach())

Expand Down
12 changes: 6 additions & 6 deletions litgpt/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ def __post_init__(self) -> None:
"`--train.lr_warmup_steps` should be less than `--train.max_steps`."
f" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.", UserWarning)

def gradient_accumulation_iters(self, devices: int) -> int:
def gradient_accumulation_iters(self, devices: int, num_nodes: int) -> int:
"""Number of iterations between gradient synchronizations"""
gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size
gradient_accumulation_iters = self.batch_size(devices, num_nodes) // self.micro_batch_size
assert gradient_accumulation_iters > 0
return gradient_accumulation_iters

def batch_size(self, devices: int) -> int:
def batch_size(self, devices: int, num_nodes: int) -> int:
"""Number of samples between optimizer steps per data-parallel rank"""
batch_size = self.global_batch_size // devices
batch_size = self.global_batch_size // (devices * num_nodes)
assert batch_size > 0
return batch_size

def warmup_iters(self, devices: int, max_iters: int, train_dataloader) -> int:
def warmup_iters(self, devices: int, num_nodes: int, max_iters: int, train_dataloader) -> int:
"""Number of iterations to warm up the learning rate."""
if self.lr_warmup_fraction:
return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader)))
if self.lr_warmup_steps:
return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices))
return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices, num_nodes))
return 0


Expand Down
13 changes: 8 additions & 5 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)
fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


def main(
fabric: L.Fabric,
devices: int,
num_nodes: int,
seed: int,
config: Config,
data: DataModule,
Expand All @@ -157,7 +158,7 @@ def main(

tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
Expand Down Expand Up @@ -201,6 +202,7 @@ def main(
train_dataloader,
val_dataloader,
devices,
num_nodes,
checkpoint_dir,
out_dir,
train,
Expand Down Expand Up @@ -237,6 +239,7 @@ def fit(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
num_nodes: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
Expand All @@ -261,7 +264,7 @@ def fit(

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
Expand All @@ -284,13 +287,13 @@ def fit(
break
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

running_loss.update(loss.detach())

Expand Down
13 changes: 8 additions & 5 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)
fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


def main(
fabric: L.Fabric,
devices: int,
num_nodes: int,
seed: int,
config: Config,
data: DataModule,
Expand All @@ -157,7 +158,7 @@ def main(

tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
Expand Down Expand Up @@ -201,6 +202,7 @@ def main(
train_dataloader,
val_dataloader,
devices,
num_nodes,
checkpoint_dir,
out_dir,
train,
Expand Down Expand Up @@ -237,6 +239,7 @@ def fit(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
num_nodes: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
Expand All @@ -261,7 +264,7 @@ def fit(

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
Expand All @@ -285,13 +288,13 @@ def fit(

input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

running_loss.update(loss.detach())

Expand Down
14 changes: 8 additions & 6 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)
fabric.launch(main, devices, num_nodes, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


def main(
fabric: L.Fabric,
devices: int,
num_nodes: int,
resume: Union[bool, Literal["auto"], Path],
seed: int,
config: Config,
Expand All @@ -139,7 +140,7 @@ def main(

tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
Expand Down Expand Up @@ -168,7 +169,7 @@ def main(
load_checkpoint(fabric, state["model"], checkpoint_path)

train_time = time.perf_counter()
token_counts = fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir, out_dir, train, eval, data)
token_counts = fit(fabric, state, train_dataloader, val_dataloader, devices, num_nodes, resume, checkpoint_dir, out_dir, train, eval, data)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
Expand Down Expand Up @@ -197,6 +198,7 @@ def fit(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
num_nodes: int,
resume: Union[bool, Literal["auto"], Path],
checkpoint_dir: Path,
out_dir: Path,
Expand Down Expand Up @@ -246,7 +248,7 @@ def fit(
f" {initial_iter}."
)

running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
Expand All @@ -259,12 +261,12 @@ def fit(
break
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0
is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
# shift the targets such that output n predicts token n+1
loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

running_loss.update(loss.detach())

Expand Down
13 changes: 8 additions & 5 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,13 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)
fabric.launch(main, devices, num_nodes, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


def main(
fabric: L.Fabric,
devices: int,
num_nodes: int,
seed: int,
config: Config,
data: DataModule,
Expand All @@ -187,7 +188,7 @@ def main(

tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
Expand Down Expand Up @@ -231,6 +232,7 @@ def main(
train_dataloader,
val_dataloader,
devices,
num_nodes,
checkpoint_dir,
out_dir,
train,
Expand Down Expand Up @@ -269,6 +271,7 @@ def fit(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
num_nodes: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
Expand All @@ -293,7 +296,7 @@ def fit(

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
Expand All @@ -316,13 +319,13 @@ def fit(
break
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

running_loss.update(loss.detach())

Expand Down
Loading
Loading