|
1 | 1 | import contextvars
|
2 | 2 | import json
|
| 3 | +import os |
3 | 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
4 | 5 |
|
5 |
| -from guardrails.utils.casting_utils import to_int |
| 6 | +from dotenv import load_dotenv |
6 | 7 | from guardrails.validator_base import (
|
7 | 8 | FailResult,
|
8 | 9 | PassResult,
|
@@ -225,23 +226,26 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]:
|
225 | 226 | found_topics.append(llm_result["name"])
|
226 | 227 | return found_topics
|
227 | 228 |
|
228 |
| - def get_client_args(self) -> Tuple[Optional[str], Optional[str]]: |
| 229 | + def get_client_args(self) -> str: |
229 | 230 | """Returns neccessary data for api calls.
|
230 | 231 |
|
231 | 232 | Returns:
|
232 |
| - Tuple[Optional[str], Optional[str]]: api key and api base values |
| 233 | + str: api key |
233 | 234 | """
|
234 |
| - kwargs = {} |
235 |
| - context_copy = contextvars.copy_context() |
236 |
| - for key, context_var in context_copy.items(): |
237 |
| - if key.name == "kwargs" and isinstance(kwargs, dict): |
238 |
| - kwargs = context_var |
239 |
| - break |
240 | 235 |
|
241 |
| - api_key = kwargs.get("api_key") |
242 |
| - api_base = kwargs.get("api_base") |
| 236 | + load_dotenv() |
| 237 | + api_key = os.getenv("OPENAI_API_KEY") |
| 238 | + if not api_key: |
| 239 | + kwargs = {} |
| 240 | + context_copy = contextvars.copy_context() |
| 241 | + for key, context_var in context_copy.items(): |
| 242 | + if key.name == "kwargs" and isinstance(kwargs, dict): |
| 243 | + kwargs = context_var |
| 244 | + break |
243 | 245 |
|
244 |
| - return (api_key, api_base) |
| 246 | + api_key = kwargs.get("api_key") |
| 247 | + |
| 248 | + return api_key |
245 | 249 |
|
246 | 250 | @retry(
|
247 | 251 | wait=wait_random_exponential(min=1, max=60),
|
@@ -280,8 +284,8 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None:
|
280 | 284 | )
|
281 | 285 |
|
282 | 286 | def openai_callable(text: str) -> str:
|
283 |
| - api_key, api_base = self.get_client_args() |
284 |
| - client = OpenAI() |
| 287 | + api_key = self.get_client_args() |
| 288 | + client = OpenAI(api_key=api_key) |
285 | 289 | response = client.chat.completions.create(
|
286 | 290 | model=llm_callable,
|
287 | 291 | response_format={"type": "json_object"},
|
|
0 commit comments