Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/strands/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,12 +150,15 @@ def __init__(
(connect, read) timeouts.
**model_config: Configuration options for the llama.cpp model.
"""
validate_config_keys(model_config, self.LlamaCppConfig)

# Set default model_id if not provided
if "model_id" not in model_config:
model_config["model_id"] = "default"

self.base_url = base_url.rstrip("/")
self.config = dict(model_config)
logger.debug("config=<%s> | initializing", self.config)

# Configure HTTP client
if isinstance(timeout, tuple):
Expand All @@ -173,19 +177,14 @@ def __init__(
timeout=timeout_obj,
)

logger.debug(
"base_url=<%s>, model_id=<%s> | initializing llama.cpp provider",
base_url,
model_config.get("model_id"),
)

@override
def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type: ignore[override]
"""Update the llama.cpp model configuration with provided arguments.
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.LlamaCppConfig)
self.config.update(model_config)

@override
Expand Down Expand Up @@ -514,6 +513,7 @@ async def stream(
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the llama.cpp model.
Expand All @@ -522,6 +522,8 @@ async def stream(
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
interface consistency but is currently ignored for this model provider.**
**kwargs: Additional keyword arguments for future extensibility.
Yields:
Expand All @@ -531,19 +533,21 @@ async def stream(
ContextWindowOverflowException: When the context window is exceeded.
ModelThrottledException: When the llama.cpp server is overloaded.
"""
warn_on_tool_choice_not_supported(tool_choice)

# Track request start time for latency calculation
start_time = time.perf_counter()

try:
logger.debug("formatting request for llama.cpp server")
logger.debug("formatting request")
request = self._format_request(messages, tool_specs, system_prompt)
logger.debug("request=<%s>", request)

logger.debug("sending request to llama.cpp server")
logger.debug("invoking model")
response = await self.client.post("/v1/chat/completions", json=request)
response.raise_for_status()

logger.debug("processing streaming response")
logger.debug("got response from model")
yield self._format_chunk({"chunk_type": "message_start"})
yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"})

Expand Down Expand Up @@ -648,12 +652,10 @@ async def stream(
yield self._format_chunk({"chunk_type": "content_stop"})

# Send stop reason
logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls))
if finish_reason == "tool_calls" or tool_calls:
stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations
else:
stop_reason = finish_reason or "end_turn"
logger.debug("stop_reason=%s", stop_reason)
yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason})

# Send usage metadata if available
Expand All @@ -676,7 +678,7 @@ async def stream(
}
)

logger.debug("finished streaming response")
logger.debug("finished streaming response from model")

except httpx.HTTPStatusError as e:
if e.response.status_code == 400:
Expand Down
Loading