Skip to content

Commit 7af49da

Browse files
authored
feat(drivers): add GriptapeCloudPromptDriver (#1692)
1 parent 762958f commit 7af49da

File tree

6 files changed

+589
-0
lines changed

6 files changed

+589
-0
lines changed

docs/griptape-framework/drivers/prompt-drivers.md

+8
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ This driver uses [Azure OpenAi function calling](https://learn.microsoft.com/en-
9898
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_5.py"
9999
```
100100

101+
### Griptape Cloud
102+
103+
The [GriptapeCloudPromptDriver](../../reference/griptape/drivers/prompt/griptape_cloud_prompt_driver.md) connects to the [Griptape Cloud](https://www.griptape.ai/cloud) chat messages API.
104+
105+
```python
106+
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py"
107+
```
108+
101109
### Cohere
102110

103111
The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_driver.md) connects to the Cohere [Chat](https://docs.cohere.com/docs/chat-api) API.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
3+
from griptape.drivers.prompt.griptape_cloud import GriptapeCloudPromptDriver
4+
from griptape.rules import Rule
5+
from griptape.structures import Agent
6+
7+
agent = Agent(
8+
prompt_driver=GriptapeCloudPromptDriver(
9+
api_key=os.environ["GT_CLOUD_API_KEY"],
10+
),
11+
rules=[
12+
Rule(
13+
"You will be provided with a product description and seed words, and your task is to generate product names.",
14+
),
15+
],
16+
)
17+
18+
agent.run("Product description: A home milkshake maker. Seed words: fast, healthy, compact.")

griptape/drivers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .prompt.dummy import DummyPromptDriver
1515
from .prompt.ollama import OllamaPromptDriver
1616
from .prompt.grok import GrokPromptDriver
17+
from .prompt.griptape_cloud import GriptapeCloudPromptDriver
1718

1819
from .memory.conversation import BaseConversationMemoryDriver
1920
from .memory.conversation.local import LocalConversationMemoryDriver
@@ -141,6 +142,7 @@
141142
"DummyPromptDriver",
142143
"OllamaPromptDriver",
143144
"GrokPromptDriver",
145+
"GriptapeCloudPromptDriver",
144146
"BaseConversationMemoryDriver",
145147
"LocalConversationMemoryDriver",
146148
"AmazonDynamoDbConversationMemoryDriver",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from griptape.drivers.prompt.griptape_cloud_prompt_driver import GriptapeCloudPromptDriver
2+
3+
__all__ = ["GriptapeCloudPromptDriver"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import os
5+
from typing import TYPE_CHECKING, Optional
6+
from urllib.parse import urljoin
7+
8+
import requests
9+
from attrs import Factory, define, field
10+
11+
from griptape.common import DeltaMessage, Message, PromptStack, observable
12+
from griptape.configs.defaults_config import Defaults
13+
from griptape.drivers.prompt import BasePromptDriver
14+
from griptape.tokenizers import BaseTokenizer, SimpleTokenizer
15+
16+
if TYPE_CHECKING:
17+
from collections.abc import Iterator
18+
19+
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
20+
from griptape.tools.base_tool import BaseTool
21+
22+
23+
logger = logging.getLogger(Defaults.logging_config.logger_name)
24+
25+
26+
@define
27+
class GriptapeCloudPromptDriver(BasePromptDriver):
28+
model: Optional[str] = field(default=None, kw_only=True)
29+
base_url: str = field(
30+
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
31+
)
32+
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
33+
headers: dict = field(
34+
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
35+
)
36+
tokenizer: BaseTokenizer = field(
37+
default=Factory(
38+
lambda self: SimpleTokenizer(
39+
characters_per_token=4,
40+
max_input_tokens=2000,
41+
max_output_tokens=self.max_tokens,
42+
),
43+
takes_self=True,
44+
),
45+
kw_only=True,
46+
)
47+
use_native_tools: bool = field(default=True, kw_only=True)
48+
structured_output_strategy: StructuredOutputStrategy = field(
49+
default="native", kw_only=True, metadata={"serializable": True}
50+
)
51+
52+
@observable
53+
def try_run(self, prompt_stack: PromptStack) -> Message:
54+
url = urljoin(self.base_url.strip("/"), "/api/chat/messages")
55+
56+
params = self._base_params(prompt_stack)
57+
logger.debug(params)
58+
response = requests.post(url, headers=self.headers, json=params)
59+
response.raise_for_status()
60+
response_json = response.json()
61+
logger.debug(response_json)
62+
63+
return Message.from_dict(response_json)
64+
65+
@observable
66+
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
67+
url = urljoin(self.base_url.strip("/"), "/api/chat/messages/stream")
68+
params = self._base_params(prompt_stack)
69+
logger.debug(params)
70+
with requests.post(url, headers=self.headers, json=params, stream=True) as response:
71+
response.raise_for_status()
72+
for line in response.iter_lines():
73+
if line:
74+
decoded_line = line.decode("utf-8")
75+
if decoded_line.startswith("data:"):
76+
delta_message_payload = decoded_line.removeprefix("data:").strip()
77+
logger.debug(delta_message_payload)
78+
yield DeltaMessage.from_json(delta_message_payload)
79+
80+
def _base_params(self, prompt_stack: PromptStack) -> dict:
81+
return {
82+
"messages": prompt_stack.to_dict()["messages"],
83+
"tools": self.__to_griptape_tools(prompt_stack.tools),
84+
**({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}),
85+
"driver_configuration": {
86+
**({"model": self.model} if self.model else {}),
87+
"max_tokens": self.max_tokens,
88+
"use_native_tools": self.use_native_tools,
89+
"temperature": self.temperature,
90+
"structured_output_strategy": self.structured_output_strategy,
91+
"extra_params": self.extra_params,
92+
},
93+
}
94+
95+
def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]:
96+
return [
97+
{
98+
"name": tool.name,
99+
"activities": [
100+
{
101+
"name": activity.__name__,
102+
"description": tool.activity_description(activity),
103+
"json_schema": tool.to_activity_json_schema(activity, "Schema"),
104+
}
105+
for activity in tool.activities()
106+
],
107+
}
108+
for tool in tools
109+
]

0 commit comments

Comments
 (0)