Skip to content

Commit 30d94f5

Browse files
authored
[BUG] tool_calls (#429)
1 parent 474386e commit 30d94f5

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ results
88
./*.json
99
client_configs/*.yaml
1010
old_results
11+
results_evaluators
1112

1213
# Byte-compiled / optimized / DLL files
1314
__pycache__/

Diff for: src/alpaca_eval/decoders/openai.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ def openai_completions(
132132

133133
prompt_batches = [prompts[batch_id * batch_size : (batch_id + 1) * batch_size] for batch_id in range(n_batches)]
134134

135-
if isinstance(max_tokens, int):
136-
max_tokens = [max_tokens] * n_examples
137-
138-
inputs = zip(prompt_batches, max_tokens)
135+
try:
136+
inputs = zip(prompt_batches, max_tokens)
137+
except TypeError:
138+
inputs = zip(prompt_batches, [max_tokens] * n_batches)
139139

140140
kwargs = dict(model=model_name, **decoding_kwargs)
141141
kwargs_to_log = {k: v for k, v in kwargs.items() if "api_key" not in k}
@@ -216,7 +216,11 @@ def _openai_completion_helper(
216216
client = all_clients[curr_client_idx]
217217

218218
# copy shared_kwargs to avoid modifying it
219-
kwargs.update(dict(max_tokens=max_tokens, top_p=top_p, temperature=temperature))
219+
to_update = dict()
220+
for k in ["max_tokens", "top_p", "temperature"]:
221+
if locals()[k] is not None:
222+
to_update[k] = locals()[k]
223+
kwargs.update(to_update)
220224
curr_kwargs = copy.deepcopy(kwargs)
221225

222226
# ensure no infinite loop
@@ -242,7 +246,7 @@ def _openai_completion_helper(
242246
# currently we only use function calls to get a JSON object => return raw text of json
243247
choices[i]["text"] = choice.message.function_call.arguments
244248

245-
if choice.message.tool_calls is not None:
249+
if choice.message.tool_calls:
246250
# currently we only use function calls to get a JSON object => return raw text of json
247251
choices[i]["text"] = choice.message.tool_calls[0].function.arguments
248252

@@ -273,10 +277,11 @@ def _openai_completion_helper(
273277
return choices
274278

275279
else:
276-
if "rate limit" in str(e).lower():
280+
if "rate " in str(e).lower():
277281
logging.warning(f"Hit request rate limit; retrying...")
278282
else:
279-
logging.warning(f"Unknown error. \n It's likely a rate limit so we are retrying...")
283+
logging.exception("Unknown error:")
284+
raise e
280285
if len(all_clients) > 1:
281286
curr_client_idx = random.choice([idx for idx in client_idcs if idx != curr_client_idx])
282287
client = all_clients[curr_client_idx]

0 commit comments

Comments
 (0)