@@ -132,6 +132,7 @@ def __init__(
132132 evaluation_tracker : EvaluationTracker ,
133133 model_config = None ,
134134 model = None ,
135+ metric_options = None ,
135136 ):
136137 if not (model or model_config ):
137138 raise ValueError ("Must provide either a model or model config when creating a pipeline." )
@@ -145,6 +146,7 @@ def __init__(
145146
146147 self .model_config = model_config
147148 self .evaluation_tracker = evaluation_tracker
149+ self ._metric_options = metric_options or {}
148150 self .accelerator , self .parallel_context = self ._init_parallelism_manager ()
149151 self .model = self ._init_model (model_config , model )
150152
@@ -209,6 +211,10 @@ def _init_tasks_and_requests(self, tasks: str):
209211 )
210212 task_names_list , fewshots_dict = taskinfo_selector (tasks , registry )
211213 task_dict = registry .get_task_dict (task_names_list )
214+ # If there are metric_options defined from the yaml file,
215+ # review if they have to be updated.
216+ if self ._metric_options :
217+ self ._update_num_samples (task_dict )
212218 LightevalTask .load_datasets (list (task_dict .values ()), self .pipeline_parameters .dataset_loading_processes )
213219
214220 self .evaluation_tracker .task_config_logger .log (task_dict )
@@ -230,6 +236,19 @@ def _init_tasks_and_requests(self, tasks: str):
230236 self .requests = requests
231237 self .docs = docs
232238
239+ def _update_num_samples (self , task_dict : dict [str , LightevalTask ]):
240+ """Helper function to update the num_samples of a given metric via the yaml file.
241+ As it has to be done at the metric level, it's better to update the value per metric.
242+ It will add a num_samples to the already defined metrics' num_samples if defined in the yaml file.
243+ As later when constructing the requests the max is taken over the num_samples, this is valid.
244+ """
245+ for _ , task in task_dict .items ():
246+ for metric in task .metrics :
247+ if metric_data := self ._metric_options .get (metric .metric_name , None ):
248+ num_samples = metric_data .get ("num_samples" , None )
249+ if num_samples :
250+ task .num_samples = [num_samples ]
251+
233252 def _init_random_seeds (self ):
234253 logger .info ("--- INIT SEEDS ---" )
235254 random .seed (1234 )
0 commit comments