Skip to content

Commit

Permalink
Notebook checks (#333)
Browse files Browse the repository at this point in the history
* add checks for notebooks

* format

* Fix mypy

* format

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
jackgerrits and ekzhu authored Aug 7, 2024
1 parent c7f5931 commit 33649c3
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 40 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ jobs:
hatch run +python=${{ matrix.python-version }} teamone-test-matrix:pytest -n auto
working-directory: ./python/teams/team-one
mypy-notebooks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Hatch
uses: pypa/hatch@install
- run: hatch run nbqa mypy docs/src
working-directory: ./python

docs:
runs-on: ubuntu-latest
steps:
Expand Down
4 changes: 2 additions & 2 deletions python/docs/src/cookbook/langgraph-agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@
" \"Tool use agent\",\n",
" ChatOpenAI(\n",
" model=\"gpt-4o\",\n",
" api_key=os.getenv(\"OPENAI_API_KEY\"),\n",
" # api_key=os.getenv(\"OPENAI_API_KEY\"),\n",
" ),\n",
" # AzureChatOpenAI(\n",
" # azure_deployment=os.getenv(\"AZURE_OPENAI_DEPLOYMENT\"), \n",
" # azure_deployment=os.getenv(\"AZURE_OPENAI_DEPLOYMENT\"),\n",
" # azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
" # api_version=os.getenv(\"AZURE_OPENAI_API_VERSION\"),\n",
" # # Using Azure Active Directory authentication.\n",
Expand Down
14 changes: 8 additions & 6 deletions python/docs/src/getting-started/agent-and-agent-runtime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
"from dataclasses import dataclass\n",
"from agnext.core import BaseAgent, CancellationToken\n",
"\n",
"\n",
"@dataclass\n",
"class MyMessage:\n",
" content: str\n",
"\n",
"\n",
"class MyAgent(BaseAgent):\n",
" def __init__(self):\n",
" super().__init__(\"MyAgent\", subscriptions=[MyMessage])\n",
" def __init__(self) -> None:\n",
" super().__init__(\"MyAgent\", subscriptions=[\"MyMessage\"])\n",
"\n",
" async def on_message(self, message: MyMessage, cancellation_token: CancellationToken) -> None:\n",
" print(f\"Received message: {message.content}\")"
Expand Down Expand Up @@ -131,9 +133,9 @@
],
"source": [
"agent_id = await runtime.get(\"my_agent\")\n",
"run_context = runtime.start() # Start processing messages in the background.\n",
"run_context = runtime.start() # Start processing messages in the background.\n",
"await runtime.send_message(MyMessage(content=\"Hello, World!\"), agent_id)\n",
"await run_context.stop() # Stop processing messages in the background."
"await run_context.stop() # Stop processing messages in the background."
]
},
{
Expand Down Expand Up @@ -173,7 +175,7 @@
"source": [
"run_context = runtime.start()\n",
"# ... Send messages, publish messages, etc.\n",
"await run_context.stop() # This will return immediately but will not cancel\n",
"await run_context.stop() # This will return immediately but will not cancel\n",
"# any in-progress message handling."
]
},
Expand All @@ -198,7 +200,7 @@
"source": [
"run_context = runtime.start()\n",
"# ... Send messages, publish messages, etc.\n",
"await run_context.stop_when_idle() # This will block until the runtime is idle."
"await run_context.stop_when_idle() # This will block until the runtime is idle."
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
"source": [
"from dataclasses import dataclass\n",
"\n",
"\n",
"@dataclass\n",
"class TextMessage:\n",
" content: str\n",
" source: str\n",
"\n",
"\n",
"@dataclass\n",
"class ImageMessage:\n",
" url: str\n",
Expand Down Expand Up @@ -83,6 +85,7 @@
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import CancellationToken\n",
"\n",
"\n",
"class MyAgent(TypeRoutedAgent):\n",
" @message_handler\n",
" async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:\n",
Expand Down Expand Up @@ -183,15 +186,18 @@
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import CancellationToken, AgentId\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"\n",
"class InnerAgent(TypeRoutedAgent):\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n",
" return Message(content=f\"Hello from inner, {message.content}\")\n",
"\n",
"\n",
"class OuterAgent(TypeRoutedAgent):\n",
" def __init__(self, description: str, inner_agent_id: AgentId):\n",
" super().__init__(description)\n",
Expand Down Expand Up @@ -289,12 +295,14 @@
"from agnext.components import TypeRoutedAgent, message_handler\n",
"from agnext.core import CancellationToken\n",
"\n",
"\n",
"class BroadcastingAgent(TypeRoutedAgent):\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> None:\n",
" # Publish a message to all agents in the same namespace.\n",
" await self.publish_message(Message(f\"Publishing a message: {message.content}!\"))\n",
"\n",
"\n",
"class ReceivingAgent(TypeRoutedAgent):\n",
" @message_handler\n",
" async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> None:\n",
Expand Down
40 changes: 22 additions & 18 deletions python/docs/src/getting-started/model-clients.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"\n",
"# Iterate over the stream and print the responses.\n",
"print(\"Streamed responses:\")\n",
"async for response in stream:\n",
"async for response in stream: # type: ignore\n",
" if isinstance(response, str):\n",
" # A partial response is a string.\n",
" print(response, flush=True, end=\"\")\n",
Expand Down Expand Up @@ -184,21 +184,19 @@
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"\n",
"# Create the token provider\n",
"token_provider = get_bearer_token_provider(\n",
" DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\"\n",
")\n",
"token_provider = get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\")\n",
"\n",
"model_client = AzureOpenAIChatCompletionClient(\n",
"az_model_client = AzureOpenAIChatCompletionClient(\n",
" model=\"{your-azure-deployment}\",\n",
" api_version=\"2024-06-01\",\n",
" azure_endpoint=\"https://{your-custom-endpoint}.openai.azure.com/\",\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" # api_key=\"sk-...\", # For key-based authentication.\n",
" model_capabilities={\n",
" \"vision\":True,\n",
" \"function_calling\":True,\n",
" \"json_output\":True,\n",
" }\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -232,23 +230,27 @@
"from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage, OpenAIChatCompletionClient\n",
"from agnext.core import CancellationToken\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"class SimpleAgent(TypeRoutedAgent):\n",
"\n",
"class SimpleAgent(TypeRoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"A simple agent\")\n",
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" \n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n",
" # Prepare input to the chat completion model.\n",
" user_message = UserMessage(content=message.content, source=\"user\")\n",
" response = await self._model_client.create(self._system_messages + [user_message], cancellation_token=cancellation_token)\n",
" response = await self._model_client.create(\n",
" self._system_messages + [user_message], cancellation_token=cancellation_token\n",
" )\n",
" # Return with the model's response.\n",
" assert isinstance(response.content, str)\n",
" return Message(content=response.content)"
]
},
Expand Down Expand Up @@ -308,14 +310,16 @@
}
],
"source": [
" # Create the runtime and register the agent.\n",
"# Create the runtime and register the agent.\n",
"runtime = SingleThreadedAgentRuntime()\n",
"agent = await runtime.register_and_get(\n",
" \"simple-agent\",\n",
" lambda: SimpleAgent(OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-mini\",\n",
" # api_key=\"sk-...\", # Optional if you have an OPENAI_API_KEY set in the environment.\n",
" )), \n",
" lambda: SimpleAgent(\n",
" OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-mini\",\n",
" # api_key=\"sk-...\", # Optional if you have an OPENAI_API_KEY set in the environment.\n",
" )\n",
" ),\n",
")\n",
"# Start the runtime processing messages.\n",
"run_context = runtime.start()\n",
Expand Down
21 changes: 13 additions & 8 deletions python/docs/src/getting-started/multi-agent-design-patterns.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"source": [
"from dataclasses import dataclass\n",
"\n",
"\n",
"@dataclass\n",
"class CodeWritingTask:\n",
" task: str\n",
Expand Down Expand Up @@ -154,7 +155,7 @@
"\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"A code writing agent.\")\n",
" self._system_messages = [\n",
" self._system_messages: List[LLMMessage] = [\n",
" SystemMessage(\n",
" content=\"\"\"You are a proficient coder. You write code to solve problems.\n",
"Work with the reviewer to improve your code.\n",
Expand Down Expand Up @@ -182,7 +183,7 @@
" self._session_memory.setdefault(session_id, []).append(message)\n",
" # Generate a response using the chat completion API.\n",
" response = await self._model_client.create(\n",
" self._system_messages + [UserMessage(content=message.task, source=self.metadata[\"name\"])],\n",
" self._system_messages + [UserMessage(content=message.task, source=self.metadata[\"type\"])],\n",
" cancellation_token=cancellation_token,\n",
" )\n",
" assert isinstance(response.content, str)\n",
Expand Down Expand Up @@ -296,7 +297,7 @@
"\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"A code reviewer agent.\")\n",
" self._system_messages = [\n",
" self._system_messages: List[LLMMessage] = [\n",
" SystemMessage(\n",
" content=\"\"\"You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n",
"Respond using the following JSON format:\n",
Expand All @@ -316,9 +317,15 @@
" @message_handler\n",
" async def handle_code_review_task(self, message: CodeReviewTask, cancellation_token: CancellationToken) -> None:\n",
" # Format the prompt for the code review.\n",
" # Gather the previous feedback if available.\n",
" previous_feedback = \"\"\n",
" if message.session_id in self._session_memory:\n",
" previous_feedback = self._session_memory[message.session_id][-1].review\n",
" previous_review = next(\n",
" (m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewResult)),\n",
" None,\n",
" )\n",
" if previous_review is not None:\n",
" previous_feedback = previous_review.review\n",
" # Store the messages in a temporary memory for this request only.\n",
" self._session_memory.setdefault(message.session_id, []).append(message)\n",
" prompt = f\"\"\"The problem statement is: {message.code_writing_task}\n",
Expand All @@ -334,7 +341,7 @@
"\"\"\"\n",
" # Generate a response using the chat completion API.\n",
" response = await self._model_client.create(\n",
" self._system_messages + [UserMessage(content=prompt, source=self.metadata[\"name\"])],\n",
" self._system_messages + [UserMessage(content=prompt, source=self.metadata[\"type\"])],\n",
" cancellation_token=cancellation_token,\n",
" json_output=True,\n",
" )\n",
Expand Down Expand Up @@ -500,9 +507,7 @@
")\n",
"run_context = runtime.start()\n",
"await runtime.publish_message(\n",
" message=CodeWritingTask(\n",
" task=\"Write a function to find the sum of all even numbers in a list.\"\n",
" ),\n",
" message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n",
" namespace=\"default\",\n",
")\n",
"\n",
Expand Down
10 changes: 7 additions & 3 deletions python/docs/src/getting-started/tools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@
"from agnext.components.tools import FunctionTool\n",
"from agnext.core import CancellationToken\n",
"\n",
"\n",
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
" # Returns a random stock price for demonstration purposes.\n",
" return random.uniform(10, 200)\n",
"\n",
"\n",
"# Create a function tool.\n",
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
"\n",
Expand Down Expand Up @@ -141,6 +143,7 @@
" OpenAIChatCompletionClient,\n",
" SystemMessage,\n",
" UserMessage,\n",
" LLMMessage,\n",
")\n",
"from agnext.components.tool_agent import ToolAgent, ToolException\n",
"from agnext.components.tools import FunctionTool, Tool, ToolSchema\n",
Expand All @@ -155,15 +158,15 @@
"class ToolUseAgent(TypeRoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent: AgentId) -> None:\n",
" super().__init__(\"An agent with tools\")\n",
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
" self._system_messages: List[LLMMessage] = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent = tool_agent\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n",
" # Create a session of messages.\n",
" session = [UserMessage(content=message.content, source=\"user\")]\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Get a response from the model.\n",
" response = await self._model_client.create(\n",
" self._system_messages + session, tools=self._tool_schema, cancellation_token=cancellation_token\n",
Expand Down Expand Up @@ -192,9 +195,10 @@
" response = await self._model_client.create(\n",
" self._system_messages + session, tools=self._tool_schema, cancellation_token=cancellation_token\n",
" )\n",
" session.append(AssistantMessage(content=response.content, source=self.metadata[\"name\"]))\n",
" session.append(AssistantMessage(content=response.content, source=self.metadata[\"type\"]))\n",
"\n",
" # Return the final response.\n",
" assert isinstance(response.content, str)\n",
" return Message(content=response.content)"
]
},
Expand Down
Loading

0 comments on commit 33649c3

Please sign in to comment.