|
6 | 6 |
|
7 | 7 | from httpx._config import Timeout
|
8 | 8 |
|
| 9 | +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler |
| 10 | +from litellm.types.utils import CustomStreamingDecoder |
9 | 11 | from litellm.utils import ModelResponse
|
10 | 12 |
|
11 | 13 | from ...groq.chat.transformation import GroqChatConfig
|
12 |
| -from ...OpenAI.openai import OpenAIChatCompletion |
| 14 | +from ...openai_like.chat.handler import OpenAILikeChatHandler |
13 | 15 |
|
14 | 16 |
|
15 |
| -class GroqChatCompletion(OpenAIChatCompletion): |
| 17 | +class GroqChatCompletion(OpenAILikeChatHandler): |
16 | 18 | def __init__(self, **kwargs):
|
17 | 19 | super().__init__(**kwargs)
|
18 | 20 |
|
19 | 21 | def completion(
|
20 | 22 | self,
|
| 23 | + *, |
| 24 | + model: str, |
| 25 | + messages: list, |
| 26 | + api_base: str, |
| 27 | + custom_llm_provider: str, |
| 28 | + custom_prompt_dict: dict, |
21 | 29 | model_response: ModelResponse,
|
22 |
| - timeout: Union[float, Timeout], |
| 30 | + print_verbose: Callable, |
| 31 | + encoding, |
| 32 | + api_key: Optional[str], |
| 33 | + logging_obj, |
23 | 34 | optional_params: dict,
|
24 |
| - logging_obj: Any, |
25 |
| - model: Optional[str] = None, |
26 |
| - messages: Optional[list] = None, |
27 |
| - print_verbose: Optional[Callable[..., Any]] = None, |
28 |
| - api_key: Optional[str] = None, |
29 |
| - api_base: Optional[str] = None, |
30 |
| - acompletion: bool = False, |
| 35 | + acompletion=None, |
31 | 36 | litellm_params=None,
|
32 | 37 | logger_fn=None,
|
33 | 38 | headers: Optional[dict] = None,
|
34 |
| - custom_prompt_dict: dict = {}, |
35 |
| - client=None, |
36 |
| - organization: Optional[str] = None, |
37 |
| - custom_llm_provider: Optional[str] = None, |
38 |
| - drop_params: Optional[bool] = None, |
| 39 | + timeout: Optional[Union[float, Timeout]] = None, |
| 40 | + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, |
| 41 | + custom_endpoint: Optional[bool] = None, |
| 42 | + streaming_decoder: Optional[CustomStreamingDecoder] = None, |
| 43 | + fake_stream: bool = False |
39 | 44 | ):
|
40 | 45 | messages = GroqChatConfig()._transform_messages(messages) # type: ignore
|
| 46 | + |
| 47 | + if optional_params.get("stream") is True: |
| 48 | + fake_stream = GroqChatConfig()._should_fake_stream(optional_params) |
| 49 | + else: |
| 50 | + fake_stream = False |
| 51 | + |
41 | 52 | return super().completion(
|
42 |
| - model_response, |
43 |
| - timeout, |
44 |
| - optional_params, |
45 |
| - logging_obj, |
46 |
| - model, |
47 |
| - messages, |
48 |
| - print_verbose, |
49 |
| - api_key, |
50 |
| - api_base, |
51 |
| - acompletion, |
52 |
| - litellm_params, |
53 |
| - logger_fn, |
54 |
| - headers, |
55 |
| - custom_prompt_dict, |
56 |
| - client, |
57 |
| - organization, |
58 |
| - custom_llm_provider, |
59 |
| - drop_params, |
| 53 | + model=model, |
| 54 | + messages=messages, |
| 55 | + api_base=api_base, |
| 56 | + custom_llm_provider=custom_llm_provider, |
| 57 | + custom_prompt_dict=custom_prompt_dict, |
| 58 | + model_response=model_response, |
| 59 | + print_verbose=print_verbose, |
| 60 | + encoding=encoding, |
| 61 | + api_key=api_key, |
| 62 | + logging_obj=logging_obj, |
| 63 | + optional_params=optional_params, |
| 64 | + acompletion=acompletion, |
| 65 | + litellm_params=litellm_params, |
| 66 | + logger_fn=logger_fn, |
| 67 | + headers=headers, |
| 68 | + timeout=timeout, |
| 69 | + client=client, |
| 70 | + custom_endpoint=custom_endpoint, |
| 71 | + streaming_decoder=streaming_decoder, |
| 72 | + fake_stream=fake_stream, |
60 | 73 | )
|
0 commit comments