Skip to content

Commit 5e94eb8

Browse files
committed
Fix TableWriter not using ids
1 parent fc6fc7a commit 5e94eb8

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

lighter/callbacks/writer/table.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ class LighterTableWriter(LighterBaseWriter):
1919
path (Path): CSV filepath.
2020
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
2121
Available writers: "tensor". A custom writer function must take a single argument: `tensor`, and return the record
22-
to be saved in the CSV file. The tensor will be a single tensor without the batch dimension.
22+
to be saved in the CSV file under 'pred' column. The tensor will be a single tensor without the batch dimension.
2323
"""
2424

2525
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
2626
super().__init__(path, writer)
27-
self.csv_records = {}
27+
self.csv_records = []
2828

2929
@property
3030
def writers(self) -> Dict[str, Callable]:
3131
return {
32-
"tensor": lambda tensor: tensor.tolist(),
32+
"tensor": lambda tensor: tensor.item() if tensor.numel() == 1 else tensor.tolist(),
3333
}
3434

3535
def write(self, tensor: Any, id: Union[int, str]) -> None:
@@ -40,9 +40,7 @@ def write(self, tensor: Any, id: Union[int, str]) -> None:
4040
tensor (Any): Tensor, without the batch dimension, to be recorded.
4141
id (Union[int, str]): Identifier, used as the key for the record.
4242
"""
43-
column = "pred"
44-
record = self.writer(tensor)
45-
self.csv_records.setdefault(id, {})[column] = record
43+
self.csv_records.append({"id": id, "pred": self.writer(tensor)})
4644

4745
def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
4846
"""
@@ -52,19 +50,18 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No
5250
If training was done in a distributed setting, it gathers predictions from all processes
5351
and then saves them from the rank 0 process.
5452
"""
55-
# Sort the records by ID and convert the dictionary to a list
56-
self.csv_records = [self.csv_records[id] for id in sorted(self.csv_records)]
57-
5853
# If in distributed data parallel mode, gather records from all processes to rank 0.
5954
if trainer.world_size > 1:
60-
# Create a list to hold the records from each process. Used on rank 0 only.
6155
gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None
62-
# Each process sends its records to rank 0, which stores them in the `gather_csv_records`.
6356
torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0)
64-
# Concatenate the gathered records
6557
if trainer.is_global_zero:
6658
self.csv_records = list(itertools.chain(*gather_csv_records))
6759

6860
# Save the records to a CSV file
6961
if trainer.is_global_zero:
70-
pd.DataFrame(self.csv_records).to_csv(self.path)
62+
df = pd.DataFrame(self.csv_records)
63+
df = df.sort_values("id").set_index("id")
64+
df.to_csv(self.path)
65+
66+
# Clear the records after saving
67+
self.csv_records = []

0 commit comments

Comments
 (0)