Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions microagent/llm/cerebras_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Dict, Any, List
from cerebras.cloud.sdk import Cerebras
from .base import LLMClient

class CerebrasClient(LLMClient):
def __init__(self):
# Initialize Cerebras client
self.client = Cerebras()

def chat_completion(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
prepared_messages = self.prepare_messages(messages)
chat_params = self.prepare_chat_params(messages=prepared_messages, **kwargs)

response = self.client.chat.completions.create(**chat_params)
return self.parse_response(response)

def stream_chat_completion(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
prepared_messages = self.prepare_messages(messages)
chat_params = self.prepare_chat_params(messages=prepared_messages, **kwargs)
return self.client.chat.completions.create(stream=True, **chat_params)

def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [
{k: v for k, v in msg.items() if k not in ['sender', 'tool_name']}
for msg in messages
]

def prepare_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return tools

def parse_response(self, response: Any) -> Dict[str, Any]:
if isinstance(response, dict):
return response

# Extract the first choice from the response
choice = response.choices[0] if response.choices else None

if choice and choice.message:
parsed_response = {
"role": choice.message.role,
"content": choice.message.content,
}

# Handle tool calls
if choice.message.tool_calls:
parsed_response["tool_calls"] = [
{
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
for tool_call in choice.message.tool_calls
]

return parsed_response
else:
return {
"role": None,
"content": None,
"tool_calls": None
}

def prepare_chat_params(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
params = {
"model": kwargs.get('model', 'llama3.1-70b'), # Default model for Cerebras
"messages": messages
}
if 'tools' in kwargs and kwargs['tools']:
params["tools"] = kwargs['tools']
if 'tool_choice' in kwargs:
params["tool_choice"] = kwargs['tool_choice']
return params

def prepare_system_message(self, instructions: str) -> Dict[str, Any]:
return {"role": "system", "content": instructions}

def prepare_tool_response(self, tool_call_id: str, tool_name: str, content: str) -> Dict[str, Any]:
return {
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_name,
"content": content,
}
3 changes: 3 additions & 0 deletions microagent/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .groq_client import GroqClient
from .cerebras_client import CerebrasClient

class LLMFactory:
@staticmethod
Expand All @@ -11,5 +12,7 @@ def create(llm_type):
return AnthropicClient()
elif llm_type == 'groq':
return GroqClient()
elif llm_type == 'cerebras':
return CerebrasClient()
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")