diff --git a/portkey_ai/__init__.py b/portkey_ai/__init__.py index 10479a8b..52fd9344 100644 --- a/portkey_ai/__init__.py +++ b/portkey_ai/__init__.py @@ -164,6 +164,13 @@ PORTKEY_PROXY_ENV, PORTKEY_GATEWAY_URL, ) +from portkey_ai.utils.json_utils import ( + PortkeyJSONEncoder, + enable_notgiven_serialization, + disable_notgiven_serialization, +) + +enable_notgiven_serialization() api_key = os.environ.get(PORTKEY_API_KEY_ENV) base_url = os.environ.get(PORTKEY_PROXY_ENV, PORTKEY_BASE_URL) @@ -175,6 +182,9 @@ "LLMOptions", "Modes", "PortkeyResponse", + "PortkeyJSONEncoder", + "enable_notgiven_serialization", + "disable_notgiven_serialization", "ModesLiteral", "ProviderTypes", "ProviderTypesLiteral", diff --git a/portkey_ai/utils/json_utils.py b/portkey_ai/utils/json_utils.py index 6b016caf..a302dcbc 100644 --- a/portkey_ai/utils/json_utils.py +++ b/portkey_ai/utils/json_utils.py @@ -1,33 +1,64 @@ import json +from portkey_ai._vendor.openai._utils import is_given -def serialize_kwargs(**kwargs): - # Function to check if a value is serializable - def is_serializable(value): - try: - json.dumps(value) - return True - except (TypeError, ValueError): - return False +_BASE_JSON_DEFAULT = json.JSONEncoder.default +_patched_notgiven_serialization = False + + +class PortkeyJSONEncoder(json.JSONEncoder): + """JSON encoder that treats OpenAI/Portkey "not provided" markers as null.""" + + def default(self, obj): # type: ignore[override] + # If this is one of OpenAI's internal "not provided" / omit markers, + # encode it as None (null in JSON) instead of raising TypeError. + if not is_given(obj): + return None + return super().default(obj) + + +def enable_notgiven_serialization() -> None: + """Globally encode NotGiven / Omit markers as null in json.dumps.""" + global _patched_notgiven_serialization + if _patched_notgiven_serialization: + return + + def patched_default(self, obj): # type: ignore[override] + if not is_given(obj): + return None + return _BASE_JSON_DEFAULT(self, obj) - # Filter out non-serializable items - serializable_kwargs = {k: v for k, v in kwargs.items() if is_serializable(v)} + json.JSONEncoder.default = patched_default + _patched_notgiven_serialization = True - # Convert to string representation - return json.dumps(serializable_kwargs) + +def disable_notgiven_serialization() -> None: + """Restore the original json.JSONEncoder.default implementation.""" + global _patched_notgiven_serialization + if not _patched_notgiven_serialization: + return + + json.JSONEncoder.default = _BASE_JSON_DEFAULT + _patched_notgiven_serialization = False + + +def _is_serializable(value) -> bool: + try: + json.dumps(value, cls=PortkeyJSONEncoder) + except (TypeError, ValueError): + return False + return True + + +def serialize_kwargs(**kwargs): + return json.dumps( + {k: v for k, v in kwargs.items() if _is_serializable(v)}, + cls=PortkeyJSONEncoder, + ) def serialize_args(*args): - # Function to check if a value is serializable - def is_serializable(value): - try: - json.dumps(value) - return True - except (TypeError, ValueError): - return False - - # Filter out non-serializable items - serializable_args = [arg for arg in args if is_serializable(arg)] - - # Convert to string representation - return json.dumps(serializable_args) + return json.dumps( + [arg for arg in args if _is_serializable(arg)], + cls=PortkeyJSONEncoder, + )