|
| 1 | +from pathlib import Path |
| 2 | +from unittest import mock |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +from pytorch_lightning import Trainer |
| 8 | + |
| 9 | +from lighter.callbacks.writer.table import LighterTableWriter |
| 10 | +from lighter.system import LighterSystem |
| 11 | + |
| 12 | + |
| 13 | +def custom_writer(tensor): |
| 14 | + return {"custom": tensor.sum().item()} |
| 15 | + |
| 16 | + |
| 17 | +def test_table_writer_initialization(): |
| 18 | + writer = LighterTableWriter(path="test.csv", writer="tensor") |
| 19 | + assert writer.path == Path("test.csv") |
| 20 | + |
| 21 | + |
| 22 | +def test_table_writer_custom_writer(): |
| 23 | + writer = LighterTableWriter(path="test.csv", writer=custom_writer) |
| 24 | + test_tensor = torch.tensor([1, 2, 3]) |
| 25 | + writer.write(tensor=test_tensor, id=1) |
| 26 | + assert writer.csv_records[0]["pred"] == {"custom": 6} |
| 27 | + |
| 28 | + |
| 29 | +def test_table_writer_write(): |
| 30 | + """Test LighterTableWriter write functionality with various inputs.""" |
| 31 | + test_file = Path("test.csv") |
| 32 | + writer = LighterTableWriter(path="test.csv", writer="tensor") |
| 33 | + |
| 34 | + expected_records = [ |
| 35 | + {"id": 1, "pred": [1, 2, 3]}, |
| 36 | + {"id": "some_id", "pred": -1}, |
| 37 | + {"id": 1331, "pred": [1.5, 2.5]}, |
| 38 | + ] |
| 39 | + # Test basic write |
| 40 | + writer.write(tensor=torch.tensor(expected_records[0]["pred"]), id=expected_records[0]["id"]) |
| 41 | + assert len(writer.csv_records) == 1 |
| 42 | + assert writer.csv_records[0]["pred"] == expected_records[0]["pred"] |
| 43 | + assert writer.csv_records[0]["id"] == expected_records[0]["id"] |
| 44 | + |
| 45 | + # Test edge cases |
| 46 | + writer.write(tensor=torch.tensor(expected_records[1]["pred"]), id=expected_records[1]["id"]) |
| 47 | + writer.write(tensor=torch.tensor(expected_records[2]["pred"]), id=expected_records[2]["id"]) |
| 48 | + trainer = Trainer(max_epochs=1) |
| 49 | + writer.on_predict_epoch_end(trainer, mock.Mock()) |
| 50 | + |
| 51 | + # Verify file creation and content |
| 52 | + assert test_file.exists() |
| 53 | + df = pd.read_csv(test_file) |
| 54 | + df["id"] = df["id"].astype(str) |
| 55 | + df["pred"] = df["pred"].apply(eval) |
| 56 | + |
| 57 | + for record in expected_records: |
| 58 | + row = df[df["id"] == str(record["id"])] |
| 59 | + assert not row.empty |
| 60 | + pred_value = row["pred"].iloc[0] # get the value from the Series |
| 61 | + assert pred_value == record["pred"] |
| 62 | + |
| 63 | + # Cleanup |
| 64 | + test_file.unlink() |
| 65 | + |
| 66 | + |
| 67 | +def test_table_writer_write_multi_process(tmp_path, monkeypatch): |
| 68 | + test_file = tmp_path / "test.csv" |
| 69 | + writer = LighterTableWriter(path=test_file, writer="tensor") |
| 70 | + trainer = Trainer(max_epochs=1) |
| 71 | + |
| 72 | + # Expected records after gathering from all processes |
| 73 | + rank0_records = [{"id": 1, "pred": [1, 2, 3]}] # records from rank 0 |
| 74 | + rank1_records = [{"id": 2, "pred": [4, 5, 6]}] # records from rank 1 |
| 75 | + expected_records = rank0_records + rank1_records |
| 76 | + |
| 77 | + # Mock distributed functions for multi-process simulation |
| 78 | + def mock_gather(obj, gather_list, dst=0): |
| 79 | + if gather_list is not None: |
| 80 | + # Fill gather_list with records from each rank |
| 81 | + gather_list[0] = rank0_records |
| 82 | + gather_list[1] = rank1_records |
| 83 | + |
| 84 | + def mock_get_rank(): |
| 85 | + return 0 |
| 86 | + |
| 87 | + monkeypatch.setattr(torch.distributed, "gather_object", mock_gather) |
| 88 | + monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank) |
| 89 | + monkeypatch.setattr(trainer.strategy, "world_size", 2) |
| 90 | + |
| 91 | + writer.on_predict_epoch_end(trainer, mock.Mock()) |
| 92 | + |
| 93 | + # Verify file creation |
| 94 | + assert test_file.exists() |
| 95 | + |
| 96 | + # Verify file content |
| 97 | + df = pd.read_csv(test_file) |
| 98 | + df["id"] = df["id"].astype(str) |
| 99 | + df["pred"] = df["pred"].apply(eval) |
| 100 | + |
| 101 | + # Check that all expected records are in the CSV |
| 102 | + for record in expected_records: |
| 103 | + row = df[df["id"] == str(record["id"])] |
| 104 | + assert not row.empty |
| 105 | + pred_value = row["pred"].iloc[0] |
| 106 | + assert pred_value == record["pred"] |
| 107 | + |
| 108 | + # Verify total number of records |
| 109 | + assert len(df) == len(expected_records) |
0 commit comments