Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0f0f0eb
enhanced lora
ysjprojects Jul 3, 2025
b0a804a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
4798b75
moved load_from_full_model_state_dict to utils for better abstraction…
ysjprojects Jul 6, 2025
d0eeb31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2025
e59911f
torch.distributed.barrier -> fabric.barrier
ysjprojects Jul 6, 2025
5468601
Merge branch 'finetune_lora_upgrade' of https://github.com/ysjproject…
ysjprojects Jul 6, 2025
a321f64
fix: rm passing lora_params to merge_lora
ysjprojects Jul 7, 2025
11ef960
Merge branch 'main' into finetune_lora_upgrade
ysjprojects Jul 8, 2025
20a0fea
test cases
ysjprojects Jul 10, 2025
be3fcfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2025
f04c4df
Merge branch 'main' into finetune_lora_upgrade
ysjprojects Jul 10, 2025
687dd18
Merge branch 'main' into finetune_lora_upgrade
Borda Jul 15, 2025
e27a064
Merge branch 'main' into finetune_lora_upgrade
Borda Jul 25, 2025
34a6690
Merge branch 'main' into finetune_lora_upgrade
ysjprojects Aug 13, 2025
9234f87
Merge branch 'main' into finetune_lora_upgrade
ysjprojects Aug 13, 2025
9b712fc
fix var shadowing linting err in test_lora.py
ysjprojects Aug 13, 2025
347e953
Merge branch 'main' into finetune_lora_upgrade
ysjprojects Aug 13, 2025
22bf472
Merge branch 'main' into finetune_lora_upgrade
Borda Aug 13, 2025
0a14ad3
Consistency with standard PyTorch checkpoint formats for save_lora_ch…
ysjprojects Aug 13, 2025
fd54bb9
Merge branch 'finetune_lora_upgrade' of github.com:ysjprojects/litgpt…
ysjprojects Aug 13, 2025
041f1a3
fixes to test functions for new lora impl
ysjprojects Aug 13, 2025
6fab829
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
f52eb8b
test lora fixes
ysjprojects Aug 13, 2025
96a25d5
Merge branch 'finetune_lora_upgrade' of github.com:ysjprojects/litgpt…
ysjprojects Aug 13, 2025
d91d8a5
fix allclose error in test_lora
ysjprojects Aug 13, 2025
cf2eea9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
3ce97f2
fix linting error: no bare except
ysjprojects Aug 13, 2025
1e54923
Merge branch 'finetune_lora_upgrade' of github.com:ysjprojects/litgpt…
ysjprojects Aug 13, 2025
e6a2099
test_lora fixes
ysjprojects Aug 13, 2025
38bc350
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
85aa460
full state dict fix test_lora
ysjprojects Aug 13, 2025
1526a61
load_model_from_full_state_dict fix: added bias conversion
ysjprojects Aug 13, 2025
c6037be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litgpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn
from litgpt.finetune.full import setup as finetune_full_fn
from litgpt.finetune.lora import setup as finetune_lora_fn
from litgpt.finetune.lora_legacy import setup as finetune_lora_legacy_fn
from litgpt.generate.adapter import main as generate_adapter_fn
from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn
from litgpt.generate.base import main as generate_base_fn
Expand All @@ -35,6 +36,7 @@ def main() -> None:
"chat": chat_fn,
"finetune": finetune_lora_fn,
"finetune_lora": finetune_lora_fn,
"finetune_lora_legacy": finetune_lora_legacy_fn,
"finetune_full": finetune_full_fn,
"finetune_adapter": finetune_adapter_fn,
"finetune_adapter_v2": finetune_adapter_v2_fn,
Expand Down
2 changes: 2 additions & 0 deletions litgpt/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class TrainArgs:
"""Total number of tokens to train on"""
max_steps: Optional[int] = None
"""Limits the number of optimizer steps to run"""
max_time: Optional[float] = None
"""Limits the number of seconds to train for"""
max_seq_length: Optional[int] = None
"""Limits the length of samples"""
tie_embeddings: Optional[bool] = None
Expand Down
79 changes: 66 additions & 13 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.utilities import ThroughputMonitor
from lightning_utilities.core.imports import RequirementCache
from torch.utils.data import ConcatDataset, DataLoader
Expand All @@ -20,7 +20,7 @@
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
Expand Down Expand Up @@ -70,6 +70,7 @@ def setup(
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
max_time=None,
),
log: LogArgs = LogArgs(),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
Expand Down Expand Up @@ -105,6 +106,7 @@ def setup(
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""

checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
Expand Down Expand Up @@ -152,12 +154,10 @@ def setup(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
" when using the --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={torch.nn.Linear},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
strategy = ModelParallelStrategy(
parallelize_fn=parallelize_fn,
data_parallel_size=devices * num_nodes,
tensor_parallel_size=1,
)
else:
strategy = "auto"
Expand All @@ -174,7 +174,9 @@ def setup(
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)

fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
fabric.launch(
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision
)


def main(
Expand All @@ -189,6 +191,7 @@ def main(
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
precision: Optional[str] = None,
) -> None:
validate_args(train, eval)

Expand Down Expand Up @@ -229,7 +232,6 @@ def main(
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)

# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

train_time = time.perf_counter()
Expand Down Expand Up @@ -264,12 +266,19 @@ def main(
save_path = out_dir / "final" / "lit_model.pth.lora"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, save_path)

fabric.barrier()
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
merge_lora(checkpoint_dir=save_path.parent)
merge_lora(
checkpoint_dir=save_path.parent,
pretrained_checkpoint_dir=checkpoint_dir,
precision=precision,
)
fabric.barrier()


def fit(
Expand Down Expand Up @@ -316,6 +325,8 @@ def fit(
total_lengths = 0
total_t0 = time.perf_counter()

max_time = train.max_time or float("inf")

token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
Expand All @@ -327,6 +338,12 @@ def fit(
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
generate_example(fabric, model, tokenizer, eval, data)
fabric.print(f"Number of epochs {train.epochs} reached, stopping training...")
break
if iter_t0 - total_t0 > max_time:
generate_example(fabric, model, tokenizer, eval, data)
fabric.print(f"Max time ({max_time / 60.0:.2f}m) reached, stopping training...")
break
input_ids, targets = batch["input_ids"], batch["labels"]

Expand Down Expand Up @@ -497,9 +514,45 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def parallelize_fn(model, device_mesh, activation_checkpointing=True):
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper

if activation_checkpointing:
model.transformer.h = torch.nn.ModuleList(
[checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h]
)

dp_mesh = device_mesh["data_parallel"]

for m in reversed(list(model.modules())):
if (
(isinstance(m, torch.nn.Linear) and m.weight.requires_grad)
or isinstance(m, CheckpointWrapper)
or isinstance(m, Block)
):
fully_shard(m, mesh=dp_mesh)

fully_shard(model, mesh=dp_mesh)

return model


def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": lora_filter})
cpu_state_dict = {}
sharded_sd = model.state_dict()
for param_name, param in sharded_sd.items():
if "lora_" not in param_name:
continue
if param.is_cpu:
param = param.to(fabric.device)
if hasattr(param, "_local_tensor"):
param = param.full_tensor()
if fabric.is_global_zero:
cpu_state_dict[param_name] = param.cpu()
fabric.barrier()
if fabric.is_global_zero:
torch.save({"model": cpu_state_dict}, file_path)


def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
Expand Down
Loading
Loading