Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7d6edab
test readme commands
rasbt Apr 17, 2024
edcb52c
add finetuning
rasbt Apr 17, 2024
e40c51e
Merge branch 'main' into readme-tests
rasbt Apr 17, 2024
84c3be7
add pretraining
rasbt Apr 17, 2024
7a06e62
update
rasbt Apr 17, 2024
880c264
Merge branch 'main' into readme-tests
rasbt Apr 22, 2024
6dc935f
fix path issue
rasbt Apr 22, 2024
3dc9c59
add serving tests
rasbt Apr 22, 2024
b51c2dd
Merge branch 'main' into readme-tests
rasbt Apr 22, 2024
c24f172
test
rasbt Apr 22, 2024
52fbb5f
Merge branch 'main' into readme-tests
rasbt Apr 24, 2024
ad936df
updates
rasbt Apr 24, 2024
832129f
Merge branch 'main' into readme-tests
rasbt Apr 24, 2024
3cfbeac
fixes
rasbt Apr 24, 2024
631c133
Merge branch 'main' into readme-tests
rasbt Apr 25, 2024
4283611
increase timeout for cpu tests
rasbt Apr 25, 2024
f08ae19
increase timeout for cpu tests
rasbt Apr 25, 2024
20e22f2
make validation cheaper
rasbt Apr 25, 2024
98e7557
accelerate pretrain
rasbt Apr 25, 2024
a712336
fix path for windows
rasbt Apr 25, 2024
fa692c8
cleanuo
rasbt Apr 25, 2024
5af2386
fix windows slash issue
rasbt Apr 25, 2024
ca65add
fix tests for windows
rasbt Apr 25, 2024
5c369c7
udpate
rasbt Apr 25, 2024
ebc4e7f
fix tests
rasbt Apr 25, 2024
834eee9
update
rasbt Apr 25, 2024
c1ddf36
remove windows and mac tests
rasbt Apr 25, 2024
8a68f20
exclude azure
rasbt Apr 25, 2024
a5165a9
Merge branch 'main' into readme-tests
rasbt Apr 26, 2024
f0e1781
Merge branch 'main' into readme-tests
rasbt May 2, 2024
53fac34
Merge branch 'main' into readme-tests
rasbt May 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ test = [
"pytest>=8.1.1",
"pytest-rerunfailures>=14.0",
"pytest-timeout>=2.3.1",
"pytest-dependency>=0.6.0",
"transformers>=4.38.0", # numerical comparisons
"einops>=0.7.0",
"protobuf>=4.23.4",
Expand Down
170 changes: 170 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from pathlib import Path
import os
import pytest
import requests
import subprocess
import sys
import threading
import time


REPO_ID = Path("EleutherAI/pythia-14m")
CUSTOM_TEXTS_DIR = Path("custom_texts")


def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
return result.stdout
except subprocess.CalledProcessError as e:
error_message = (
f"Command '{' '.join(command)}' failed with exit status {e.returncode}\n"
f"Output:\n{e.stdout}\n"
f"Error:\n{e.stderr}"
)
# You can either print the message, log it, or raise an exception with it
print(error_message)
raise RuntimeError(error_message) from None


@pytest.mark.skipif(
sys.platform.startswith("win") or
sys.platform == "darwin" or
'AGENT_NAME' in os.environ,
reason="Does not run on Windows, macOS, or Azure Pipelines"
)
@pytest.mark.dependency()
def test_download_model():
repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI
command = ["litgpt", "download", "--repo_id", str(repo_id)]
output = run_command(command)

s = Path("checkpoints") / repo_id
assert f"Saving converted checkpoint to {str(s)}" in output
assert ("checkpoints" / REPO_ID).exists()


@pytest.mark.dependency()
def test_download_books():
CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True)

books = [
("https://www.gutenberg.org/cache/epub/24440/pg24440.txt", "book1.txt"),
("https://www.gutenberg.org/cache/epub/26393/pg26393.txt", "book2.txt")
]
for url, filename in books:
subprocess.run(["curl", url, "--output", str(CUSTOM_TEXTS_DIR / filename)], check=True)
# Verify each book is downloaded
assert (CUSTOM_TEXTS_DIR / filename).exists(), f"{filename} not downloaded"


@pytest.mark.dependency(depends=["test_download_model"])
def test_chat_with_model():
command = ["litgpt", "generate", "base", "--checkpoint_dir", f"checkpoints"/REPO_ID]
prompt = "What do Llamas eat?"
result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)
assert "What food do llamas eat?" in result.stdout


@pytest.mark.dependency(depends=["test_download_model"])
@pytest.mark.timeout(300)
def test_finetune_model():

OUT_DIR = Path("out") / "lora"
DATASET_PATH = Path("custom_finetuning_dataset.json")
CHECKPOINT_DIR = "checkpoints" / REPO_ID

download_command = ["curl", "-L", "https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json", "-o", str(DATASET_PATH)]
subprocess.run(download_command, check=True)

assert DATASET_PATH.exists(), "Dataset file not downloaded"

finetune_command = [
"litgpt", "finetune", "lora",
"--checkpoint_dir", str(CHECKPOINT_DIR),
"--lora_r", "1",
"--data", "JSON",
"--data.json_path", str(DATASET_PATH),
"--data.val_split_fraction", "0.00001", # Keep small because new final validation is expensive
"--train.max_steps", "1",
"--out_dir", str(OUT_DIR)
]
run_command(finetune_command)

assert (OUT_DIR/"final").exists(), "Finetuning output directory was not created"
assert (OUT_DIR/"final"/"lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_pretrain_model():
OUT_DIR = Path("out") / "custom_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"--model_name", "pythia-14m",
"--tokenizer_dir", str("checkpoints" / REPO_ID),
"--data", "TextFiles",
"--data.train_data_path", str(CUSTOM_TEXTS_DIR),
"--train.max_tokens", "100", # to accelerate things for CI
"--eval.max_iters", "1", # to accelerate things for CI
"--out_dir", str(OUT_DIR)
]
run_command(pretrain_command)

assert (OUT_DIR / "final").exists(), "Pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_continue_pretrain_model():
OUT_DIR = Path("out") / "custom_continue_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"--model_name", "pythia-14m",
"--initial_checkpoint", str("checkpoints" / REPO_ID),
"--tokenizer_dir", str("checkpoints" / REPO_ID),
"--data", "TextFiles",
"--data.train_data_path", str(CUSTOM_TEXTS_DIR),
"--train.max_tokens", "100", # to accelerate things for CI
"--eval.max_iters", "1", # to accelerate things for CI
"--out_dir", str(OUT_DIR)
]
run_command(pretrain_command)

assert (OUT_DIR / "final").exists(), "Continued pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model"])
def test_serve():
CHECKPOINT_DIR = str("checkpoints" / REPO_ID)
run_command = [
"litgpt", "serve",
"--checkpoint_dir", str(CHECKPOINT_DIR)
]

process = None

def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate(timeout=60)
except subprocess.TimeoutExpired:
print('Server start-up timeout expired')

server_thread = threading.Thread(target=run_server)
server_thread.start()

# Allow time to initialize and start serving
time.sleep(30)

try:
response = requests.get("http://127.0.0.1:8000")
print(response.status_code)
assert response.status_code == 200, "Server did not respond as expected."
finally:
if process:
process.kill()
server_thread.join()