@@ -19,17 +19,17 @@ class LighterTableWriter(LighterBaseWriter):
19
19
path (Path): CSV filepath.
20
20
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
21
21
Available writers: "tensor". A custom writer function must take a single argument: `tensor`, and return the record
22
- to be saved in the CSV file. The tensor will be a single tensor without the batch dimension.
22
+ to be saved in the CSV file under 'pred' column . The tensor will be a single tensor without the batch dimension.
23
23
"""
24
24
25
25
def __init__ (self , path : Union [str , Path ], writer : Union [str , Callable ]) -> None :
26
26
super ().__init__ (path , writer )
27
- self .csv_records = {}
27
+ self .csv_records = []
28
28
29
29
@property
30
30
def writers (self ) -> Dict [str , Callable ]:
31
31
return {
32
- "tensor" : lambda tensor : tensor .tolist (),
32
+ "tensor" : lambda tensor : tensor .item () if tensor . numel () == 1 else tensor . tolist (),
33
33
}
34
34
35
35
def write (self , tensor : Any , id : Union [int , str ]) -> None :
@@ -40,9 +40,7 @@ def write(self, tensor: Any, id: Union[int, str]) -> None:
40
40
tensor (Any): Tensor, without the batch dimension, to be recorded.
41
41
id (Union[int, str]): Identifier, used as the key for the record.
42
42
"""
43
- column = "pred"
44
- record = self .writer (tensor )
45
- self .csv_records .setdefault (id , {})[column ] = record
43
+ self .csv_records .append ({"id" : id , "pred" : self .writer (tensor )})
46
44
47
45
def on_predict_epoch_end (self , trainer : Trainer , pl_module : LighterSystem ) -> None :
48
46
"""
@@ -52,19 +50,18 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No
52
50
If training was done in a distributed setting, it gathers predictions from all processes
53
51
and then saves them from the rank 0 process.
54
52
"""
55
- # Sort the records by ID and convert the dictionary to a list
56
- self .csv_records = [self .csv_records [id ] for id in sorted (self .csv_records )]
57
-
58
53
# If in distributed data parallel mode, gather records from all processes to rank 0.
59
54
if trainer .world_size > 1 :
60
- # Create a list to hold the records from each process. Used on rank 0 only.
61
55
gather_csv_records = [None ] * trainer .world_size if trainer .is_global_zero else None
62
- # Each process sends its records to rank 0, which stores them in the `gather_csv_records`.
63
56
torch .distributed .gather_object (self .csv_records , gather_csv_records , dst = 0 )
64
- # Concatenate the gathered records
65
57
if trainer .is_global_zero :
66
58
self .csv_records = list (itertools .chain (* gather_csv_records ))
67
59
68
60
# Save the records to a CSV file
69
61
if trainer .is_global_zero :
70
- pd .DataFrame (self .csv_records ).to_csv (self .path )
62
+ df = pd .DataFrame (self .csv_records )
63
+ df = df .sort_values ("id" ).set_index ("id" )
64
+ df .to_csv (self .path )
65
+
66
+ # Clear the records after saving
67
+ self .csv_records = []
0 commit comments