1616import math
1717import os
1818import tempfile
19+ from dataclasses import dataclass , field
1920from typing import Dict , List , Optional , Union
2021
2122import pytest
@@ -66,6 +67,57 @@ def compute_threshold(num_samples: int,
6667 return ref_accuracy - z_alpha * scale
6768
6869
70+ @dataclass (slots = True )
71+ class HypothesisTestingParams :
72+ ref_accuracy : float
73+ num_samples : int
74+ alpha : float = 0.05
75+ beta : float = 0.2
76+ sigma : float = 50.0
77+ higher_is_better : bool = True
78+ theta : float = field (init = False )
79+ threshold : float = field (init = False )
80+
81+ def __post_init__ (self ) -> None :
82+ self .theta = compute_theta (self .num_samples ,
83+ sigma = self .sigma ,
84+ alpha = self .alpha ,
85+ beta = self .beta )
86+ self .threshold = compute_threshold (
87+ self .num_samples ,
88+ self .ref_accuracy ,
89+ sigma = self .sigma ,
90+ alpha = self .alpha ,
91+ higher_is_better = self .higher_is_better )
92+
93+ def report (self , accuracy : Optional [float ] = None ) -> str :
94+ report = f"""===========================================================
95+ = ACCURACY HYPOTHESIS TESTING
96+ ===========================================================
97+ Alpha (Type I: False Positive): { self .alpha :.3f}
98+ Beta (Type II: False Negative): { self .beta :.3f}
99+ Sigma (Standard deviation): { self .sigma :.3f}
100+ #Samples: { self .num_samples }
101+ Higher is better: { self .higher_is_better }
102+ Theta (Minimum detectable effect): { self .theta :.3f}
103+ Reference accuracy: { self .ref_accuracy :.3f}
104+ Threshold: { self .threshold :.3f}
105+ ==========================================================="""
106+ if accuracy is not None :
107+ report = f"""{ report }
108+ Evaluated accuracy: { accuracy :.3f}
109+ ==========================================================="""
110+ return report
111+
112+ def assert_passing (self , accuracy : float ) -> None :
113+ compare_op = ">=" if self .higher_is_better else "<="
114+ err_msg = f"Reference accuracy is { self .ref_accuracy :.3f} , threshold is { self .threshold :.3f} . Expected accuracy { compare_op } threshold, but got { accuracy :.3f} . Please see hypothesis testing report:\n { self .report (accuracy )} "
115+ if self .higher_is_better :
116+ assert accuracy >= self .threshold , err_msg
117+ else :
118+ assert accuracy <= self .threshold , err_msg
119+
120+
69121class AccuracyTask :
70122 REFERENCE_DIR = f"{ os .path .dirname (__file__ )} /references"
71123
@@ -93,8 +145,9 @@ def __init__(self, model_name: str):
93145 with open (f"{ self .REFERENCE_DIR } /{ self .DATASET } .yaml" ) as f :
94146 self .reference : List [dict ] = yaml .safe_load (f ).get (model_name , [])
95147
96- def get_num_samples_and_threshold (self , ** acc_specs ):
97- """Get num_samples and threshold via accuracy specifications.
148+ def get_hypothesis_testing_params (self ,
149+ ** acc_specs ) -> HypothesisTestingParams :
150+ """Get hypothesis testing parameters via accuracy specifications.
98151
99152 Args:
100153 acc_specs: Accuracy specifications, currently including:
@@ -119,30 +172,14 @@ def get_num_samples_and_threshold(self, **acc_specs):
119172 else :
120173 raise ValueError (f"Not registered specs: { acc_specs } ." )
121174
122- accuracy = entry .get ("accuracy" )
123- alpha = entry .get ("alpha" , self .ALPHA )
124- beta = entry .get ("beta" , self .BETA )
125- sigma = entry .get ("sigma" , self .SIGMA )
126- num_samples = entry .get ("num_samples" , self .NUM_SAMPLES )
127- higher_is_better = entry .get ("higher_is_better" , self .HIGHER_IS_BETTER )
128- theta = compute_theta (num_samples , sigma = sigma , alpha = alpha , beta = beta )
129- threshold = compute_threshold (num_samples ,
130- accuracy ,
131- sigma = sigma ,
132- alpha = alpha ,
133- higher_is_better = higher_is_better )
134- print ("===========================================================\n "
135- "= ACCURACY HYPOTHESIS TESTING\n "
136- "===========================================================\n "
137- f"Alpha (Type I: False Positive): { alpha :.3f} \n "
138- f"Beta (Type II: False Negative): { beta :.3f} \n "
139- f"Sigma (Standard deviation): { sigma :.3f} \n "
140- f"#Samples: { num_samples } \n "
141- f"Theta (Minimum detectable effect): { theta :.3f} \n "
142- f"Reference accuracy: { accuracy :.3f} \n "
143- f"Threshold: { threshold :.3f} \n "
144- "===========================================================\n " )
145- return num_samples , threshold
175+ return HypothesisTestingParams (
176+ ref_accuracy = entry .get ("accuracy" ),
177+ alpha = entry .get ("alpha" , self .ALPHA ),
178+ beta = entry .get ("beta" , self .BETA ),
179+ sigma = entry .get ("sigma" , self .SIGMA ),
180+ num_samples = entry .get ("num_samples" , self .NUM_SAMPLES ),
181+ higher_is_better = entry .get ("higher_is_better" ,
182+ self .HIGHER_IS_BETTER ))
146183
147184 def evaluate (self ,
148185 llm : Union [LLM , PyTorchLLM , AutoDeployLLM ],
@@ -165,13 +202,13 @@ def evaluate(self,
165202 is_integration_test = os .getenv ('INTEGRATION_TEST' , '0' ) == '1'
166203
167204 if is_integration_test :
168- num_samples = 1
169205 logger .info (
170206 "Running in INTEGRATION_TEST mode: using only 1 sample and skipping accuracy verification"
171207 )
172- threshold = 0
208+ hypothesis_testing_params = HypothesisTestingParams (ref_accuracy = 0 ,
209+ num_samples = 1 )
173210 else :
174- num_samples , threshold = self .get_num_samples_and_threshold (
211+ hypothesis_testing_params = self .get_hypothesis_testing_params (
175212 dtype = llm .args .dtype ,
176213 quant_algo = llm .args .quant_config .quant_algo ,
177214 kv_cache_quant_algo = llm .args .quant_config .kv_cache_quant_algo ,
@@ -193,17 +230,19 @@ def evaluate(self,
193230 evaluator_kwargs .update (self .EVALUATOR_KWARGS )
194231 if extra_evaluator_kwargs is not None :
195232 evaluator_kwargs .update (extra_evaluator_kwargs )
196- evaluator = self .EVALUATOR_CLS (num_samples = num_samples ,
197- ** evaluator_kwargs )
233+ evaluator = self .EVALUATOR_CLS (
234+ num_samples = hypothesis_testing_params .num_samples ,
235+ ** evaluator_kwargs )
198236 evaluate_kwargs = {}
199237 if hasattr (self , 'EVALUATE_KWARGS' ):
200238 evaluate_kwargs .update (self .EVALUATE_KWARGS )
201239 accuracy = evaluator .evaluate (llm , sampling_params , streaming ,
202240 ** evaluate_kwargs )
203- if self .HIGHER_IS_BETTER :
204- assert accuracy >= threshold , f"Expected accuracy >= { threshold } , but got { accuracy } ."
205- else :
206- assert accuracy <= threshold , f"Expected accuracy <= { threshold } , but got { accuracy } ."
241+
242+ logger .info (
243+ f"Hypothesis testing report:\n { hypothesis_testing_params .report (accuracy )} "
244+ )
245+ hypothesis_testing_params .assert_passing (accuracy )
207246
208247
209248class CnnDailymail (AccuracyTask ):
@@ -457,7 +496,7 @@ def initialize_case(self,
457496 self .env = env
458497
459498 def convert (self ):
460- print ("Converting model to TensorRT LLM checkpoint..." )
499+ logger . info ("Converting model to TensorRT LLM checkpoint..." )
461500
462501 is_prequantized = False
463502 for quant_config_file in [
@@ -559,7 +598,7 @@ def convert(self):
559598 venv_check_call (self .llm_venv , convert_cmd )
560599
561600 def build (self ):
562- print ("Building engines..." )
601+ logger . info ("Building engines..." )
563602 max_batch_size = max (task .MAX_BATCH_SIZE for task in self .tasks )
564603 max_input_len = max (task .MAX_INPUT_LEN for task in self .tasks )
565604 max_seq_len = max (task .MAX_INPUT_LEN + task .MAX_OUTPUT_LEN
@@ -578,7 +617,7 @@ def build(self):
578617 check_call (" " .join (build_cmd ), shell = True , env = self .llm_venv ._new_env )
579618
580619 def summarize (self , task : AccuracyTask ):
581- print ("Running summarize..." )
620+ logger . info ("Running summarize..." )
582621 summarize_cmd = [
583622 f"{ self .llm_root } /examples/summarize.py" ,
584623 f"--engine_dir={ self .engine_dir } " ,
@@ -595,12 +634,16 @@ def summarize(self, task: AccuracyTask):
595634 "--no_add_special_tokens"
596635 ])
597636
598- num_samples , threshold = task .get_num_samples_and_threshold (
637+ hypothesis_testing_params = task .get_hypothesis_testing_params (
599638 dtype = self .dtype ,
600639 quant_algo = self .quant_algo ,
601640 kv_cache_quant_algo = self .kv_cache_quant_algo ,
602641 spec_dec_algo = self .spec_dec_algo ,
603642 extra_acc_spec = self .extra_acc_spec )
643+ logger .info (
644+ f"Hypothesis testing report:\n { hypothesis_testing_params .report ()} " )
645+ num_samples = hypothesis_testing_params .num_samples
646+ threshold = hypothesis_testing_params .threshold
604647
605648 if num_samples < task .MAX_BATCH_SIZE :
606649 max_ite = 1
@@ -642,13 +685,17 @@ def summarize(self, task: AccuracyTask):
642685 str (world_size ), "--allow-run-as-root" ], summarize_cmd )
643686
644687 def mmlu (self , task : AccuracyTask ):
645- print ("Running mmlu..." )
646- num_samples , threshold = task .get_num_samples_and_threshold (
688+ logger . info ("Running mmlu..." )
689+ hypothesis_testing_params = task .get_hypothesis_testing_params (
647690 dtype = self .dtype ,
648691 quant_algo = self .quant_algo ,
649692 kv_cache_quant_algo = self .kv_cache_quant_algo ,
650693 spec_dec_algo = self .spec_dec_algo ,
651694 extra_acc_spec = self .extra_acc_spec )
695+ logger .info (
696+ f"Hypothesis testing report:\n { hypothesis_testing_params .report ()} " )
697+ num_samples = hypothesis_testing_params .num_samples
698+ threshold = hypothesis_testing_params .threshold
652699
653700 mmlu_cmd = [
654701 "trtllm-eval" ,
@@ -669,27 +716,31 @@ def mmlu(self, task: AccuracyTask):
669716 check_call (" " .join (mmlu_cmd ), shell = True , env = self .llm_venv ._new_env )
670717
671718 def eval_long_context (self , task : AccuracyTask ):
672- print ("Running construct_synthetic_dataset..." )
719+ logger . info ("Running construct_synthetic_dataset..." )
673720 data_gen_cmd = [
674721 f"{ self .llm_root } /examples/infinitebench/construct_synthetic_dataset.py" ,
675722 "--test_case=build_passkey" , f"--test_level={ task .LEVEL } "
676723 ]
677724 venv_check_call (self .llm_venv , data_gen_cmd )
678725
679- print ("Running eval_long_context..." )
726+ logger . info ("Running eval_long_context..." )
680727 eval_cmd = [
681728 f"{ self .llm_root } /examples/eval_long_context.py" , "--task=passkey" ,
682729 f"--engine_dir={ self .engine_dir } " ,
683730 f"--tokenizer_dir={ self .MODEL_PATH } " ,
684731 f"--max_input_length={ task .MAX_INPUT_LEN } " ,
685732 "--enable_chunked_context"
686733 ]
687- num_samples , threshold = task .get_num_samples_and_threshold (
734+ hypothesis_testing_params = task .get_hypothesis_testing_params (
688735 dtype = self .dtype ,
689736 quant_algo = self .quant_algo ,
690737 kv_cache_quant_algo = self .kv_cache_quant_algo ,
691738 spec_dec_algo = self .spec_dec_algo ,
692739 extra_acc_spec = self .extra_acc_spec )
740+ logger .info (
741+ f"Hypothesis testing report:\n { hypothesis_testing_params .report ()} " )
742+ num_samples = hypothesis_testing_params .num_samples
743+ threshold = hypothesis_testing_params .threshold
693744
694745 batch_size = min (task .MAX_BATCH_SIZE , num_samples )
695746 eval_cmd .extend ([
0 commit comments