diff --git a/litgpt/__main__.py b/litgpt/__main__.py index a92e2027c2..6f63f589a2 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -12,6 +12,7 @@ from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn from litgpt.finetune.full import setup as finetune_full_fn from litgpt.finetune.lora import setup as finetune_lora_fn +from litgpt.finetune.lora_legacy import setup as finetune_lora_legacy_fn from litgpt.generate.adapter import main as generate_adapter_fn from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn from litgpt.generate.base import main as generate_base_fn @@ -35,6 +36,7 @@ def main() -> None: "chat": chat_fn, "finetune": finetune_lora_fn, "finetune_lora": finetune_lora_fn, + "finetune_lora_legacy": finetune_lora_legacy_fn, "finetune_full": finetune_full_fn, "finetune_adapter": finetune_adapter_fn, "finetune_adapter_v2": finetune_adapter_v2_fn, diff --git a/litgpt/args.py b/litgpt/args.py index ee0d99e2e6..1e69883be6 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -28,6 +28,8 @@ class TrainArgs: """Total number of tokens to train on""" max_steps: Optional[int] = None """Limits the number of optimizer steps to run""" + max_time: Optional[float] = None + """Limits the number of seconds to train for""" max_seq_length: Optional[int] = None """Limits the length of samples""" tie_embeddings: Optional[bool] = None diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 9593e1d4fe..0d025ba2bb 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -11,7 +11,7 @@ import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision -from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.utilities import ThroughputMonitor from lightning_utilities.core.imports import RequirementCache from torch.utils.data import ConcatDataset, DataLoader @@ -20,7 +20,7 @@ from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate -from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable +from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer @@ -70,6 +70,7 @@ def setup( lr_warmup_steps=100, epochs=5, max_seq_length=None, + max_time=None, ), log: LogArgs = LogArgs(), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), @@ -105,6 +106,7 @@ def setup( seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data @@ -152,12 +154,10 @@ def setup( "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" " when using the --quantize flag." ) - strategy = FSDPStrategy( - auto_wrap_policy={torch.nn.Linear}, - activation_checkpointing_policy={Block}, - state_dict_type="full", - limit_all_gathers=True, - cpu_offload=False, + strategy = ModelParallelStrategy( + parallelize_fn=parallelize_fn, + data_parallel_size=devices * num_nodes, + tensor_parallel_size=1, ) else: strategy = "auto" @@ -174,7 +174,9 @@ def setup( if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) + fabric.launch( + main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision + ) def main( @@ -189,6 +191,7 @@ def main( eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, + precision: Optional[str] = None, ) -> None: validate_args(train, eval) @@ -229,7 +232,6 @@ def main( optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) - # strict=False because missing keys due to LoRA weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) train_time = time.perf_counter() @@ -264,12 +266,19 @@ def main( save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, save_path) + + fabric.barrier() if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) - merge_lora(checkpoint_dir=save_path.parent) + merge_lora( + checkpoint_dir=save_path.parent, + pretrained_checkpoint_dir=checkpoint_dir, + precision=precision, + ) + fabric.barrier() def fit( @@ -316,6 +325,8 @@ def fit( total_lengths = 0 total_t0 = time.perf_counter() + max_time = train.max_time or float("inf") + token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), @@ -327,6 +338,12 @@ def fit( iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: + generate_example(fabric, model, tokenizer, eval, data) + fabric.print(f"Number of epochs {train.epochs} reached, stopping training...") + break + if iter_t0 - total_t0 > max_time: + generate_example(fabric, model, tokenizer, eval, data) + fabric.print(f"Max time ({max_time / 60.0:.2f}m) reached, stopping training...") break input_ids, targets = batch["input_ids"], batch["labels"] @@ -497,9 +514,45 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: return longest_seq_length, longest_seq_ix +def parallelize_fn(model, device_mesh, activation_checkpointing=True): + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper + + if activation_checkpointing: + model.transformer.h = torch.nn.ModuleList( + [checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h] + ) + + dp_mesh = device_mesh["data_parallel"] + + for m in reversed(list(model.modules())): + if ( + (isinstance(m, torch.nn.Linear) and m.weight.requires_grad) + or isinstance(m, CheckpointWrapper) + or isinstance(m, Block) + ): + fully_shard(m, mesh=dp_mesh) + + fully_shard(model, mesh=dp_mesh) + + return model + + def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: - fabric.print(f"Saving LoRA weights to {str(file_path)!r}") - fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) + cpu_state_dict = {} + sharded_sd = model.state_dict() + for param_name, param in sharded_sd.items(): + if "lora_" not in param_name: + continue + if param.is_cpu: + param = param.to(fabric.device) + if hasattr(param, "_local_tensor"): + param = param.full_tensor() + if fabric.is_global_zero: + cpu_state_dict[param_name] = param.cpu() + fabric.barrier() + if fabric.is_global_zero: + torch.save({"model": cpu_state_dict}, file_path) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: diff --git a/litgpt/finetune/lora_legacy.py b/litgpt/finetune/lora_legacy.py new file mode 100644 index 0000000000..10dcc0220b --- /dev/null +++ b/litgpt/finetune/lora_legacy.py @@ -0,0 +1,522 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import dataclasses +import math +import os +import time +import warnings +from pathlib import Path +from pprint import pprint +from typing import Dict, List, Literal, Optional, Tuple, Union + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities import ThroughputMonitor +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import ConcatDataset, DataLoader +from torchmetrics import RunningMean + +from litgpt.args import EvalArgs, LogArgs, TrainArgs +from litgpt.data import Alpaca, DataModule +from litgpt.generate.base import generate +from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable +from litgpt.prompts import save_prompt_style +from litgpt.scripts.merge_lora import merge_lora +from litgpt.tokenizer import Tokenizer +from litgpt.utils import ( + CycleIterator, + auto_download_checkpoint, + check_nvlink_connectivity, + check_valid_checkpoint_dir, + choose_logger, + chunked_cross_entropy, + copy_config_files, + create_finetuning_performance_report, + get_default_supported_precision, + init_out_dir, + instantiate_bnb_optimizer, + instantiate_torch_optimizer, + load_checkpoint, + num_parameters, + parse_devices, + save_hyperparameters, + select_sft_generate_example, +) + + +def setup( + checkpoint_dir: Path, + out_dir: Path = Path("out/finetune/lora"), + precision: Optional[str] = None, + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, + devices: Union[int, str] = 1, + num_nodes: int = 1, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_query: bool = True, + lora_key: bool = False, + lora_value: bool = True, + lora_projection: bool = False, + lora_mlp: bool = False, + lora_head: bool = False, + data: Optional[DataModule] = None, + train: TrainArgs = TrainArgs( + save_interval=1000, + log_interval=1, + global_batch_size=16, + micro_batch_size=1, + lr_warmup_steps=100, + epochs=5, + max_seq_length=None, + ), + log: LogArgs = LogArgs(), + eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), + optimizer: Union[str, Dict] = "AdamW", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + seed: int = 1337, + access_token: Optional[str] = None, +) -> None: + """Finetune a model using the LoRA method. + + Arguments: + checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. + out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in + /teamspace/jobs//share. + precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". + quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. + devices: How many devices/GPUs to use. + num_nodes: How many nodes the code is being run on. + lora_r: The LoRA rank. + lora_alpha: The LoRA alpha. + lora_dropout: The LoRA dropout value. + lora_query: Whether to apply LoRA to the query weights in attention. + lora_key: Whether to apply LoRA to the key weights in attention. + lora_value: Whether to apply LoRA to the value weights in attention. + lora_projection: Whether to apply LoRA to the output projection in the attention block. + lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. + lora_head: Whether to apply LoRA to output head in GPT. + data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. + train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. + eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. + optimizer: An optimizer name (such as "AdamW") or config. + logger_name: The name of the logger to send metrics to. + seed: The random seed to use for reproducibility. + access_token: Optional API token to access models with restrictions. + """ + + checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) + pprint(locals()) + data = Alpaca() if data is None else data + devices = parse_devices(devices) + out_dir = init_out_dir(out_dir) + + check_valid_checkpoint_dir(checkpoint_dir) + config = Config.from_file( + checkpoint_dir / "model_config.yaml", + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + lora_query=lora_query, + lora_key=lora_key, + lora_value=lora_value, + lora_projection=lora_projection, + lora_mlp=lora_mlp, + lora_head=lora_head, + ) + + precision = precision or get_default_supported_precision(training=True) + logger = choose_logger( + logger_name, + out_dir, + name=f"finetune-{config.name}", + log_interval=train.log_interval, + log_args=dataclasses.asdict(log), + ) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + if RequirementCache("bitsandbytes != 0.42.0"): + warnings.warn( + "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." + ) + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices * num_nodes > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" + " when using the --quantize flag." + ) + + strategy = FSDPStrategy( + auto_wrap_policy={torch.nn.Linear}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric( + devices=devices, + num_nodes=num_nodes, + strategy=strategy, + precision=precision, + loggers=logger, + plugins=plugins, + ) + + if torch.cuda.is_available() and devices > 1: + check_nvlink_connectivity(fabric) + + fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) + + +def main( + fabric: L.Fabric, + devices: int, + seed: int, + config: Config, + data: DataModule, + checkpoint_dir: Path, + out_dir: Path, + train: TrainArgs, + eval: EvalArgs, + optimizer: Union[str, Dict], + num_nodes: int = 1, +) -> None: + validate_args(train, eval) + + tokenizer = Tokenizer(checkpoint_dir) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) + lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) + + fabric.seed_everything(seed) # same seed for every process to init model (FSDP) + + if fabric.global_rank == 0: + os.makedirs(out_dir, exist_ok=True) + + checkpoint_path = checkpoint_dir / "lit_model.pth" + with fabric.init_module(empty_init=(fabric.world_size > 1)): + model = GPT(config) + mark_only_lora_as_trainable(model) + + fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") + fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") + + model = fabric.setup_module(model) + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) + + from bitsandbytes.nn import StableEmbedding + + old_embedding = model.transformer.wte + model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) + with torch.no_grad(): + model.transformer.wte.weight.copy_(old_embedding.weight) + model.transformer.wte = model.transformer.wte.to( + device=old_embedding.weight.device, dtype=old_embedding.weight.dtype + ) + else: + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) + + optimizer = fabric.setup_optimizers(optimizer) + scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) + + # strict=False because missing keys due to LoRA weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) + + train_time = time.perf_counter() + token_counts = fit( + fabric=fabric, + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + devices=devices, + num_nodes=num_nodes, + checkpoint_dir=checkpoint_dir, + out_dir=out_dir, + train=train, + eval=eval, + data=data, + ) + + training_time = time.perf_counter() - train_time + output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) + fabric.print(output) + + # Final evaluation + if eval.final_validation: + val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) + metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} + fabric.log_dict(metrics) + fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") + + # Save the final LoRA checkpoint at the end of training + save_path = out_dir / "final" / "lit_model.pth.lora" + save_path.parent.mkdir(parents=True, exist_ok=True) + save_lora_checkpoint(fabric, model, save_path) + if fabric.global_rank == 0: + # Copy checkpoint files from original checkpoint dir + copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) + save_prompt_style(data.prompt_style, save_path.parent) + merge_lora(checkpoint_dir=save_path.parent) + + +def fit( + fabric: L.Fabric, + model: GPT, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + devices: int, + checkpoint_dir: Path, + out_dir: Path, + train: TrainArgs, + eval: EvalArgs, + data: DataModule, + num_nodes: int = 1, +) -> dict: + tokenizer = Tokenizer(checkpoint_dir) + longest_seq_length, longest_seq_ix = get_longest_seq_length( + ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) + ) + model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + fabric.print( + f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" + f" {model.max_seq_length} and context length is {model.config.block_size}" + ) + + if eval.initial_validation: + val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) + val_loss = f"{val_loss:.3f}" + else: + fabric.print("Verifying settings ...") + validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check + val_loss = "n/a" + + train_iterator = CycleIterator(train_dataloader) + throughput = ThroughputMonitor(fabric, window_size=50) + running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( + fabric.device + ) + max_steps = train.max_steps or float("inf") + step_count = 0 + iter_num = 0 + total_lengths = 0 + total_t0 = time.perf_counter() + + token_counts = { + "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), + "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), + "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), + } + + while step_count < max_steps: + iter_num += 1 + iter_t0 = time.perf_counter() + batch = next(train_iterator) + if train_iterator.epoch >= train.epochs: + break + input_ids, targets = batch["input_ids"], batch["labels"] + + is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids, lm_head_chunk_size=128) + # shift the targets such that output n predicts token n+1 + logits[-1] = logits[-1][..., :-1, :] + loss = chunked_cross_entropy(logits, targets[..., 1:]) + fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) + + running_loss.update(loss.detach()) + + if not is_accumulating: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + step_count += 1 + + token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() + token_counts["raw_tokens_plus_prompt_template"] += ( + batch["token_counts"]["raw_plus_prompt_template"].sum().item() + ) + token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() + + total_lengths += input_ids.numel() + if iter_num % train.log_interval == 0: + loss = running_loss.compute().item() # expensive device-to-host synchronization + t1 = time.perf_counter() + throughput.update( + time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths + ) + throughput.compute_and_log(step=iter_num) + metrics = { + "loss": loss, + "iter": iter_num, + "step": step_count, + "epoch": train_iterator.epoch, + "iter_time": t1 - iter_t0, + "tokens": token_counts["raw_tokens_plus_prompt_template"], + "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, + "learning_rate": scheduler.get_last_lr()[0], + } + if isinstance(val_loss, torch.Tensor): + val_loss = f"{val_loss:.3f}" + fabric.print( + f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" + f" loss train: {metrics['loss']:.3f}," + f" val: {val_loss} |" + f" iter time: {metrics['iter_time'] * 1000:.2f} ms" + f"{' (step)' if not is_accumulating else ''}" + ) + fabric.log_dict(metrics, step=iter_num) + + if not is_accumulating and step_count % eval.interval == 0: + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader, eval) + generate_example(fabric, model, tokenizer, eval, data) + t1 = time.perf_counter() - t0 + + val_loss_tensor = val_loss.detach().clone().to(fabric.device) + val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) + + fabric.all_reduce(val_loss_tensor, reduce_op="mean") + fabric.all_reduce(val_time_tensor, reduce_op="mean") + + fabric.print( + f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" + ) + metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} + fabric.log_dict(metrics, step=iter_num) + fabric.barrier() + + if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: + checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" + checkpoint_file.parent.mkdir(parents=True, exist_ok=True) + save_lora_checkpoint(fabric, model, checkpoint_file) + if fabric.global_rank == 0: + copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) + save_prompt_style(data.prompt_style, checkpoint_file.parent) + + total_token_counts = {} + for key in token_counts: + total = fabric.all_reduce(token_counts[key], reduce_op="sum") + total_token_counts[key] = total.item() + + return total_token_counts + + +# FSDP has issues with `inference_mode` +@torch.no_grad() +def validate( + fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True +) -> torch.Tensor: + if verbose: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) + for k, batch in enumerate(val_dataloader): + if k >= eval.max_iters: + break + input_ids, targets = batch["input_ids"], batch["labels"] + logits = model(input_ids) + losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) + + val_loss = losses.mean() + + model.train() + return val_loss + + +@torch.no_grad() +def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): + instruction = select_sft_generate_example(eval, data) + + fabric.print(instruction) + prompt = data.prompt_style.apply(instruction) + encoded = tokenizer.encode(prompt, device=fabric.device) + model.eval() + + max_returned_tokens = len(encoded) + eval.max_new_tokens + + if max_returned_tokens < model.max_seq_length: + with fabric.init_tensor(): + # do not set `max_seq_length=max_returned_token` because memory is not a concern here + model.set_kv_cache(batch_size=1) + output = generate( + model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + ) + model.clear_kv_cache() + model.train() + output = tokenizer.decode(output) + fabric.print(f"{output}\n") + else: + print( + f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " + f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " + f"The model's supported context size (post-training) is {model.config.block_size}." + ) + + +def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): + # linear warmup followed by cosine annealing + scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) + return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) + + +def get_dataloaders( + fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs +) -> Tuple[DataLoader, DataLoader]: + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + return train_dataloader, val_dataloader + + +def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: + # find out the minimum max_seq_length required during fine-tuning (saves memory!) + lengths = [len(d["input_ids"]) for d in data] + longest_seq_length = max(lengths) + longest_seq_ix = lengths.index(longest_seq_length) + return longest_seq_length, longest_seq_ix + + +def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: + fabric.print(f"Saving LoRA weights to {str(file_path)!r}") + fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) + + +def validate_args(train: TrainArgs, eval: EvalArgs) -> None: + issues = [] + unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] + for args, names in unsupported: + for name in names: + if getattr(args, name) is not None: + issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") + required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] + for args, names in required: + for name in names: + if getattr(args, name) is None: + issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") + if not train.epochs and not train.max_steps: + issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") + if issues: + raise ValueError("\n".join(issues)) diff --git a/litgpt/utils.py b/litgpt/utils.py index af97fa2f11..a9d998f7af 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -25,7 +25,7 @@ import torch.utils._device import yaml from lightning.fabric.loggers import CSVLogger, TensorBoardLogger -from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from lightning.pytorch.cli import instantiate_class from lightning.pytorch.loggers import MLFlowLogger, WandbLogger @@ -379,6 +379,15 @@ def get_default_supported_precision(training: bool) -> str: def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: if isinstance(fabric.strategy, FSDPStrategy): fabric.load_raw(checkpoint_path, model, strict=strict) + elif isinstance(fabric.strategy, ModelParallelStrategy): + state_dict = torch.load(checkpoint_path, mmap=True) + load_from_full_model_state_dict( + model=model, + full_sd=state_dict, + device=fabric.device, + strict=strict, + cpu_offload=True, + ) else: state_dict = lazy_load(checkpoint_path) state_dict = state_dict.get("model", state_dict) @@ -398,6 +407,41 @@ def load_checkpoint_update( model.load_state_dict(state_dict, strict=strict) +def load_from_full_model_state_dict( + model: torch.nn.Module, + full_sd: Dict[str, Any], + device: torch.device, + strict: bool = False, + cpu_offload: bool = False, +): + from torch.distributed._tensor import distribute_tensor + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + print(meta_sharded_sd.keys()) + for param_name, full_tensor in full_sd.items(): + if "norm" not in param_name and "wte" not in param_name and "ln_f" not in param_name: + param_name = param_name.replace(".weight", ".linear.weight") + param_name = param_name.replace(".bias", ".linear.bias") + else: + param_name = param_name + + print(param_name) + + sharded_meta_param = meta_sharded_sd.get(param_name) + full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + # choose `assign=True` since we cannot call `copy_` on meta tensor + return model.load_state_dict(sharded_sd, strict=strict, assign=True) + + def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation # this assumes that all samples have a fixed length equal to the block size diff --git a/tests/test_cli.py b/tests/test_cli.py index 6d1c1091be..1ddbf5588e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,7 +16,7 @@ def test_cli(): out = out.getvalue() assert "usage: litgpt" in out assert ( - "{download,chat,finetune,finetune_lora,finetune_full,finetune_adapter,finetune_adapter_v2," + "{download,chat,finetune,finetune_lora,finetune_lora_legacy,finetune_full,finetune_adapter,finetune_adapter_v2," "pretrain,generate,generate_full,generate_adapter,generate_adapter_v2,generate_sequentially," "generate_speculatively,generate_tp,convert_to_litgpt,convert_from_litgpt,convert_pretrained_checkpoint," "merge_lora,evaluate,serve}" in out diff --git a/tests/test_lora.py b/tests/test_lora.py index 7fa0a3f810..1585ea4449 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -14,6 +14,7 @@ from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging +from torch.distributed.device_mesh import init_device_mesh from torch.nn import functional as F from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM @@ -956,3 +957,213 @@ def test_load_legacy_state_dict(): attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) + + +@_RunIf(standalone=True, min_cuda_gpus=2) +def test_parallelize_fn(): + from litgpt.finetune.lora import parallelize_fn + + config = Config( + n_layer=2, + n_head=4, + n_embd=32, + block_size=8, + vocab_size=8, + lora_r=4, + lora_alpha=8, + lora_dropout=0.1, + lora_query=True, + lora_value=True, + lora_projection=True, + ) + + fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") + fabric.launch() + + model = LoRAGPT(config) + mark_only_lora_as_trainable(model) + + # create device mesh for data parallel + device_mesh = init_device_mesh( + device_type=fabric.device.type, + mesh_shape=(2, 1), + mesh_dim_names=("data_parallel", "tensor_parallel"), + ) + + # test with activation checkpointing enabled (default) + parallelized_model = parallelize_fn(model, device_mesh, activation_checkpointing=True) + + # verify the model is still functional + assert parallelized_model is not None + assert isinstance(parallelized_model, LoRAGPT) + + parallelized_model = parallelized_model.to(fabric.device) + + # test forward pass to ensure the parallelized model works + x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device) + + # verify forward pass works + with torch.no_grad(): + output = parallelized_model(x) + assert output.shape == (1, config.block_size, config.padded_vocab_size) + + # test with activation checkpointing disabled + model_no_checkpoint = LoRAGPT(config) + mark_only_lora_as_trainable(model_no_checkpoint) + + parallelized_model_no_checkpoint = parallelize_fn(model_no_checkpoint, device_mesh, activation_checkpointing=False) + + # verify the model is still functional + assert parallelized_model_no_checkpoint is not None + assert isinstance(parallelized_model_no_checkpoint, LoRAGPT) + + # test forward pass to ensure the parallelized model works + parallelized_model_no_checkpoint = parallelized_model_no_checkpoint.to(fabric.device) + + with torch.no_grad(): + output = parallelized_model_no_checkpoint(x) + assert output.shape == (1, config.block_size, config.padded_vocab_size) + + # verify that all parameters are properly distributed (not on meta device) + for mod in parallelized_model.modules(): + for param_name, param in mod.named_parameters(): + if param.requires_grad: # Only check trainable parameters (LoRA parameters) + assert not param.is_meta, f"Parameter `{param_name}` should not be on meta device" + assert param.device.type == "cuda", f"Parameter `{param_name}` should be on CUDA device" + + +@_RunIf(standalone=True, min_cuda_gpus=2) +def test_load_from_full_model_state_dict(): + from litgpt.finetune.lora import parallelize_fn + from litgpt.utils import load_from_full_model_state_dict + + config = Config( + n_layer=2, + n_head=4, + n_embd=32, + block_size=8, + vocab_size=8, + lora_r=4, + lora_alpha=8, + lora_dropout=0.1, + lora_query=True, + lora_value=True, + lora_projection=True, + lora_mlp=True, + lora_head=True, + ) + + # set up distributed environment with FSDP + fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") + fabric.launch() + + # create a reference model to get the full state dict + reference_model = LoRAGPT(config) + mark_only_lora_as_trainable(reference_model) + + # initialize the reference model with some values + with torch.no_grad(): + for param in reference_model.parameters(): + if param.requires_grad: + param.fill_(0.1) + + # get the full state dict (simulating a checkpoint) + full_state_dict = {} + for name, param in reference_model.named_parameters(): + # Convert parameters to checkpoint format (what load_from_full_model_state_dict expects) + if "norm" not in name and "wte" not in name and "ln_f" not in name: + # For linear layers, remove .linear from the name to simulate checkpoint format + checkpoint_name = name.replace(".linear.weight", ".weight").replace(".linear.bias", ".bias") + else: + # For norm, embedding, and layer norm layers, keep the original name + checkpoint_name = name + full_state_dict[checkpoint_name] = param.detach().clone() + + # create distributed model + model = LoRAGPT(config) + mark_only_lora_as_trainable(model) + + # set up device mesh for distributed model + device_mesh = init_device_mesh( + device_type=fabric.device.type, + mesh_shape=(2, 1), + mesh_dim_names=("data_parallel", "tensor_parallel"), + ) + model = parallelize_fn(model, device_mesh, activation_checkpointing=False) + model = model.to(fabric.device) + + # test with default parameters (strict=False, cpu_offload=False) + result = load_from_full_model_state_dict( + model=model, + full_sd=full_state_dict, + device=fabric.device, + strict=False, + cpu_offload=False, + ) + + # verify that the function returns the missing/unexpected keys + assert hasattr(result, "missing_keys") + assert hasattr(result, "unexpected_keys") + + # verify that parameters are loaded correctly + for name, param in model.named_parameters(): + if param.requires_grad: + # Check that parameter is not on meta device + assert not param.is_meta, f"Parameter {name} should not be on meta device" + # Check that parameter is on the correct device + assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device" + + # test with cpu_offload=True + model_cpu_offload = LoRAGPT(config) + mark_only_lora_as_trainable(model_cpu_offload) + model_cpu_offload = parallelize_fn(model_cpu_offload, device_mesh, activation_checkpointing=False) + model_cpu_offload = model_cpu_offload.to(fabric.device) + + result_cpu_offload = load_from_full_model_state_dict( + model=model_cpu_offload, + full_sd=full_state_dict, + device=fabric.device, + strict=False, + cpu_offload=True, + ) + + # verify that parameters are loaded correctly with CPU offload + for name, param in model_cpu_offload.named_parameters(): + if param.requires_grad: + # Check that parameter is not on meta device + assert not param.is_meta, f"Parameter {name} should not be on meta device" + # With cpu_offload, parameters might be on CPU + assert param.device.type in ["cpu", "cuda"], f"Parameter {name} should be on CPU or CUDA device" + + # test with strict=True + model_strict = LoRAGPT(config) + mark_only_lora_as_trainable(model_strict) + model_strict = parallelize_fn(model_strict, device_mesh, activation_checkpointing=False) + model_strict = model_strict.to(fabric.device) + + try: + result_strict = load_from_full_model_state_dict( + model=model_strict, + full_sd=full_state_dict, + device=fabric.device, + strict=True, + cpu_offload=False, + ) + # If strict loading succeeds, verify parameters + for name, param in model_strict.named_parameters(): + if param.requires_grad: + assert not param.is_meta, f"Parameter {name} should not be on meta device" + assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device" + except RuntimeError as e: + # strict=True might fail if there are missing keys, which is expected behavior + assert "Missing key(s)" in str(e) or "Unexpected key(s)" in str(e) + + # test forward pass to ensure model still works after loading + x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device) + + with torch.no_grad(): + output = model(x) + assert output.shape == (1, config.block_size, config.padded_vocab_size) + + output_cpu_offload = model_cpu_offload(x) + assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)