forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add examples constrained generation + multi-step reasoning via Guidan…
…ce + AutoGen (microsoft#937) * Add examples of constrained generation * Add guidance based coder that always returns valid code blocks * Add guidance based agent that always generates a valid json * Add link to guidance nb
- Loading branch information
Showing
2 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Using Guidance with AutoGen\n", | ||
"\n", | ||
"This notebook shows how Guidance can be used to enable structured responses from AutoGen agents. In particular, this notebook focuses on creating agents that always output a valid code block or valid json object.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from guidance import models, gen, user, assistant, system\n", | ||
"from autogen import AssistantAgent, UserProxyAgent, Agent\n", | ||
"from autogen import config_list_from_json\n", | ||
"\n", | ||
"llm_config = config_list_from_json(\"OAI_CONFIG_LIST\")[0] # use the first config\n", | ||
"gpt = models.OpenAI(\"gpt-4\", api_key=llm_config.get(\"api_key\"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The example below uses guidance to create a `guidance_coder` agent that only responds with valid code blocks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33muser\u001b[0m (to guidance_coder):\n", | ||
"\n", | ||
"Plot and save a chart of nvidia and tsla stock price change YTD.\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[33mguidance_coder\u001b[0m (to user):\n", | ||
"\n", | ||
"```python\n", | ||
"# filename: stock_price_change.py\n", | ||
"\n", | ||
"import pandas as pd\n", | ||
"import yfinance as yf\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from datetime import datetime\n", | ||
"\n", | ||
"# Get today's date\n", | ||
"today = datetime.today().strftime('%Y-%m-%d')\n", | ||
"\n", | ||
"# Download stock data\n", | ||
"nvda = yf.download('NVDA', start='2022-01-01', end=today)\n", | ||
"tsla = yf.download('TSLA', start='2022-01-01', end=today)\n", | ||
"\n", | ||
"# Calculate percentage change in closing price\n", | ||
"nvda['Pct Change'] = nvda['Close'].pct_change()\n", | ||
"tsla['Pct Change'] = tsla['Close'].pct_change()\n", | ||
"\n", | ||
"# Plot percentage change\n", | ||
"plt.figure(figsize=(14,7))\n", | ||
"plt.plot(nvda['Pct Change'], label='NVDA')\n", | ||
"plt.plot(tsla['Pct Change'], label='TSLA')\n", | ||
"plt.title('Nvidia and Tesla Stock Price Change YTD')\n", | ||
"plt.xlabel('Date')\n", | ||
"plt.ylabel('Percentage Change')\n", | ||
"plt.legend()\n", | ||
"plt.grid(True)\n", | ||
"\n", | ||
"# Save the plot as a PNG file\n", | ||
"plt.savefig('stock_price_change.png')\n", | ||
"```\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[31m\n", | ||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n", | ||
"\u001b[31m\n", | ||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"execute_code was called without specifying a value for use_docker. Since the python docker package is not available, code will be run natively. Note: this fallback behavior is subject to change\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33muser\u001b[0m (to guidance_coder):\n", | ||
"\n", | ||
"exitcode: 0 (execution succeeded)\n", | ||
"Code output: \n", | ||
"\n", | ||
"[*********************100%%**********************] 1 of 1 completed\n", | ||
"\n", | ||
"[*********************100%%**********************] 1 of 1 completed\n", | ||
"\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[33mguidance_coder\u001b[0m (to user):\n", | ||
"\n", | ||
"Great! The code executed successfully and the chart of Nvidia and Tesla stock price change Year-To-Date (YTD) has been saved as 'stock_price_change.png' in the current directory. You can open this file to view the chart.\n", | ||
"\n", | ||
"TERMINATE\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import re\n", | ||
"\n", | ||
"def is_valid_code_block(code):\n", | ||
" pattern = r\"```[\\w\\s]*\\n([\\s\\S]*?)\\n```\"\n", | ||
" match = re.search(pattern, code)\n", | ||
" if match:\n", | ||
" return True\n", | ||
" else:\n", | ||
" return False\n", | ||
"\n", | ||
"\n", | ||
"def generate_structured_response(recipient, messages, sender, config):\n", | ||
" gpt = models.OpenAI(\"gpt-4\", api_key=llm_config.get(\"api_key\"), echo=False)\n", | ||
" \n", | ||
" # populate the recipient with the messages from the history\n", | ||
" with system():\n", | ||
" lm = gpt + recipient.system_message\n", | ||
" \n", | ||
" for message in messages:\n", | ||
" if message.get(\"role\") == \"user\":\n", | ||
" with user():\n", | ||
" lm += message.get(\"content\")\n", | ||
" else:\n", | ||
" with assistant():\n", | ||
" lm += message.get(\"content\")\n", | ||
"\n", | ||
" # generate a new response and store it\n", | ||
" with assistant():\n", | ||
" lm += gen(name=\"initial_response\")\n", | ||
" # ask the agent to reflect on the nature of the response and store it\n", | ||
" with user():\n", | ||
" lm += \"Does the very last response from you contain code? Respond with yes or no.\"\n", | ||
" with assistant():\n", | ||
" lm += gen(name=\"contains_code\")\n", | ||
" # if the response contains code, ask the agent to generate a proper code block\n", | ||
" if \"yes\" in lm[\"contains_code\"].lower():\n", | ||
" with user():\n", | ||
" lm += \"Respond with a single blocks containing the valid code. Valid code blocks start with ```\"\n", | ||
" with assistant():\n", | ||
" lm += \"```\" + gen(name=\"code\")\n", | ||
" response = \"```\" + lm[\"code\"]\n", | ||
" \n", | ||
" is_valid = is_valid_code_block(response)\n", | ||
" if not is_valid:\n", | ||
" raise ValueError(f\"Failed to generate a valid code block\\n {response}\")\n", | ||
" \n", | ||
" # otherwise, just use the initial response\n", | ||
" else:\n", | ||
" response = lm[\"initial_response\"]\n", | ||
" \n", | ||
" return True, response\n", | ||
"\n", | ||
"\n", | ||
"guidance_agent = AssistantAgent(\"guidance_coder\", llm_config=llm_config)\n", | ||
"guidance_agent.register_reply(Agent, generate_structured_response, 1)\n", | ||
"user_proxy = UserProxyAgent(\"user\", human_input_mode=\"TERMINATE\", code_execution_config={\"work_dir\": \"coding\"},\n", | ||
" is_termination_msg=lambda msg: \"TERMINATE\" in msg.get(\"content\"))\n", | ||
"user_proxy.initiate_chat(guidance_agent, message=\"Plot and save a chart of nvidia and tsla stock price change YTD.\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The example below uses Guidance to enable a `guidance_labeler` agent that only responds with a valid JSON that labels a given comment/joke." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33muser\u001b[0m (to guidance_labeler):\n", | ||
"\n", | ||
"\n", | ||
"Label the TEXT via the following instructions:\n", | ||
" \n", | ||
"The label must be a JSON of the format:\n", | ||
"{\n", | ||
" \"label\": str,\n", | ||
" \"explanation\": str\n", | ||
"}\n", | ||
" \n", | ||
"TEXT: what did the fish say when it bumped into a wall? Dam!\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[33mguidance_labeler\u001b[0m (to user):\n", | ||
"\n", | ||
"{\"label\":\"Joke\",\"explanation\":\"The text is a joke, using a play on words where the fish says 'Dam!' after bumping into a wall, which is a pun on the word 'damn' and a 'dam' which is a barrier that stops or restricts the flow of water, often creating a reservoir, and is something a fish might encounter.\"}\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pydantic import BaseModel\n", | ||
"\n", | ||
"class Response(BaseModel):\n", | ||
" label: str\n", | ||
" explanation: str\n", | ||
"\n", | ||
"response_prompt_instructions = \"\"\"The label must be a JSON of the format:\n", | ||
"{\n", | ||
" \"label\": str,\n", | ||
" \"explanation\": str\n", | ||
"}\"\"\"\n", | ||
"\n", | ||
"def generate_structured_response(recipient, messages, sender, config):\n", | ||
" gpt = models.OpenAI(\"gpt-4\", api_key=llm_config.get(\"api_key\"), echo=False)\n", | ||
" \n", | ||
" # populate the recipient with the messages from the history\n", | ||
" with system():\n", | ||
" lm = gpt + recipient.system_message\n", | ||
" \n", | ||
" for message in messages:\n", | ||
" if message.get(\"role\") == \"user\":\n", | ||
" with user():\n", | ||
" lm += message.get(\"content\")\n", | ||
" else:\n", | ||
" with assistant():\n", | ||
" lm += message.get(\"content\")\n", | ||
"\n", | ||
" # generate a new response and store it\n", | ||
" with assistant():\n", | ||
" lm += gen(name=\"initial_response\")\n", | ||
" # ask the agent to reflect on the nature of the response and store it\n", | ||
" with user():\n", | ||
" lm += \"Does the very last response from you contain JSON object? Respond with yes or no.\"\n", | ||
" with assistant():\n", | ||
" lm += gen(name=\"contains_json\")\n", | ||
" # if the response contains code, ask the agent to generate a proper code block\n", | ||
" if \"yes\" in lm[\"contains_json\"].lower():\n", | ||
" with user():\n", | ||
" lm += \"What was that JSON object? Only respond with that valid JSON string. A valid JSON string starts with {\"\n", | ||
" with assistant():\n", | ||
" lm += \"{\" + gen(name=\"json\")\n", | ||
" response = \"{\" + lm[\"json\"]\n", | ||
" # verify that the response is valid json\n", | ||
" try:\n", | ||
" response_obj = Response.model_validate_json(response)\n", | ||
" response = response_obj.model_dump_json()\n", | ||
" except Exception as e:\n", | ||
" response = str(e)\n", | ||
" # otherwise, just use the initial response\n", | ||
" else:\n", | ||
" response = lm[\"initial_response\"]\n", | ||
"\n", | ||
" return True, response\n", | ||
"\n", | ||
"guidance_agent = AssistantAgent(\"guidance_labeler\", llm_config=llm_config, system_message=\"You are a helpful assistant\")\n", | ||
"guidance_agent.register_reply(Agent, generate_structured_response, 1)\n", | ||
"user_proxy = UserProxyAgent(\"user\", human_input_mode=\"ALWAYS\", code_execution_config=False)\n", | ||
"user_proxy.initiate_chat(guidance_agent, message=f\"\"\"\n", | ||
"Label the TEXT via the following instructions:\n", | ||
" \n", | ||
"{response_prompt_instructions}\n", | ||
" \n", | ||
"TEXT: what did the fish say when it bumped into a wall? Dam!\n", | ||
"\n", | ||
"\"\"\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters