Skip to content

Commit 5d172c8

Browse files
committed
[TEST] fix integrations test despte random max_samples
1 parent bd5513f commit 5d172c8

File tree

3 files changed

+32
-22
lines changed

3 files changed

+32
-22
lines changed

Diff for: src/alpaca_eval/decoders/huggingface_api.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def huggingface_api_completions(
5151
token=constants.HUGGINGFACEHUB_API_TOKEN,
5252
)
5353

54-
default_kwargs = dict(do_sample=do_sample, options=dict(wait_for_model=True), return_full_text=False)
54+
default_kwargs = dict(do_sample=do_sample, return_full_text=False)
5555
default_kwargs.update(kwargs)
5656
logging.info(f"Kwargs to completion: {default_kwargs}")
5757

@@ -72,32 +72,33 @@ def huggingface_api_completions(
7272
)
7373
logging.info(f"Time for {n_examples} completions: {t}")
7474

75-
completions = [completion["generated_text"] for completion in completions]
76-
7775
# unclear pricing
7876
price = [np.nan] * len(completions)
7977
avg_time = [t.duration / n_examples] * len(completions)
8078

8179
return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
8280

8381

84-
def inference_helper(prompt: str, inference, params, n_retries=100, waiting_time=2) -> dict:
82+
def inference_helper(prompt: str, inference, params, n_retries=100, waiting_time=2) -> str:
8583
for _ in range(n_retries):
86-
output = inference(inputs=prompt, params=params)
87-
if "error" in output and n_retries > 0:
88-
error = output["error"]
89-
if "Rate limit reached" in output["error"]:
90-
logging.warning(f"Rate limit reached... Trying again in {waiting_time} seconds. Full error: {error}")
91-
time.sleep(waiting_time)
92-
elif "Input validation error" in error and "max_new_tokens" in error:
93-
params["max_new_tokens"] = int(params["max_new_tokens"] * 0.8)
94-
logging.warning(
95-
f"`max_new_tokens` too large. Reducing target length to {params['max_new_tokens']}, " f"Retrying..."
96-
)
97-
if params["max_new_tokens"] == 0:
84+
try:
85+
# TODO: check why doesn't stop after </s>
86+
output = inference(prompt=prompt, **params)
87+
except Exception as error:
88+
if n_retries > 0:
89+
if "Rate limit reached" in error:
90+
logging.warning(f"Rate limit reached... Trying again in {waiting_time} seconds.")
91+
time.sleep(waiting_time)
92+
elif "Input validation error" in error and "max_new_tokens" in error:
93+
params["max_new_tokens"] = int(params["max_new_tokens"] * 0.8)
94+
logging.warning(
95+
f"`max_new_tokens` too large. Reducing target length to {params['max_new_tokens']}, "
96+
f"Retrying..."
97+
)
98+
if params["max_new_tokens"] == 0:
99+
raise ValueError(f"Error in inference. Full error: {error}")
100+
else:
98101
raise ValueError(f"Error in inference. Full error: {error}")
99102
else:
100-
raise ValueError(f"Error in inference. Full error: {error}")
101-
else:
102-
return output[0]
103-
raise ValueError(f"Error in inference. We tried {n_retries} times and failed.")
103+
raise ValueError(f"Error in inference. We tried {n_retries} times and failed. Full error: {error}")
104+
return output

Diff for: tests/integration_tests/test_example_integration.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import os
12
import subprocess
23

34
import pytest
45

56

7+
# example file is from 003 so should always lose against gpt4 turbo
68
@pytest.mark.slow
79
def test_cli_evaluate_example():
10+
env = os.environ.copy()
11+
env["IS_ALPACA_EVAL_2"] = "True"
12+
813
result = subprocess.run(
914
[
1015
"alpaca_eval",
@@ -19,15 +24,18 @@ def test_cli_evaluate_example():
1924
],
2025
capture_output=True,
2126
text=True,
27+
env=env,
2228
)
2329
normalized_output = " ".join(result.stdout.split())
24-
expected_output = " ".join("example 33.33 33.33 3".split())
30+
expected_output = " ".join("example 0.00 0.00 3".split())
2531

2632
assert expected_output in normalized_output
2733

2834

2935
@pytest.mark.slow
3036
def test_openai_fn_evaluate_example():
37+
env = os.environ.copy()
38+
env["IS_ALPACA_EVAL_2"] = "True"
3139
result = subprocess.run(
3240
[
3341
"alpaca_eval",
@@ -42,6 +50,7 @@ def test_openai_fn_evaluate_example():
4250
],
4351
capture_output=True,
4452
text=True,
53+
env=env,
4554
)
4655
normalized_output = " ".join(result.stdout.split())
4756
expected_output = " ".join("example 0.00 0.00 2".split())

Diff for: tests/test_decoders_unit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_cohere_completions(mocker):
5757
def test_huggingface_api_completions(mocker):
5858
mocker.patch(
5959
"alpaca_eval.decoders.huggingface_api.inference_helper",
60-
return_value=dict(generated_text="Mocked completion text"),
60+
return_value="Mocked completion text",
6161
)
6262
result = huggingface_api_completions(
6363
["Prompt 1", "Prompt 2"],

0 commit comments

Comments
 (0)