Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 54 additions & 37 deletions src/f5_ai_gateway_sdk/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
LICENSE file in the root directory of this source tree.
"""

import inspect
import json
import logging
from abc import ABC
from io import TextIOWrapper, StringIO
from json import JSONDecodeError
from typing import Generic, Any, TypeVar
from collections.abc import Callable, Mapping
from collections.abc import Awaitable, Callable, Mapping
import warnings

from pydantic import JsonValue, ValidationError
Expand Down Expand Up @@ -225,6 +226,15 @@ def __init_subclass__(cls, **kwargs):
"The DEPRECATED 'process' method must not be implemented "
"alongside 'process_input' or 'process_response'."
)
if is_process_overridden and inspect.iscoroutinefunction(
inspect.unwrap(cls.process)
):
# we don't want to add async capabilities to the deprecated function
raise TypeError(
f"Cannot create concrete class {cls.__name__}. "
"The DEPRECATED 'process' method does not support async. "
"Implement 'process_input' and/or 'process_response' instead."
)

return

Expand Down Expand Up @@ -875,15 +885,18 @@ async def _parse_and_process(self, request: Request) -> Response:
prompt_hash, response_hash = (None, None)
if input_direction:
prompt_hash = prompt.hash()
result: Result | Reject = self.process_input(
result = await self._handle_process_function(
self.process_input,
metadata=metadata,
parameters=parameters,
prompt=prompt,
request=request,
)

else:
response_hash = response.hash()
result: Result | Reject = self.process_response(
result = await self._handle_process_function(
self.process_response,
metadata=metadata,
parameters=parameters,
prompt=prompt,
Expand Down Expand Up @@ -1014,13 +1027,22 @@ def _is_method_overridden(self, method_name: str) -> bool:
# the method object directly from the Processor class, then it has been overridden.
return instance_class_method_obj is not base_class_method_obj

def _process_fallback(self, **kwargs) -> Result | Reject:
warnings.warn(
f"{type(self).__name__} uses the deprecated 'process' method. "
"Implement 'process_input' and/or 'process_response' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.process(**kwargs)

def process_input(
self,
prompt: PROMPT,
metadata: Metadata,
parameters: PARAMS,
request: Request,
) -> Result | Reject:
) -> Result | Reject | Awaitable[Result | Reject]:
"""
This abstract method is for implementors of the processor to define
with their own custom logic. Errors should be raised as a subclass
Expand All @@ -1043,23 +1065,17 @@ def process_input(self, prompt, response, metadata, parameters, request):

return Result(processor_result=result)
"""
if self._is_method_overridden("process"):
warnings.warn(
f"{type(self).__name__} uses the deprecated 'process' method for input. "
"Implement 'process_input' instead.",
DeprecationWarning,
stacklevel=2, # Points the warning to the caller of process_input
if not self._is_method_overridden("process"):
raise NotImplementedError(
f"{type(self).__name__} must implement 'process_input' or the "
"deprecated 'process' method to handle input."
)
return self.process(
prompt=prompt,
response=None,
metadata=metadata,
parameters=parameters,
request=request,
)
raise NotImplementedError(
f"{type(self).__name__} must implement 'process_input' or the "
"deprecated 'process' method to handle input."
return self._process_fallback(
prompt=prompt,
response=None,
metadata=metadata,
parameters=parameters,
request=request,
)

def process_response(
Expand All @@ -1069,7 +1085,7 @@ def process_response(
metadata: Metadata,
parameters: PARAMS,
request: Request,
) -> Result | Reject:
) -> Result | Reject | Awaitable[Result | Reject]:
"""
This abstract method is for implementors of the processor to define
with their own custom logic. Errors should be raised as a subclass
Expand All @@ -1096,23 +1112,17 @@ def process_response(self, prompt, response, metadata, parameters, request):
return Result(processor_result=result)
"""

if self._is_method_overridden("process"):
warnings.warn(
f"{type(self).__name__} uses the deprecated 'process' method for response. "
"Implement 'process_response' instead.",
DeprecationWarning,
stacklevel=2, # Points the warning to the caller of process_input
if not self._is_method_overridden("process"):
raise NotImplementedError(
f"{type(self).__name__} must implement 'process_response' or the "
"deprecated 'process' method to handle input."
)
return self.process(
prompt=prompt,
response=response,
metadata=metadata,
parameters=parameters,
request=request,
)
raise NotImplementedError(
f"{type(self).__name__} must implement 'process_response' or the "
"deprecated 'process' method to handle input."
return self._process_fallback(
prompt=prompt,
response=response,
metadata=metadata,
parameters=parameters,
request=request,
)

def process(
Expand Down Expand Up @@ -1159,6 +1169,13 @@ def process(self, prompt, response, metadata, parameters, request):
"'process_input'/'process_response'."
)

async def _handle_process_function(self, func, **kwargs) -> Result | Reject:
if inspect.iscoroutinefunction(func):
result = await func(**kwargs)
else:
result = func(**kwargs)
return result


def _validation_error_as_messages(err: ValidationError) -> list[str]:
return [_error_details_to_str(e) for e in err.errors()]
Expand Down
Loading