Skip to content

Commit 2f055b0

Browse files
Martin Follmartinfoell
authored andcommitted
Implementation of a new shuffling strategy in RBatchGenerator
This commit introduces a new shuffling strategy for creating batches of data used in ML training loaded by RDataFrame lazily. Shuffling of the data is important in ML to avoid bias towards any particular class, label, type used as target of a classification task in the training. Lazy loading of data into batches allows one to manage datasets that do not fit in memory. The shuffling strategy is implemented as follows. There are three important components: chunk, block and batch. A chunk is the largest piece of data that is loaded into memory at a time. A block is a logical range of consecutive entries, multiple blocks make up a chunk. The batch is the container that will be expected by the model. The sizes of the chunks, blocks and batches are set by the user to specify how shuffled the data should be. Given the total number of entries in the dataset, the chunk size and the block size we compute how many blocks will make up a chunk. The logic of creating the chunks from the blocks is done in the class RChunkConstructor where the chunks are defined by taking blocks from random parts of the dataset. Once the chunks are defined by the indices they are loaded into memory one at a time with the class RChunkLoader where the entries of the chunk are further shuffled. Finally, the chunk is further split into batches with the class RBatchLoader. In the class RBatchGenerator the steps described above are connected, creating training and validation batches from the splitting defined by the user.
1 parent 7b6cd60 commit 2f055b0

File tree

11 files changed

+1428
-634
lines changed

11 files changed

+1428
-634
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_batchgenerator.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# Author: Kristupas Pranckietis, Vilnius University 05/2024
33
# Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
44
# Author: Vincenzo Eduardo Padulano, CERN 10/2024
5+
# Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
56

67
################################################################################
7-
# Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. #
8+
# Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. #
89
# All rights reserved. #
910
# #
1011
# For the licensing terms see $ROOTSYS/LICENSE. #
@@ -20,12 +21,13 @@
2021
import numpy as np
2122
import tensorflow as tf
2223
import torch
24+
import ROOT
2325

2426

