|
1 | 1 | import datetime
|
2 | 2 | import os
|
3 |
| -from unittest.mock import MagicMock |
| 3 | +from unittest.mock import MagicMock, Mock |
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | from mistralai.models import (
|
@@ -1320,6 +1320,109 @@ def _return_invalid4_toxicity_response(*args, **kwargs):
|
1320 | 1320 | client.toxicity("some text")
|
1321 | 1321 |
|
1322 | 1322 |
|
| 1323 | +def test_LLMClient_retries(monkeypatch): |
| 1324 | + """ |
| 1325 | + Test the retry functionality for structuring LLM API calls. |
| 1326 | + """ |
| 1327 | + |
| 1328 | + def _return_valid_summary_coherence_response(*args, **kwargs): |
| 1329 | + return "5" |
| 1330 | + |
| 1331 | + errors = ["The score is 5."] * 3 + ["5"] |
| 1332 | + |
| 1333 | + def _return_invalid_summary_coherence_response(*args, **kwargs): |
| 1334 | + return "The score is 5." |
| 1335 | + |
| 1336 | + monkeypatch.setattr( |
| 1337 | + "valor_api.backend.core.llm_clients.LLMClient.__call__", |
| 1338 | + _return_valid_summary_coherence_response, |
| 1339 | + ) |
| 1340 | + |
| 1341 | + # Test with retries=None |
| 1342 | + client = LLMClient(api_key=None, model_name="model_name", retries=None) |
| 1343 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1344 | + |
| 1345 | + # Test with retries=0 |
| 1346 | + client = LLMClient(api_key=None, model_name="model_name", retries=0) |
| 1347 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1348 | + |
| 1349 | + # Test with retries=3 and valid response |
| 1350 | + client = LLMClient(api_key=None, model_name="model_name", retries=3) |
| 1351 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1352 | + |
| 1353 | + # mock_method returns a bad response three times but on the fourth call returns a valid response. |
| 1354 | + monkeypatch.setattr( |
| 1355 | + "valor_api.backend.core.llm_clients.LLMClient.__call__", |
| 1356 | + Mock(side_effect=errors), |
| 1357 | + ) |
| 1358 | + client = LLMClient(api_key=None, model_name="model_name", retries=3) |
| 1359 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1360 | + |
| 1361 | + # Test with retries=2 and invalid response |
| 1362 | + monkeypatch.setattr( |
| 1363 | + "valor_api.backend.core.llm_clients.LLMClient.__call__", |
| 1364 | + Mock(side_effect=errors), |
| 1365 | + ) |
| 1366 | + with pytest.raises(InvalidLLMResponseError): |
| 1367 | + client = LLMClient(api_key=None, model_name="model_name", retries=2) |
| 1368 | + client.summary_coherence("some text", "some summary") |
| 1369 | + |
| 1370 | + monkeypatch.setattr( |
| 1371 | + "valor_api.backend.core.llm_clients.LLMClient.__call__", |
| 1372 | + _return_invalid_summary_coherence_response, |
| 1373 | + ) |
| 1374 | + |
| 1375 | + # Test with retries=None and invalid response |
| 1376 | + with pytest.raises(InvalidLLMResponseError): |
| 1377 | + client = LLMClient(api_key=None, model_name="model_name", retries=None) |
| 1378 | + client.summary_coherence("some text", "some summary") |
| 1379 | + |
| 1380 | + # Test with retries=3 and invalid response |
| 1381 | + with pytest.raises(InvalidLLMResponseError): |
| 1382 | + client = LLMClient(api_key=None, model_name="model_name", retries=3) |
| 1383 | + client.summary_coherence("some text", "some summary") |
| 1384 | + |
| 1385 | + # Test WrappedOpenAIClient |
| 1386 | + monkeypatch.setattr( |
| 1387 | + "valor_api.backend.core.llm_clients.WrappedOpenAIClient.__call__", |
| 1388 | + Mock(side_effect=errors), |
| 1389 | + ) |
| 1390 | + client = WrappedOpenAIClient( |
| 1391 | + api_key=None, model_name="model_name", retries=3 |
| 1392 | + ) |
| 1393 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1394 | + |
| 1395 | + with pytest.raises(InvalidLLMResponseError): |
| 1396 | + monkeypatch.setattr( |
| 1397 | + "valor_api.backend.core.llm_clients.WrappedOpenAIClient.__call__", |
| 1398 | + Mock(side_effect=errors), |
| 1399 | + ) |
| 1400 | + client = WrappedOpenAIClient( |
| 1401 | + api_key=None, model_name="model_name", retries=2 |
| 1402 | + ) |
| 1403 | + client.summary_coherence("some text", "some summary") |
| 1404 | + |
| 1405 | + # Test WrappedMistralAIClient |
| 1406 | + monkeypatch.setattr( |
| 1407 | + "valor_api.backend.core.llm_clients.WrappedMistralAIClient.__call__", |
| 1408 | + Mock(side_effect=errors), |
| 1409 | + ) |
| 1410 | + client = WrappedMistralAIClient( |
| 1411 | + api_key=None, model_name="model_name", retries=3 |
| 1412 | + ) |
| 1413 | + assert 5 == client.summary_coherence("some text", "some summary") |
| 1414 | + |
| 1415 | + with pytest.raises(InvalidLLMResponseError): |
| 1416 | + monkeypatch.setattr( |
| 1417 | + "valor_api.backend.core.llm_clients.WrappedMistralAIClient.__call__", |
| 1418 | + Mock(side_effect=errors), |
| 1419 | + ) |
| 1420 | + client = WrappedMistralAIClient( |
| 1421 | + api_key=None, model_name="model_name", retries=2 |
| 1422 | + ) |
| 1423 | + client.summary_coherence("some text", "some summary") |
| 1424 | + |
| 1425 | + |
1323 | 1426 | def test_WrappedOpenAIClient():
|
1324 | 1427 | def _create_bad_request(model, messages, seed) -> ChatCompletion:
|
1325 | 1428 | raise ValueError
|
|
0 commit comments