22# Author: Kristupas Pranckietis, Vilnius University 05/2024
33# Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
44# Author: Vincenzo Eduardo Padulano, CERN 10/2024
5+ # Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
56
67################################################################################
7- # Copyright (C) 1995-2024 , Rene Brun and Fons Rademakers. #
8+ # Copyright (C) 1995-2025 , Rene Brun and Fons Rademakers. #
89# All rights reserved. #
910# #
1011# For the licensing terms see $ROOTSYS/LICENSE. #
2021 import numpy as np
2122 import tensorflow as tf
2223 import torch
24+ import ROOT
2325
2426
2527class BaseGenerator :
2628 def get_template (
2729 self ,
28- x_rdf : RNode ,
30+ x_rdf : ROOT . RDF . RNode ,
2931 columns : list [str ] = list (),
3032 max_vec_sizes : dict [str , int ] = dict (),
3133 ) -> Tuple [str , list [int ]]:
@@ -80,9 +82,10 @@ def get_template(
8082
8183 def __init__ (
8284 self ,
83- rdataframe : RNode ,
85+ rdataframe : ROOT . RDF . RNode ,
8486 batch_size : int ,
8587 chunk_size : int ,
88+ block_size : int ,
8689 columns : list [str ] = list (),
8790 max_vec_sizes : dict [str , int ] = dict (),
8891 vec_padding : int = 0 ,
@@ -92,6 +95,7 @@ def __init__(
9295 max_chunks : int = 0 ,
9396 shuffle : bool = True ,
9497 drop_remainder : bool = True ,
98+ set_seed : int = 0 ,
9599 ):
96100 """Wrapper around the Cpp RBatchGenerator
97101
@@ -126,6 +130,10 @@ def __init__(
126130 drop_remainder (bool):
127131 Drop the remainder of data that is too small to compose full batch.
128132 Defaults to True.
133+ set_seed (int):
134+ For reproducibility: Set the seed for the random number generator used
135+ to split the dataset into training and validation and shuffling of the chunks
136+ Defaults to 0 which means that the seed is set to the random device.
129137 """
130138
131139 import ROOT
@@ -154,11 +162,6 @@ def __init__(
154162
155163 self .noded_rdf = RDF .AsRNode (rdataframe )
156164
157- if ROOT .Internal .RDF .GetDataSourceLabel (self .noded_rdf ) != "TTreeDS" :
158- raise ValueError (
159- "RNode object must be created out of TTrees or files of TTree"
160- )
161-
162165 if isinstance (target , str ):
163166 target = [target ]
164167
@@ -221,15 +224,16 @@ def __init__(
221224 self .generator = TMVA .Experimental .Internal .RBatchGenerator (template )(
222225 self .noded_rdf ,
223226 chunk_size ,
227+ block_size ,
224228 batch_size ,
225229 self .given_columns ,
226- self .num_columns ,
227230 max_vec_sizes_list ,
228231 vec_padding ,
229232 validation_split ,
230233 max_chunks ,
231234 shuffle ,
232235 drop_remainder ,
236+ set_seed ,
233237 )
234238
235239 atexit .register (self .DeActivate )
@@ -238,6 +242,9 @@ def __init__(
238242 def is_active (self ):
239243 return self .generator .IsActive ()
240244
245+ def is_training_active (self ):
246+ return self .generator .TrainingIsActive ()
247+
241248 def Activate (self ):
242249 """Initialize the generator to be used for a loop"""
243250 self .generator .Activate ()
@@ -246,6 +253,30 @@ def DeActivate(self):
246253 """Deactivate the generator"""
247254 self .generator .DeActivate ()
248255
256+ def ActivateTrainingEpoch (self ):
257+ """Activate the generator"""
258+ self .generator .ActivateTrainingEpoch ()
259+
260+ def ActivateValidationEpoch (self ):
261+ """Activate the generator"""
262+ self .generator .ActivateValidationEpoch ()
263+
264+ def DeActivateTrainingEpoch (self ):
265+ """Deactivate the generator"""
266+ self .generator .DeActivateTrainingEpoch ()
267+
268+ def DeActivateValidationEpoch (self ):
269+ """Deactivate the generator"""
270+ self .generator .DeActivateValidationEpoch ()
271+
272+ def CreateTrainBatches (self ):
273+ """Deactivate the generator"""
274+ self .generator .CreateTrainBatches ()
275+
276+ def CreateValidationBatches (self ):
277+ """Deactivate the generator"""
278+ self .generator .CreateValidationBatches ()
279+
249280 def GetSample (self ):
250281 """
251282 Return a sample of data that has the same size and types as the actual
@@ -445,12 +476,14 @@ def GetValidationBatch(self) -> Any:
445476class LoadingThreadContext :
446477 def __init__ (self , base_generator : BaseGenerator ):
447478 self .base_generator = base_generator
448-
479+ # create training batches from the first chunk
480+ self .base_generator .CreateTrainBatches ();
481+
449482 def __enter__ (self ):
450- self .base_generator .Activate ()
483+ self .base_generator .ActivateTrainingEpoch ()
451484
452485 def __exit__ (self , type , value , traceback ):
453- self .base_generator .DeActivate ()
486+ self .base_generator .DeActivateTrainingEpoch ()
454487 return True
455488
456489
@@ -469,6 +502,7 @@ def __init__(self, base_generator: BaseGenerator, conversion_function: Callable)
469502 self .base_generator = base_generator
470503 self .conversion_function = conversion_function
471504
505+
472506 def Activate (self ):
473507 """Start the loading of training batches"""
474508 self .base_generator .Activate ()
@@ -503,6 +537,7 @@ def last_batch_no_of_rows(self) -> int:
503537 return self .base_generator .generator .TrainRemainderRows ()
504538
505539 def __iter__ (self ):
540+
506541 self ._callable = self .__call__ ()
507542
508543 return self
@@ -522,16 +557,28 @@ def __call__(self) -> Any:
522557 Union[np.NDArray, torch.Tensor]: A batch of data
523558 """
524559
525- with LoadingThreadContext (self .base_generator ):
560+ with LoadingThreadContext (self .base_generator ):
526561 while True :
527562 batch = self .base_generator .GetTrainBatch ()
528-
529563 if batch is None :
530564 break
531-
532565 yield self .conversion_function (batch )
566+
567+ return None
568+
569+ class LoadingThreadContextVal :
570+ def __init__ (self , base_generator : BaseGenerator ):
571+ self .base_generator = base_generator
572+ # create validation batches from the first chunk
573+ self .base_generator .CreateValidationBatches ()
533574
534- return None
575+ def __enter__ (self ):
576+ self .base_generator .ActivateValidationEpoch ()
577+
578+ def __exit__ (self , type , value , traceback ):
579+ self .base_generator .DeActivateValidationEpoch ()
580+ return True
581+
535582
536583
537584class ValidationRBatchGenerator :
@@ -588,27 +635,27 @@ def __next__(self):
588635 return batch
589636
590637 def __call__ (self ) -> Any :
591- """Loop through the validation batches
638+ """Start the loading of batches and yield the results
592639
593640 Yields:
594641 Union[np.NDArray, torch.Tensor]: A batch of data
595642 """
596- if self .base_generator .is_active :
597- self .base_generator .DeActivate ()
598-
599- while True :
600- batch = self .base_generator .GetValidationBatch ()
601-
602- if not batch :
603- break
604-
605- yield self .conversion_function (batch )
606-
607-
643+
644+ with LoadingThreadContextVal (self .base_generator ):
645+ while True :
646+ batch = self .base_generator .GetValidationBatch ()
647+ if batch is None :
648+ self .base_generator .DeActivateValidationEpoch ()
649+ break
650+ yield self .conversion_function (batch )
651+
652+ return None
653+
608654def CreateNumPyGenerators (
609- rdataframe : RNode ,
655+ rdataframe : ROOT . RDF . RNode ,
610656 batch_size : int ,
611657 chunk_size : int ,
658+ block_size : int ,
612659 columns : list [str ] = list (),
613660 max_vec_sizes : dict [str , int ] = dict (),
614661 vec_padding : int = 0 ,
@@ -618,6 +665,7 @@ def CreateNumPyGenerators(
618665 max_chunks : int = 0 ,
619666 shuffle : bool = True ,
620667 drop_remainder = True ,
668+ set_seed : int = 0 ,
621669) -> Tuple [TrainRBatchGenerator , ValidationRBatchGenerator ]:
622670 """
623671 Return two batch generators based on the given ROOT file and tree or RDataFrame
@@ -676,6 +724,7 @@ def CreateNumPyGenerators(
676724 rdataframe ,
677725 batch_size ,
678726 chunk_size ,
727+ block_size ,
679728 columns ,
680729 max_vec_sizes ,
681730 vec_padding ,
@@ -685,6 +734,7 @@ def CreateNumPyGenerators(
685734 max_chunks ,
686735 shuffle ,
687736 drop_remainder ,
737+ set_seed ,
688738 )
689739
690740 train_generator = TrainRBatchGenerator (
@@ -702,9 +752,10 @@ def CreateNumPyGenerators(
702752
703753
704754def CreateTFDatasets (
705- rdataframe : RNode ,
755+ rdataframe : ROOT . RDF . RNode ,
706756 batch_size : int ,
707757 chunk_size : int ,
758+ block_size : int ,
708759 columns : list [str ] = list (),
709760 max_vec_sizes : dict [str , int ] = dict (),
710761 vec_padding : int = 0 ,
@@ -714,6 +765,7 @@ def CreateTFDatasets(
714765 max_chunks : int = 0 ,
715766 shuffle : bool = True ,
716767 drop_remainder = True ,
768+ set_seed : int = 0 ,
717769) -> Tuple [tf .data .Dataset , tf .data .Dataset ]:
718770 """
719771 Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
@@ -771,6 +823,7 @@ def CreateTFDatasets(
771823 rdataframe ,
772824 batch_size ,
773825 chunk_size ,
826+ block_size ,
774827 columns ,
775828 max_vec_sizes ,
776829 vec_padding ,
@@ -780,6 +833,7 @@ def CreateTFDatasets(
780833 max_chunks ,
781834 shuffle ,
782835 drop_remainder ,
836+ set_seed ,
783837 )
784838
785839 train_generator = TrainRBatchGenerator (
@@ -847,9 +901,10 @@ def CreateTFDatasets(
847901
848902
849903def CreatePyTorchGenerators (
850- rdataframe : RNode ,
904+ rdataframe : ROOT . RDF . RNode ,
851905 batch_size : int ,
852906 chunk_size : int ,
907+ block_size : int ,
853908 columns : list [str ] = list (),
854909 max_vec_sizes : dict [str , int ] = dict (),
855910 vec_padding : int = 0 ,
@@ -859,6 +914,7 @@ def CreatePyTorchGenerators(
859914 max_chunks : int = 0 ,
860915 shuffle : bool = True ,
861916 drop_remainder = True ,
917+ set_seed : int = 0 ,
862918) -> Tuple [TrainRBatchGenerator , ValidationRBatchGenerator ]:
863919 """
864920 Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
@@ -914,6 +970,7 @@ def CreatePyTorchGenerators(
914970 rdataframe ,
915971 batch_size ,
916972 chunk_size ,
973+ block_size ,
917974 columns ,
918975 max_vec_sizes ,
919976 vec_padding ,
@@ -923,6 +980,7 @@ def CreatePyTorchGenerators(
923980 max_chunks ,
924981 shuffle ,
925982 drop_remainder ,
983+ set_seed ,
926984 )
927985
928986 train_generator = TrainRBatchGenerator (
0 commit comments