Skip to content

Commit 4d0dd74

Browse files
Add unit tests for all modules (#138)
* test: Add unit tests for LighterSystem functionality and behavior * test: Add comprehensive unit tests for all modules in the lighter package * fix: Import torch in test files to resolve undefined name errors * fix: Resolve optimizer empty parameter list and update test assertions * fix: Import Path from pathlib to resolve undefined name errors in tests * fix: Ensure layers are frozen correctly in LighterFreezer test and fix writer path error * test: Remove unnecessary frozen state assignment in freezer test * test: Include dummy Trainer in freezer and system tests for validation * fix: Import missing classes in test_freezer and test_system files * fix: Import missing Dataset, DataLoader, and Accuracy in test files * chore: Update imports in test files for consistency and clarity * fix: Resolve TypeError in trainer setup and import DataLoader correctly * feat: Wrap dummy models into a minimal PyTorch Lightning system * fix: Correct test case for replacing layer with identity in unit tests * chore: Import Module from torch.nn in test_freezer.py * test: Fix NameError and improve tests for LighterFreezer functionality * refactor: Simplify dataset return structure in test_freezer.py * refactor: Replace DummySystem with LighterSystem and add model tests * feat: Update DummyModel architecture and enhance dummy_system setup * test: Ensure optimizer is correctly set up in freezer exception test * test: Update assertion to check for grad_fn in freezer test case * test: Remove redundant optimizer setup assertions from test_freezer.py * fix: Correct freezing logic in freezer tests to match expected behavior * Fixes, reorganize * test: Remove skip markers from training and validation step tests * fix: Attach DummySystem to Trainer and resolve batch size mismatches * fix: Ensure DummySystem is attached to Trainer in test_valid_batch_formats * test: Refactor training, validation, and prediction steps in tests * feat: Add predict dataset to DummySystem and update corresponding tests * test: Update unit tests to use training_step with mock logging * test: Add unit tests for empty datasets, model forward pass, and system initialization * Fix tests * Update tests/unit/test_callbacks_writer_file.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update tests/unit/test_callbacks_writer_table.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update tests/unit/test_callbacks_writer_file.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update tests/unit/test_utils_collate.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update tests/unit/test_callbacks_writer_base.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * test: Fix assertion in table writer test for correct key usage * test: Add comprehensive tests for LighterTableWriter functionality * test: Replace mocked Trainer with a real Trainer instance in tests * test: Update Trainer instantiation in table writer tests for simplicity * test: Update test_table_writer_write to validate all CSV entries with pandas * test: Update unit test for LighterTableWriter with expected records * test: Fix DataFrame ID type comparison in table writer tests * test: Update test cases for LighterTableWriter initialization and validation * fix: Mock world_size and is_global_zero methods in tests * test: Update test directory name and improve error handling in tests * test: Fix assertion and error handling in file writer tests * test: Fix test assertions and mock properties in writer tests * fix: Mock world_size and is_global_zero methods in tests * Improve tests provided by coderabbit --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent c332a2b commit 4d0dd74

17 files changed

+562
-2
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,5 @@ projects/*
152152
**/predictions/
153153
*/.DS_Store
154154
.DS_Store
155+
.aider*
156+
test_dir/

lighter/callbacks/writer/table.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No
6060
# Save the records to a CSV file
6161
if trainer.is_global_zero:
6262
df = pd.DataFrame(self.csv_records)
63-
df = df.sort_values("id").set_index("id")
63+
try:
64+
df = df.sort_values("id")
65+
except TypeError:
66+
pass
67+
df = df.set_index("id")
6468
df.to_csv(self.path)
6569

6670
# Clear the records after saving

lighter/utils/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def parse_config(**kwargs) -> ConfigParser:
3636
raise ValueError("'--config' not specified. Please provide a valid configuration file.")
3737

3838
# Initialize the parser with the predefined structure.
39-
parser = ConfigParser(ConfigSchema().dict(), globals=False)
39+
parser = ConfigParser(ConfigSchema().model_dump(), globals=False)
4040
# Update the parser with the configuration file.
4141
parser.update(parser.load_config_files(kwargs.pop("config")))
4242
# Update the parser with the provided cli arguments.

tests/unit/test_callbacks_freezer.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
import torch
3+
from pytorch_lightning import Trainer
4+
from torch.nn import Module
5+
from torch.utils.data import Dataset
6+
7+
from lighter.callbacks.freezer import LighterFreezer
8+
from lighter.system import LighterSystem
9+
10+
11+
class DummyModel(Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.layer1 = torch.nn.Linear(10, 10)
15+
self.layer2 = torch.nn.Linear(10, 4)
16+
self.layer3 = torch.nn.Linear(4, 1)
17+
18+
def forward(self, x):
19+
x = self.layer1(x)
20+
x = self.layer2(x)
21+
x = self.layer3(x)
22+
return x
23+
24+
25+
class DummyDataset(Dataset):
26+
def __len__(self):
27+
return 10
28+
29+
def __getitem__(self, idx):
30+
return {"input": torch.randn(10), "target": torch.tensor(0)}
31+
32+
33+
@pytest.fixture
34+
def dummy_system():
35+
model = DummyModel()
36+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
37+
dataset = DummyDataset()
38+
criterion = torch.nn.CrossEntropyLoss()
39+
return LighterSystem(model=model, batch_size=32, criterion=criterion, optimizer=optimizer, datasets={"train": dataset})
40+
41+
42+
def test_freezer_initialization():
43+
freezer = LighterFreezer(names=["layer1"])
44+
assert freezer.names == ["layer1"]
45+
46+
47+
def test_freezer_functionality(dummy_system):
48+
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
49+
trainer = Trainer(callbacks=[freezer], max_epochs=1)
50+
trainer.fit(dummy_system)
51+
assert not dummy_system.model.layer1.weight.requires_grad
52+
assert not dummy_system.model.layer1.bias.requires_grad
53+
assert dummy_system.model.layer2.weight.requires_grad
54+
55+
56+
def test_freezer_with_exceptions(dummy_system):
57+
freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"])
58+
trainer = Trainer(callbacks=[freezer], max_epochs=1)
59+
trainer.fit(dummy_system)
60+
assert not dummy_system.model.layer1.weight.requires_grad
61+
assert not dummy_system.model.layer1.bias.requires_grad
62+
assert dummy_system.model.layer2.weight.requires_grad
63+
assert dummy_system.model.layer2.bias.requires_grad
64+
assert not dummy_system.model.layer3.weight.requires_grad
65+
assert not dummy_system.model.layer3.bias.requires_grad

tests/unit/test_callbacks_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from lighter.callbacks.utils import preprocess_image
4+
5+
6+
def test_preprocess_image_2d():
7+
image = torch.rand(1, 3, 64, 64) # Batch of 2D images
8+
processed_image = preprocess_image(image)
9+
assert processed_image.shape == (3, 64, 64)
10+
11+
12+
def test_preprocess_image_3d():
13+
batch_size = 8
14+
depth = 20
15+
height = 64
16+
width = 64
17+
image = torch.rand(batch_size, 1, depth, height, width) # Batch of 3D images
18+
processed_image = preprocess_image(image)
19+
assert processed_image.shape == (1, depth * height, batch_size * width)
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import pytest
2+
3+
from lighter.callbacks.writer.base import LighterBaseWriter
4+
5+
6+
def test_writer_initialization():
7+
with pytest.raises(TypeError):
8+
LighterBaseWriter(path="test", writer="tensor")
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
import torch
5+
6+
from lighter.callbacks.writer.file import LighterFileWriter
7+
8+
9+
def test_file_writer_initialization():
10+
"""Test LighterFileWriter initialization with proper attributes."""
11+
path = Path("test_dir")
12+
path.mkdir(exist_ok=True) # Ensure the directory exists
13+
try:
14+
writer = LighterFileWriter(path=path, writer="tensor")
15+
assert writer.path == Path("test_dir")
16+
assert writer.writer.__name__ == "write_tensor" # Verify writer function
17+
finally:
18+
shutil.rmtree(path) # Clean up after test
19+
20+
21+
def test_file_writer_write_tensor():
22+
"""Test LighterFileWriter's ability to write and persist tensors correctly."""
23+
test_dir = Path("test_dir")
24+
test_dir.mkdir(exist_ok=True)
25+
try:
26+
writer = LighterFileWriter(path=test_dir, writer="tensor")
27+
tensor = torch.tensor([1, 2, 3])
28+
writer.write(tensor, id=1)
29+
30+
# Verify file exists
31+
saved_path = writer.path / "1.pt"
32+
assert saved_path.exists()
33+
34+
# Verify tensor contents
35+
loaded_tensor = torch.load(saved_path)
36+
assert torch.equal(loaded_tensor, tensor)
37+
finally:
38+
shutil.rmtree(test_dir)
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from pathlib import Path
2+
from unittest import mock
3+
4+
import pandas as pd
5+
import pytest
6+
import torch
7+
from pytorch_lightning import Trainer
8+
9+
from lighter.callbacks.writer.table import LighterTableWriter
10+
from lighter.system import LighterSystem
11+
12+
13+
def custom_writer(tensor):
14+
return {"custom": tensor.sum().item()}
15+
16+
17+
def test_table_writer_initialization():
18+
writer = LighterTableWriter(path="test.csv", writer="tensor")
19+
assert writer.path == Path("test.csv")
20+
21+
22+
def test_table_writer_custom_writer():
23+
writer = LighterTableWriter(path="test.csv", writer=custom_writer)
24+
test_tensor = torch.tensor([1, 2, 3])
25+
writer.write(tensor=test_tensor, id=1)
26+
assert writer.csv_records[0]["pred"] == {"custom": 6}
27+
28+
29+
def test_table_writer_write():
30+
"""Test LighterTableWriter write functionality with various inputs."""
31+
test_file = Path("test.csv")
32+
writer = LighterTableWriter(path="test.csv", writer="tensor")
33+
34+
expected_records = [
35+
{"id": 1, "pred": [1, 2, 3]},
36+
{"id": "some_id", "pred": -1},
37+
{"id": 1331, "pred": [1.5, 2.5]},
38+
]
39+
# Test basic write
40+
writer.write(tensor=torch.tensor(expected_records[0]["pred"]), id=expected_records[0]["id"])
41+
assert len(writer.csv_records) == 1
42+
assert writer.csv_records[0]["pred"] == expected_records[0]["pred"]
43+
assert writer.csv_records[0]["id"] == expected_records[0]["id"]
44+
45+
# Test edge cases
46+
writer.write(tensor=torch.tensor(expected_records[1]["pred"]), id=expected_records[1]["id"])
47+
writer.write(tensor=torch.tensor(expected_records[2]["pred"]), id=expected_records[2]["id"])
48+
trainer = Trainer(max_epochs=1)
49+
writer.on_predict_epoch_end(trainer, mock.Mock())
50+
51+
# Verify file creation and content
52+
assert test_file.exists()
53+
df = pd.read_csv(test_file)
54+
df["id"] = df["id"].astype(str)
55+
df["pred"] = df["pred"].apply(eval)
56+
57+
for record in expected_records:
58+
row = df[df["id"] == str(record["id"])]
59+
assert not row.empty
60+
pred_value = row["pred"].iloc[0] # get the value from the Series
61+
assert pred_value == record["pred"]
62+
63+
# Cleanup
64+
test_file.unlink()
65+
66+
67+
def test_table_writer_write_multi_process(tmp_path, monkeypatch):
68+
test_file = tmp_path / "test.csv"
69+
writer = LighterTableWriter(path=test_file, writer="tensor")
70+
trainer = Trainer(max_epochs=1)
71+
72+
# Expected records after gathering from all processes
73+
rank0_records = [{"id": 1, "pred": [1, 2, 3]}] # records from rank 0
74+
rank1_records = [{"id": 2, "pred": [4, 5, 6]}] # records from rank 1
75+
expected_records = rank0_records + rank1_records
76+
77+
# Mock distributed functions for multi-process simulation
78+
def mock_gather(obj, gather_list, dst=0):
79+
if gather_list is not None:
80+
# Fill gather_list with records from each rank
81+
gather_list[0] = rank0_records
82+
gather_list[1] = rank1_records
83+
84+
def mock_get_rank():
85+
return 0
86+
87+
monkeypatch.setattr(torch.distributed, "gather_object", mock_gather)
88+
monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank)
89+
monkeypatch.setattr(trainer.strategy, "world_size", 2)
90+
91+
writer.on_predict_epoch_end(trainer, mock.Mock())
92+
93+
# Verify file creation
94+
assert test_file.exists()
95+
96+
# Verify file content
97+
df = pd.read_csv(test_file)
98+
df["id"] = df["id"].astype(str)
99+
df["pred"] = df["pred"].apply(eval)
100+
101+
# Check that all expected records are in the CSV
102+
for record in expected_records:
103+
row = df[df["id"] == str(record["id"])]
104+
assert not row.empty
105+
pred_value = row["pred"].iloc[0]
106+
assert pred_value == record["pred"]
107+
108+
# Verify total number of records
109+
assert len(df) == len(expected_records)

0 commit comments

Comments
 (0)