Skip to content

Commit d650320

Browse files
authored
[None][infra] Improve the failure message for accuracy test suite (#7994)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 108248e commit d650320

File tree

2 files changed

+96
-45
lines changed

2 files changed

+96
-45
lines changed

tests/integration/defs/accuracy/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ meta-llama/Llama-3.1-8B-Instruct:
125125
126126
The first item is the default accuracy specification (i.e., using original Hugging Face model data type and no quantization), and the reference accuracy is 68.17. The second item is an accuracy specification with FP8 GEMM quantization, with a slightly lower reference accuracy 67.93. The third item is a specification with FP8 GEMM and KV cache quantization, with a further slightly lower reference accuracy 67.87.
127127
128-
Model data type and quantization decide the precision in model computation, so accuracy differences can be *justified* if different data types or quantizations are used. Hence, they are the most typical components in accuracy specifications. Please see other categories of accuracy specifications documented in `AccuracyTask.get_num_samples_and_threshold` in [accuracy_core.py](./accuracy_core.py). Note that we exclude most inference features such as parallelism, because theoretically they should not affect model accuracy. Think from the opposite perspective, if enabling tensor parallelism results in statistically significant accuracy loss, we might need to check whether some accuracy bugs exist.
128+
Model data type and quantization decide the precision in model computation, so accuracy differences can be *justified* if different data types or quantizations are used. Hence, they are the most typical components in accuracy specifications. Please see other categories of accuracy specifications documented in `AccuracyTask.get_hypothesis_testing_params` in [accuracy_core.py](./accuracy_core.py). Note that we exclude most inference features such as parallelism, because theoretically they should not affect model accuracy. Think from the opposite perspective, if enabling tensor parallelism results in statistically significant accuracy loss, we might need to check whether some accuracy bugs exist.
129129

130130
A direct implication is that multiple test cases with different features may share the same accuracy reference. This is by design. For example, we should expect a test case with tensor parallelism to have very similar accuracy to its single-GPU counterpart.
131131

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import math
1717
import os
1818
import tempfile
19+
from dataclasses import dataclass, field
1920
from typing import Dict, List, Optional, Union
2021

2122
import pytest
@@ -66,6 +67,57 @@ def compute_threshold(num_samples: int,
6667
return ref_accuracy - z_alpha * scale
6768

6869

70+
@dataclass(slots=True)
71+
class HypothesisTestingParams:
72+
ref_accuracy: float
73+
num_samples: int
74+
alpha: float = 0.05
75+
beta: float = 0.2
76+
sigma: float = 50.0
77+
higher_is_better: bool = True
78+
theta: float = field(init=False)
79+
threshold: float = field(init=False)
80+
81+
def __post_init__(self) -> None:
82+
self.theta = compute_theta(self.num_samples,
83+
sigma=self.sigma,
84+
alpha=self.alpha,
85+
beta=self.beta)
86+
self.threshold = compute_threshold(
87+
self.num_samples,
88+
self.ref_accuracy,
89+
sigma=self.sigma,
90+
alpha=self.alpha,
91+
higher_is_better=self.higher_is_better)
92+
93+
def report(self, accuracy: Optional[float] = None) -> str:
94+
report = f"""===========================================================
95+
= ACCURACY HYPOTHESIS TESTING
96+
===========================================================
97+
Alpha (Type I: False Positive): {self.alpha:.3f}
98+
Beta (Type II: False Negative): {self.beta:.3f}
99+
Sigma (Standard deviation): {self.sigma:.3f}
100+
#Samples: {self.num_samples}
101+
Higher is better: {self.higher_is_better}
102+
Theta (Minimum detectable effect): {self.theta:.3f}
103+
Reference accuracy: {self.ref_accuracy:.3f}
104+
Threshold: {self.threshold:.3f}
105+
==========================================================="""
106+
if accuracy is not None:
107+
report = f"""{report}
108+
Evaluated accuracy: {accuracy:.3f}
109+
==========================================================="""
110+
return report
111+
112+
def assert_passing(self, accuracy: float) -> None:
113+
compare_op = ">=" if self.higher_is_better else "<="
114+
err_msg = f"Reference accuracy is {self.ref_accuracy:.3f}, threshold is {self.threshold:.3f}. Expected accuracy {compare_op} threshold, but got {accuracy:.3f}. Please see hypothesis testing report:\n{self.report(accuracy)}"
115+
if self.higher_is_better:
116+
assert accuracy >= self.threshold, err_msg
117+
else:
118+
assert accuracy <= self.threshold, err_msg
119+
120+
69121
class AccuracyTask:
70122
REFERENCE_DIR = f"{os.path.dirname(__file__)}/references"
71123

@@ -93,8 +145,9 @@ def __init__(self, model_name: str):
93145
with open(f"{self.REFERENCE_DIR}/{self.DATASET}.yaml") as f:
94146
self.reference: List[dict] = yaml.safe_load(f).get(model_name, [])
95147

96-
def get_num_samples_and_threshold(self, **acc_specs):
97-
"""Get num_samples and threshold via accuracy specifications.
148+
def get_hypothesis_testing_params(self,
149+
**acc_specs) -> HypothesisTestingParams:
150+
"""Get hypothesis testing parameters via accuracy specifications.
98151
99152
Args:
100153
acc_specs: Accuracy specifications, currently including:
@@ -119,30 +172,14 @@ def get_num_samples_and_threshold(self, **acc_specs):
119172
else:
120173
raise ValueError(f"Not registered specs: {acc_specs}.")
121174

122-
accuracy = entry.get("accuracy")
123-
alpha = entry.get("alpha", self.ALPHA)
124-
beta = entry.get("beta", self.BETA)
125-
sigma = entry.get("sigma", self.SIGMA)
126-
num_samples = entry.get("num_samples", self.NUM_SAMPLES)
127-
higher_is_better = entry.get("higher_is_better", self.HIGHER_IS_BETTER)
128-
theta = compute_theta(num_samples, sigma=sigma, alpha=alpha, beta=beta)
129-
threshold = compute_threshold(num_samples,
130-
accuracy,
131-
sigma=sigma,
132-
alpha=alpha,
133-
higher_is_better=higher_is_better)
134-
print("===========================================================\n"
135-
"= ACCURACY HYPOTHESIS TESTING\n"
136-
"===========================================================\n"
137-
f"Alpha (Type I: False Positive): {alpha:.3f}\n"
138-
f"Beta (Type II: False Negative): {beta:.3f}\n"
139-
f"Sigma (Standard deviation): {sigma:.3f}\n"
140-
f"#Samples: {num_samples}\n"
141-
f"Theta (Minimum detectable effect): {theta:.3f}\n"
142-
f"Reference accuracy: {accuracy:.3f}\n"
143-
f"Threshold: {threshold:.3f}\n"
144-
"===========================================================\n")
145-
return num_samples, threshold
175+
return HypothesisTestingParams(
176+
ref_accuracy=entry.get("accuracy"),
177+
alpha=entry.get("alpha", self.ALPHA),
178+
beta=entry.get("beta", self.BETA),
179+
sigma=entry.get("sigma", self.SIGMA),
180+
num_samples=entry.get("num_samples", self.NUM_SAMPLES),
181+
higher_is_better=entry.get("higher_is_better",
182+
self.HIGHER_IS_BETTER))
146183

147184
def evaluate(self,
148185
llm: Union[LLM, PyTorchLLM, AutoDeployLLM],
@@ -165,13 +202,13 @@ def evaluate(self,
165202
is_integration_test = os.getenv('INTEGRATION_TEST', '0') == '1'
166203

167204
if is_integration_test:
168-
num_samples = 1
169205
logger.info(
170206
"Running in INTEGRATION_TEST mode: using only 1 sample and skipping accuracy verification"
171207
)
172-
threshold = 0
208+
hypothesis_testing_params = HypothesisTestingParams(ref_accuracy=0,
209+
num_samples=1)
173210
else:
174-
num_samples, threshold = self.get_num_samples_and_threshold(
211+
hypothesis_testing_params = self.get_hypothesis_testing_params(
175212
dtype=llm.args.dtype,
176213
quant_algo=llm.args.quant_config.quant_algo,
177214
kv_cache_quant_algo=llm.args.quant_config.kv_cache_quant_algo,
@@ -193,17 +230,19 @@ def evaluate(self,
193230
evaluator_kwargs.update(self.EVALUATOR_KWARGS)
194231
if extra_evaluator_kwargs is not None:
195232
evaluator_kwargs.update(extra_evaluator_kwargs)
196-
evaluator = self.EVALUATOR_CLS(num_samples=num_samples,
197-
**evaluator_kwargs)
233+
evaluator = self.EVALUATOR_CLS(
234+
num_samples=hypothesis_testing_params.num_samples,
235+
**evaluator_kwargs)
198236
evaluate_kwargs = {}
199237
if hasattr(self, 'EVALUATE_KWARGS'):
200238
evaluate_kwargs.update(self.EVALUATE_KWARGS)
201239
accuracy = evaluator.evaluate(llm, sampling_params, streaming,
202240
**evaluate_kwargs)
203-
if self.HIGHER_IS_BETTER:
204-
assert accuracy >= threshold, f"Expected accuracy >= {threshold}, but got {accuracy}."
205-
else:
206-
assert accuracy <= threshold, f"Expected accuracy <= {threshold}, but got {accuracy}."
241+
242+
logger.info(
243+
f"Hypothesis testing report:\n{hypothesis_testing_params.report(accuracy)}"
244+
)
245+
hypothesis_testing_params.assert_passing(accuracy)
207246

208247

209248
class CnnDailymail(AccuracyTask):
@@ -457,7 +496,7 @@ def initialize_case(self,
457496
self.env = env
458497

459498
def convert(self):
460-
print("Converting model to TensorRT LLM checkpoint...")
499+
logger.info("Converting model to TensorRT LLM checkpoint...")
461500

462501
is_prequantized = False
463502
for quant_config_file in [
@@ -559,7 +598,7 @@ def convert(self):
559598
venv_check_call(self.llm_venv, convert_cmd)
560599

561600
def build(self):
562-
print("Building engines...")
601+
logger.info("Building engines...")
563602
max_batch_size = max(task.MAX_BATCH_SIZE for task in self.tasks)
564603
max_input_len = max(task.MAX_INPUT_LEN for task in self.tasks)
565604
max_seq_len = max(task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN
@@ -578,7 +617,7 @@ def build(self):
578617
check_call(" ".join(build_cmd), shell=True, env=self.llm_venv._new_env)
579618

580619
def summarize(self, task: AccuracyTask):
581-
print("Running summarize...")
620+
logger.info("Running summarize...")
582621
summarize_cmd = [
583622
f"{self.llm_root}/examples/summarize.py",
584623
f"--engine_dir={self.engine_dir}",
@@ -595,12 +634,16 @@ def summarize(self, task: AccuracyTask):
595634
"--no_add_special_tokens"
596635
])
597636

598-
num_samples, threshold = task.get_num_samples_and_threshold(
637+
hypothesis_testing_params = task.get_hypothesis_testing_params(
599638
dtype=self.dtype,
600639
quant_algo=self.quant_algo,
601640
kv_cache_quant_algo=self.kv_cache_quant_algo,
602641
spec_dec_algo=self.spec_dec_algo,
603642
extra_acc_spec=self.extra_acc_spec)
643+
logger.info(
644+
f"Hypothesis testing report:\n{hypothesis_testing_params.report()}")
645+
num_samples = hypothesis_testing_params.num_samples
646+
threshold = hypothesis_testing_params.threshold
604647

605648
if num_samples < task.MAX_BATCH_SIZE:
606649
max_ite = 1
@@ -642,13 +685,17 @@ def summarize(self, task: AccuracyTask):
642685
str(world_size), "--allow-run-as-root"], summarize_cmd)
643686

644687
def mmlu(self, task: AccuracyTask):
645-
print("Running mmlu...")
646-
num_samples, threshold = task.get_num_samples_and_threshold(
688+
logger.info("Running mmlu...")
689+
hypothesis_testing_params = task.get_hypothesis_testing_params(
647690
dtype=self.dtype,
648691
quant_algo=self.quant_algo,
649692
kv_cache_quant_algo=self.kv_cache_quant_algo,
650693
spec_dec_algo=self.spec_dec_algo,
651694
extra_acc_spec=self.extra_acc_spec)
695+
logger.info(
696+
f"Hypothesis testing report:\n{hypothesis_testing_params.report()}")
697+
num_samples = hypothesis_testing_params.num_samples
698+
threshold = hypothesis_testing_params.threshold
652699

653700
mmlu_cmd = [
654701
"trtllm-eval",
@@ -669,27 +716,31 @@ def mmlu(self, task: AccuracyTask):
669716
check_call(" ".join(mmlu_cmd), shell=True, env=self.llm_venv._new_env)
670717

671718
def eval_long_context(self, task: AccuracyTask):
672-
print("Running construct_synthetic_dataset...")
719+
logger.info("Running construct_synthetic_dataset...")
673720
data_gen_cmd = [
674721
f"{self.llm_root}/examples/infinitebench/construct_synthetic_dataset.py",
675722
"--test_case=build_passkey", f"--test_level={task.LEVEL}"
676723
]
677724
venv_check_call(self.llm_venv, data_gen_cmd)
678725

679-
print("Running eval_long_context...")
726+
logger.info("Running eval_long_context...")
680727
eval_cmd = [
681728
f"{self.llm_root}/examples/eval_long_context.py", "--task=passkey",
682729
f"--engine_dir={self.engine_dir}",
683730
f"--tokenizer_dir={self.MODEL_PATH}",
684731
f"--max_input_length={task.MAX_INPUT_LEN}",
685732
"--enable_chunked_context"
686733
]
687-
num_samples, threshold = task.get_num_samples_and_threshold(
734+
hypothesis_testing_params = task.get_hypothesis_testing_params(
688735
dtype=self.dtype,
689736
quant_algo=self.quant_algo,
690737
kv_cache_quant_algo=self.kv_cache_quant_algo,
691738
spec_dec_algo=self.spec_dec_algo,
692739
extra_acc_spec=self.extra_acc_spec)
740+
logger.info(
741+
f"Hypothesis testing report:\n{hypothesis_testing_params.report()}")
742+
num_samples = hypothesis_testing_params.num_samples
743+
threshold = hypothesis_testing_params.threshold
693744

694745
batch_size = min(task.MAX_BATCH_SIZE, num_samples)
695746
eval_cmd.extend([

0 commit comments

Comments
 (0)