2527
class BaseGenerator:
2628
def get_template(
2729
self,
28-
x_rdf: RNode,
30+
x_rdf: ROOT.RDF.RNode,
2931
columns: list[str] = list(),
3032
max_vec_sizes: dict[str, int] = dict(),
3133
) -> Tuple[str, list[int]]:
@@ -80,9 +82,10 @@ def get_template(
8082

8183
def __init__(
8284
self,
83-
rdataframe: RNode,
85+
rdataframe: ROOT.RDF.RNode,
8486
batch_size: int,
8587
chunk_size: int,
88+
block_size: int,
8689
columns: list[str] = list(),
8790
max_vec_sizes: dict[str, int] = dict(),
8891
vec_padding: int = 0,
@@ -92,6 +95,7 @@ def __init__(
9295
max_chunks: int = 0,
9396
shuffle: bool = True,
9497
drop_remainder: bool = True,
98+
set_seed: int = 0,
9599
):
96100
"""Wrapper around the Cpp RBatchGenerator
97101
@@ -126,6 +130,10 @@ def __init__(
126130
drop_remainder (bool):
127131
Drop the remainder of data that is too small to compose full batch.
128132
Defaults to True.
133+
set_seed (int):
134+
For reproducibility: Set the seed for the random number generator used
135+
to split the dataset into training and validation and shuffling of the chunks
136+
Defaults to 0 which means that the seed is set to the random device.
129137
"""
130138

131139
import ROOT
@@ -154,11 +162,6 @@ def __init__(
154162

155163
self.noded_rdf = RDF.AsRNode(rdataframe)
156164

157-
if ROOT.Internal.RDF.GetDataSourceLabel(self.noded_rdf) != "TTreeDS":
158-
raise ValueError(
159-
"RNode object must be created out of TTrees or files of TTree"
160-
)
161-
162165
if isinstance(target, str):
163166
target = [target]
164167

@@ -221,15 +224,16 @@ def __init__(
221224
self.generator = TMVA.Experimental.Internal.RBatchGenerator(template)(
222225
self.noded_rdf,
223226
chunk_size,
227+
block_size,
224228
batch_size,
225229
self.given_columns,
226-
self.num_columns,
227230
max_vec_sizes_list,
228231
vec_padding,
229232
validation_split,
230233
max_chunks,
231234
shuffle,
232235
drop_remainder,
236+
set_seed,
233237
)
234238

235239
atexit.register(self.DeActivate)
@@ -238,6 +242,9 @@ def __init__(
238242
def is_active(self):
239243
return self.generator.IsActive()
240244

245+
def is_training_active(self):
246+
return self.generator.TrainingIsActive()
247+
241248
def Activate(self):
242249
"""Initialize the generator to be used for a loop"""
243250
self.generator.Activate()
@@ -246,6 +253,30 @@ def DeActivate(self):
246253
"""Deactivate the generator"""
247254
self.generator.DeActivate()
248255

256+
def ActivateTrainingEpoch(self):
257+
"""Activate the generator"""
258+
self.generator.ActivateTrainingEpoch()
259+
260+
def ActivateValidationEpoch(self):
261+
"""Activate the generator"""
262+
self.generator.ActivateValidationEpoch()
263+
264+
def DeActivateTrainingEpoch(self):
265+
"""Deactivate the generator"""
266+
self.generator.DeActivateTrainingEpoch()
267+
268+
def DeActivateValidationEpoch(self):
269+
"""Deactivate the generator"""
270+
self.generator.DeActivateValidationEpoch()
271+
272+
def CreateTrainBatches(self):
273+
"""Deactivate the generator"""
274+
self.generator.CreateTrainBatches()
275+
276+
def CreateValidationBatches(self):
277+
"""Deactivate the generator"""
278+
self.generator.CreateValidationBatches()
279+
249280
def GetSample(self):
250281
"""
251282
Return a sample of data that has the same size and types as the actual
@@ -445,12 +476,14 @@ def GetValidationBatch(self) -> Any:
445476
class LoadingThreadContext:
446477
def __init__(self, base_generator: BaseGenerator):
447478
self.base_generator = base_generator
448-
479+
# create training batches from the first chunk
480+
self.base_generator.CreateTrainBatches();
481+
449482
def __enter__(self):
450-
self.base_generator.Activate()
483+
self.base_generator.ActivateTrainingEpoch()
451484

452485
def __exit__(self, type, value, traceback):
453-
self.base_generator.DeActivate()
486+
self.base_generator.DeActivateTrainingEpoch()
454487
return True
455488

456489

@@ -469,6 +502,7 @@ def __init__(self, base_generator: BaseGenerator, conversion_function: Callable)
469502
self.base_generator = base_generator
470503
self.conversion_function = conversion_function
471504

505+
472506
def Activate(self):
473507
"""Start the loading of training batches"""
474508
self.base_generator.Activate()
@@ -503,6 +537,7 @@ def last_batch_no_of_rows(self) -> int:
503537
return self.base_generator.generator.TrainRemainderRows()
504538

505539
def __iter__(self):
540+
506541
self._callable = self.__call__()
507542

508543
return self
@@ -522,16 +557,28 @@ def __call__(self) -> Any:
522557
Union[np.NDArray, torch.Tensor]: A batch of data
523558
"""
524559

525-
with LoadingThreadContext(self.base_generator):
560+
with LoadingThreadContext(self.base_generator):
526561
while True:
527562
batch = self.base_generator.GetTrainBatch()
528-
529563
if batch is None:
530564
break
531-
532565
yield self.conversion_function(batch)
566+
567+
return None
568+
569+
class LoadingThreadContextVal:
570+
def __init__(self, base_generator: BaseGenerator):
571+
self.base_generator = base_generator
572+
# create validation batches from the first chunk
573+
self.base_generator.CreateValidationBatches()
533574

534-
return None
575+
def __enter__(self):
576+
self.base_generator.ActivateValidationEpoch()
577+
578+
def __exit__(self, type, value, traceback):
579+
self.base_generator.DeActivateValidationEpoch()
580+
return True
581+
535582

536583

537584
class ValidationRBatchGenerator:
@@ -588,27 +635,27 @@ def __next__(self):
588635
return batch
589636

590637
def __call__(self) -> Any:
591-
"""Loop through the validation batches
638+
"""Start the loading of batches and yield the results
592639
593640
Yields:
594641
Union[np.NDArray, torch.Tensor]: A batch of data
595642
"""
596-
if self.base_generator.is_active:
597-
self.base_generator.DeActivate()
598-
599-
while True:
600-
batch = self.base_generator.GetValidationBatch()
601-
602-
if not batch:
603-
break
604-
605-
yield self.conversion_function(batch)
606-
607-
643+
644+
with LoadingThreadContextVal(self.base_generator):
645+
while True:
646+
batch = self.base_generator.GetValidationBatch()
647+
if batch is None:
648+
self.base_generator.DeActivateValidationEpoch()
649+
break
650+
yield self.conversion_function(batch)
651+
652+
return None
653+
608654
def CreateNumPyGenerators(
609-
rdataframe: RNode,
655+
rdataframe: ROOT.RDF.RNode,
610656
batch_size: int,
611657
chunk_size: int,
658+
block_size: int,
612659
columns: list[str] = list(),
613660
max_vec_sizes: dict[str, int] = dict(),
614661
vec_padding: int = 0,
@@ -618,6 +665,7 @@ def CreateNumPyGenerators(
618665
max_chunks: int = 0,
619666
shuffle: bool = True,
620667
drop_remainder=True,
668+
set_seed: int = 0,
621669
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
622670
"""
623671
Return two batch generators based on the given ROOT file and tree or RDataFrame
@@ -676,6 +724,7 @@ def CreateNumPyGenerators(
676724
rdataframe,
677725
batch_size,
678726
chunk_size,
727+
block_size,
679728
columns,
680729
max_vec_sizes,
681730
vec_padding,
@@ -685,6 +734,7 @@ def CreateNumPyGenerators(
685734
max_chunks,
686735
shuffle,
687736
drop_remainder,
737+
set_seed,
688738
)
689739

690740
train_generator = TrainRBatchGenerator(
@@ -702,9 +752,10 @@ def CreateNumPyGenerators(
702752

703753

704754
def CreateTFDatasets(
705-
rdataframe: RNode,
755+
rdataframe: ROOT.RDF.RNode,
706756
batch_size: int,
707757
chunk_size: int,
758+
block_size: int,
708759
columns: list[str] = list(),
709760
max_vec_sizes: dict[str, int] = dict(),
710761
vec_padding: int = 0,
@@ -714,6 +765,7 @@ def CreateTFDatasets(
714765
max_chunks: int = 0,
715766
shuffle: bool = True,
716767
drop_remainder=True,
768+
set_seed: int = 0,
717769
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
718770
"""
719771
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
@@ -771,6 +823,7 @@ def CreateTFDatasets(
771823
rdataframe,
772824
batch_size,
773825
chunk_size,
826+
block_size,
774827
columns,
775828
max_vec_sizes,
776829
vec_padding,
@@ -780,6 +833,7 @@ def CreateTFDatasets(
780833
max_chunks,
781834
shuffle,
782835
drop_remainder,
836+
set_seed,
783837
)
784838

785839
train_generator = TrainRBatchGenerator(
@@ -847,9 +901,10 @@ def CreateTFDatasets(
847901

848902

849903
def CreatePyTorchGenerators(
850-
rdataframe: RNode,
904+
rdataframe: ROOT.RDF.RNode,
851905
batch_size: int,
852906
chunk_size: int,
907+
block_size: int,
853908
columns: list[str] = list(),
854909
max_vec_sizes: dict[str, int] = dict(),
855910
vec_padding: int = 0,
@@ -859,6 +914,7 @@ def CreatePyTorchGenerators(
859914
max_chunks: int = 0,
860915
shuffle: bool = True,
861916
drop_remainder=True,
917+
set_seed: int = 0,
862918
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
863919
"""
864920
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
@@ -914,6 +970,7 @@ def CreatePyTorchGenerators(
914970
rdataframe,
915971
batch_size,
916972
chunk_size,
973+
block_size,
917974
columns,
918975
max_vec_sizes,
919976
vec_padding,
@@ -923,6 +980,7 @@ def CreatePyTorchGenerators(
923980
max_chunks,
924981
shuffle,
925982
drop_remainder,
983+
set_seed,
926984
)
927985

928986
train_generator = TrainRBatchGenerator(

0 commit comments

Comments
 (0)