Skip to content

Commit 8f49489

Browse files
authored
Merge pull request #4 from guardrails-ai/fix_api_key
adding load dot env to args to correctly set api key
2 parents d3b315a + 3dc53cf commit 8f49489

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

validator/main.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import contextvars
22
import json
3+
import os
34
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
45

5-
from guardrails.utils.casting_utils import to_int
6+
from dotenv import load_dotenv
67
from guardrails.validator_base import (
78
FailResult,
89
PassResult,
@@ -225,23 +226,26 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]:
225226
found_topics.append(llm_result["name"])
226227
return found_topics
227228

228-
def get_client_args(self) -> Tuple[Optional[str], Optional[str]]:
229+
def get_client_args(self) -> str:
229230
"""Returns neccessary data for api calls.
230231
231232
Returns:
232-
Tuple[Optional[str], Optional[str]]: api key and api base values
233+
str: api key
233234
"""
234-
kwargs = {}
235-
context_copy = contextvars.copy_context()
236-
for key, context_var in context_copy.items():
237-
if key.name == "kwargs" and isinstance(kwargs, dict):
238-
kwargs = context_var
239-
break
240235

241-
api_key = kwargs.get("api_key")
242-
api_base = kwargs.get("api_base")
236+
load_dotenv()
237+
api_key = os.getenv("OPENAI_API_KEY")
238+
if not api_key:
239+
kwargs = {}
240+
context_copy = contextvars.copy_context()
241+
for key, context_var in context_copy.items():
242+
if key.name == "kwargs" and isinstance(kwargs, dict):
243+
kwargs = context_var
244+
break
243245

244-
return (api_key, api_base)
246+
api_key = kwargs.get("api_key")
247+
248+
return api_key
245249

246250
@retry(
247251
wait=wait_random_exponential(min=1, max=60),
@@ -280,8 +284,8 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None:
280284
)
281285

282286
def openai_callable(text: str) -> str:
283-
api_key, api_base = self.get_client_args()
284-
client = OpenAI()
287+
api_key = self.get_client_args()
288+
client = OpenAI(api_key=api_key)
285289
response = client.chat.completions.create(
286290
model=llm_callable,
287291
response_format={"type": "json_object"},

0 commit comments

Comments
 (0)