Skip to content

Commit a9ab7e7

Browse files
bnativib.nativi
and
b.nativi
authored
Add Retries for LLM-Guided Metrics (#728)
Co-authored-by: b.nativi <[email protected]>
1 parent 3c56069 commit a9ab7e7

File tree

7 files changed

+419
-44
lines changed

7 files changed

+419
-44
lines changed

api/tests/functional-tests/backend/core/test_llm_clients.py

+104-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
import os
3-
from unittest.mock import MagicMock
3+
from unittest.mock import MagicMock, Mock
44

55
import pytest
66
from mistralai.models import (
@@ -1320,6 +1320,109 @@ def _return_invalid4_toxicity_response(*args, **kwargs):
13201320
client.toxicity("some text")
13211321

13221322

1323+
def test_LLMClient_retries(monkeypatch):
1324+
"""
1325+
Test the retry functionality for structuring LLM API calls.
1326+
"""
1327+
1328+
def _return_valid_summary_coherence_response(*args, **kwargs):
1329+
return "5"
1330+
1331+
errors = ["The score is 5."] * 3 + ["5"]
1332+
1333+
def _return_invalid_summary_coherence_response(*args, **kwargs):
1334+
return "The score is 5."
1335+
1336+
monkeypatch.setattr(
1337+
"valor_api.backend.core.llm_clients.LLMClient.__call__",
1338+
_return_valid_summary_coherence_response,
1339+
)
1340+
1341+
# Test with retries=None
1342+
client = LLMClient(api_key=None, model_name="model_name", retries=None)
1343+
assert 5 == client.summary_coherence("some text", "some summary")
1344+
1345+
# Test with retries=0
1346+
client = LLMClient(api_key=None, model_name="model_name", retries=0)
1347+
assert 5 == client.summary_coherence("some text", "some summary")
1348+
1349+
# Test with retries=3 and valid response
1350+
client = LLMClient(api_key=None, model_name="model_name", retries=3)
1351+
assert 5 == client.summary_coherence("some text", "some summary")
1352+
1353+
# mock_method returns a bad response three times but on the fourth call returns a valid response.
1354+
monkeypatch.setattr(
1355+
"valor_api.backend.core.llm_clients.LLMClient.__call__",
1356+
Mock(side_effect=errors),
1357+
)
1358+
client = LLMClient(api_key=None, model_name="model_name", retries=3)
1359+
assert 5 == client.summary_coherence("some text", "some summary")
1360+
1361+
# Test with retries=2 and invalid response
1362+
monkeypatch.setattr(
1363+
"valor_api.backend.core.llm_clients.LLMClient.__call__",
1364+
Mock(side_effect=errors),
1365+
)
1366+
with pytest.raises(InvalidLLMResponseError):
1367+
client = LLMClient(api_key=None, model_name="model_name", retries=2)
1368+
client.summary_coherence("some text", "some summary")
1369+
1370+
monkeypatch.setattr(
1371+
"valor_api.backend.core.llm_clients.LLMClient.__call__",
1372+
_return_invalid_summary_coherence_response,
1373+
)
1374+
1375+
# Test with retries=None and invalid response
1376+
with pytest.raises(InvalidLLMResponseError):
1377+
client = LLMClient(api_key=None, model_name="model_name", retries=None)
1378+
client.summary_coherence("some text", "some summary")
1379+
1380+
# Test with retries=3 and invalid response
1381+
with pytest.raises(InvalidLLMResponseError):
1382+
client = LLMClient(api_key=None, model_name="model_name", retries=3)
1383+
client.summary_coherence("some text", "some summary")
1384+
1385+
# Test WrappedOpenAIClient
1386+
monkeypatch.setattr(
1387+
"valor_api.backend.core.llm_clients.WrappedOpenAIClient.__call__",
1388+
Mock(side_effect=errors),
1389+
)
1390+
client = WrappedOpenAIClient(
1391+
api_key=None, model_name="model_name", retries=3
1392+
)
1393+
assert 5 == client.summary_coherence("some text", "some summary")
1394+
1395+
with pytest.raises(InvalidLLMResponseError):
1396+
monkeypatch.setattr(
1397+
"valor_api.backend.core.llm_clients.WrappedOpenAIClient.__call__",
1398+
Mock(side_effect=errors),
1399+
)
1400+
client = WrappedOpenAIClient(
1401+
api_key=None, model_name="model_name", retries=2
1402+
)
1403+
client.summary_coherence("some text", "some summary")
1404+
1405+
# Test WrappedMistralAIClient
1406+
monkeypatch.setattr(
1407+
"valor_api.backend.core.llm_clients.WrappedMistralAIClient.__call__",
1408+
Mock(side_effect=errors),
1409+
)
1410+
client = WrappedMistralAIClient(
1411+
api_key=None, model_name="model_name", retries=3
1412+
)
1413+
assert 5 == client.summary_coherence("some text", "some summary")
1414+
1415+
with pytest.raises(InvalidLLMResponseError):
1416+
monkeypatch.setattr(
1417+
"valor_api.backend.core.llm_clients.WrappedMistralAIClient.__call__",
1418+
Mock(side_effect=errors),
1419+
)
1420+
client = WrappedMistralAIClient(
1421+
api_key=None, model_name="model_name", retries=2
1422+
)
1423+
client.summary_coherence("some text", "some summary")
1424+
1425+
13231426
def test_WrappedOpenAIClient():
13241427
def _create_bad_request(model, messages, seed) -> ChatCompletion:
13251428
raise ValueError

0 commit comments

Comments
 (0)