From 7d057a93b2bb51348ed73237bb60070794ea30dd Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 9 Jun 2024 20:29:34 -0700 Subject: [PATCH 1/5] Autogenstudio docs (#2890) * add autogenstudio docs * update ags readme to point to docs page * update docs * update docs * update faqs * update, fix typos --- samples/apps/autogen-studio/README.md | 66 +------ .../blog/2023-12-01-AutoGenStudio/index.mdx | 3 + website/docs/autogen-studio/faqs.md | 86 +++++++++ .../docs/autogen-studio/getting-started.md | 121 +++++++++++++ .../autogen-studio/img/agent_assistant.png | 3 + .../autogen-studio/img/agent_groupchat.png | 3 + website/docs/autogen-studio/img/agent_new.png | 3 + .../autogen-studio/img/agent_skillsmodel.png | 3 + .../autogen-studio/img/ara_stockprices.png | 3 + website/docs/autogen-studio/img/model_new.png | 3 + .../docs/autogen-studio/img/model_openai.png | 3 + website/docs/autogen-studio/img/skill.png | 3 + .../docs/autogen-studio/img/workflow_chat.png | 3 + .../autogen-studio/img/workflow_export.png | 3 + .../docs/autogen-studio/img/workflow_new.png | 3 + .../autogen-studio/img/workflow_profile.png | 3 + .../img/workflow_sequential.png | 3 + .../docs/autogen-studio/img/workflow_test.png | 3 + website/docs/autogen-studio/usage.md | 114 ++++++++++++ website/docusaurus.config.js | 33 ++-- website/sidebars.js | 165 +++++++++++------- 21 files changed, 484 insertions(+), 146 deletions(-) create mode 100644 website/docs/autogen-studio/faqs.md create mode 100644 website/docs/autogen-studio/getting-started.md create mode 100644 website/docs/autogen-studio/img/agent_assistant.png create mode 100644 website/docs/autogen-studio/img/agent_groupchat.png create mode 100644 website/docs/autogen-studio/img/agent_new.png create mode 100644 website/docs/autogen-studio/img/agent_skillsmodel.png create mode 100644 website/docs/autogen-studio/img/ara_stockprices.png create mode 100644 website/docs/autogen-studio/img/model_new.png create mode 100644 website/docs/autogen-studio/img/model_openai.png create mode 100644 website/docs/autogen-studio/img/skill.png create mode 100644 website/docs/autogen-studio/img/workflow_chat.png create mode 100644 website/docs/autogen-studio/img/workflow_export.png create mode 100644 website/docs/autogen-studio/img/workflow_new.png create mode 100644 website/docs/autogen-studio/img/workflow_profile.png create mode 100644 website/docs/autogen-studio/img/workflow_sequential.png create mode 100644 website/docs/autogen-studio/img/workflow_test.png create mode 100644 website/docs/autogen-studio/usage.md diff --git a/samples/apps/autogen-studio/README.md b/samples/apps/autogen-studio/README.md index 1e60b5362dba..05a2a58f800a 100644 --- a/samples/apps/autogen-studio/README.md +++ b/samples/apps/autogen-studio/README.md @@ -12,24 +12,14 @@ Code for AutoGen Studio is on GitHub at [microsoft/autogen](https://github.com/m > **Note**: AutoGen Studio is meant to help you rapidly prototype multi-agent workflows and demonstrate an example of end user interfaces built with AutoGen. It is not meant to be a production-ready app. > [!WARNING] -> AutoGen Studio is currently under active development and we are iterating quickly. Kindly consider that we may introduce breaking changes in the releases during the upcoming weeks, and also the `README` might be outdated. We'll update the `README` as soon as we stabilize the API. +> AutoGen Studio is currently under active development and we are iterating quickly. Kindly consider that we may introduce breaking changes in the releases during the upcoming weeks, and also the `README` might be outdated. Please see the AutoGen Studio [docs](https://microsoft.github.io/autogen/docs/autogen-studio/getting-started) page for the most up-to-date information. + +**Updates** -> [!NOTE] Updates > April 17: AutoGen Studio database layer is now rewritten to use [SQLModel](https://sqlmodel.tiangolo.com/) (Pydantic + SQLAlchemy). This provides entity linking (skills, models, agents and workflows are linked via association tables) and supports multiple [database backend dialects](https://docs.sqlalchemy.org/en/20/dialects/) supported in SQLAlchemy (SQLite, PostgreSQL, MySQL, Oracle, Microsoft SQL Server). The backend database can be specified a `--database-uri` argument when running the application. For example, `autogenstudio ui --database-uri sqlite:///database.sqlite` for SQLite and `autogenstudio ui --database-uri postgresql+psycopg://user:password@localhost/dbname` for PostgreSQL. > March 12: Default directory for AutoGen Studio is now /home//.autogenstudio. You can also specify this directory using the `--appdir` argument when running the application. For example, `autogenstudio ui --appdir /path/to/folder`. This will store the database and other files in the specified directory e.g. `/path/to/folder/database.sqlite`. `.env` files in that directory will be used to set environment variables for the app. -### Capabilities / Roadmap - -Some of the capabilities supported by the app frontend include the following: - -- [x] Build / Configure agents (currently supports two agent workflows based on `UserProxyAgent` and `AssistantAgent`), modify their configuration (e.g. skills, temperature, model, agent system message, model etc) and compose them into workflows. -- [x] Chat with agent works and specify tasks. -- [x] View agent messages and output files in the UI from agent runs. -- [x] Add interaction sessions to a gallery. -- [ ] Support for more complex agent workflows (e.g. `GroupChat` workflows). -- [ ] Improved user experience (e.g., streaming intermediate model output, better summarization of agent responses, etc). - Project Structure: - _autogenstudio/_ code for the backend classes and web api (FastAPI) @@ -97,32 +87,6 @@ AutoGen Studio also takes several parameters to customize the application: Now that you have AutoGen Studio installed and running, you are ready to explore its capabilities, including defining and modifying agent workflows, interacting with agents and sessions, and expanding agent skills. -## Capabilities - -AutoGen Studio proposes some high-level concepts. - -**Agent Workflow**: An agent workflow is a specification of a set of agents that can work together to accomplish a task. The simplest version of this is a setup with two agents – a user proxy agent (that represents a user i.e. it compiles code and prints result) and an assistant that can address task requests (e.g., generating plans, writing code, evaluating responses, proposing error recovery steps, etc.). A more complex flow could be a group chat where even more agents work towards a solution. - -**Session**: A session refers to a period of continuous interaction or engagement with an agent workflow, typically characterized by a sequence of activities or operations aimed at achieving specific objectives. It includes the agent workflow configuration, the interactions between the user and the agents. A session can be “published” to a “gallery”. - -**Skills**: Skills are functions (e.g., Python functions) that describe how to solve a task. In general, a good skill has a descriptive name (e.g. `generate_images`), extensive docstrings and good defaults (e.g., writing out files to disk for persistence and reuse). You can add new skills AutoGen Studio app via the provided UI. At inference time, these skills are made available to the assistant agent as they address your tasks. - -## Example Usage - -Consider the following query. - -``` -Plot a chart of NVDA and TESLA stock price YTD. Save the result to a file named nvda_tesla.png -``` - -The agent workflow responds by _writing and executing code_ to create a python program to generate the chart with the stock prices. - -> Note than there could be multiple turns between the `AssistantAgent` and the `UserProxyAgent` to produce and execute the code in order to complete the task. - -![ARA](./docs/ara_stockprices.png) - -> Note: You can also view the debug console that generates useful information to see how the agents are interacting in the background. - ## Contribution Guide We welcome contributions to AutoGen Studio. We recommend the following general steps to contribute to the project: @@ -137,29 +101,7 @@ We welcome contributions to AutoGen Studio. We recommend the following general s ## FAQ -**Q: How do I specify the directory where files(e.g. database) are stored?** - -A: You can specify the directory where files are stored by setting the `--appdir` argument when running the application. For example, `autogenstudio ui --appdir /path/to/folder`. This will store the database (default) and other files in the specified directory e.g. `/path/to/folder/database.sqlite`. - -**Q: Where can I adjust the default skills, agent and workflow configurations?** -A: You can modify agent configurations directly from the UI or by editing the [dbdefaults.json](autogenstudio/utils/dbdefaults.json) file which is used to initialize the database. - -**Q: If I want to reset the entire conversation with an agent, how do I go about it?** -A: To reset your conversation history, you can delete the `database.sqlite` file in the `--appdir` directory. This will reset the entire conversation history. To delete user files, you can delete the `files` directory in the `--appdir` directory. - -**Q: Is it possible to view the output and messages generated by the agents during interactions?** -A: Yes, you can view the generated messages in the debug console of the web UI, providing insights into the agent interactions. Alternatively, you can inspect the `database.sqlite` file for a comprehensive record of messages. - -**Q: Can I use other models with AutoGen Studio?** -Yes. AutoGen standardizes on the openai model api format, and you can use any api server that offers an openai compliant endpoint. In the AutoGen Studio UI, each agent has an `llm_config` field where you can input your model endpoint details including `model`, `api key`, `base url`, `model type` and `api version`. For Azure OpenAI models, you can find these details in the Azure portal. Note that for Azure OpenAI, the `model name` is the deployment id or engine, and the `model type` is "azure". -For other OSS models, we recommend using a server such as vllm to instantiate an openai compliant endpoint. - -**Q: The server starts but I can't access the UI** -A: If you are running the server on a remote machine (or a local machine that fails to resolve localhost correstly), you may need to specify the host address. By default, the host address is set to `localhost`. You can specify the host address using the `--host ` argument. For example, to start the server on port 8081 and local address such that it is accessible from other machines on the network, you can run the following command: - -```bash -autogenstudio ui --port 8081 --host 0.0.0.0 -``` +Please refer to the AutoGen Studio [FAQs](https://microsoft.github.io/autogen/docs/autogen-studio/faqs) page for more information. ## Acknowledgements diff --git a/website/blog/2023-12-01-AutoGenStudio/index.mdx b/website/blog/2023-12-01-AutoGenStudio/index.mdx index 014aa870e0c5..49151f7b355f 100644 --- a/website/blog/2023-12-01-AutoGenStudio/index.mdx +++ b/website/blog/2023-12-01-AutoGenStudio/index.mdx @@ -25,6 +25,9 @@ To help you rapidly prototype multi-agent solutions for your tasks, we are intro - Explicitly add skills to your agents and accomplish more tasks. - Publish your sessions to a local gallery. + +See the official AutoGen Studio documentation [here](https://microsoft.github.io/autogen/docs/autogen-studio/getting-started) for more details. + AutoGen Studio is open source [code here](https://github.com/microsoft/autogen/tree/main/samples/apps/autogen-studio), and can be installed via pip. Give it a try! ```bash diff --git a/website/docs/autogen-studio/faqs.md b/website/docs/autogen-studio/faqs.md new file mode 100644 index 000000000000..82426578daad --- /dev/null +++ b/website/docs/autogen-studio/faqs.md @@ -0,0 +1,86 @@ +# AutoGen Studio FAQs + +## Q: How do I specify the directory where files(e.g. database) are stored? + +A: You can specify the directory where files are stored by setting the `--appdir` argument when running the application. For example, `autogenstudio ui --appdir /path/to/folder`. This will store the database (default) and other files in the specified directory e.g. `/path/to/folder/database.sqlite`. + +## Q: Where can I adjust the default skills, agent and workflow configurations? + +A: You can modify agent configurations directly from the UI or by editing the `init_db_samples` function in the `autogenstudio/database/utils.py` file which is used to initialize the database. + +## Q: If I want to reset the entire conversation with an agent, how do I go about it? + +A: To reset your conversation history, you can delete the `database.sqlite` file in the `--appdir` directory. This will reset the entire conversation history. To delete user files, you can delete the `files` directory in the `--appdir` directory. + +## Q: Is it possible to view the output and messages generated by the agents during interactions? + +A: Yes, you can view the generated messages in the debug console of the web UI, providing insights into the agent interactions. Alternatively, you can inspect the `database.sqlite` file for a comprehensive record of messages. + +## Q: Can I use other models with AutoGen Studio? + +Yes. AutoGen standardizes on the openai model api format, and you can use any api server that offers an openai compliant endpoint. In the AutoGen Studio UI, each agent has an `llm_config` field where you can input your model endpoint details including `model`, `api key`, `base url`, `model type` and `api version`. For Azure OpenAI models, you can find these details in the Azure portal. Note that for Azure OpenAI, the `model name` is the deployment id or engine, and the `model type` is "azure". +For other OSS models, we recommend using a server such as vllm, LMStudio, Ollama, to instantiate an openai compliant endpoint. + +## Q: The server starts but I can't access the UI + +A: If you are running the server on a remote machine (or a local machine that fails to resolve localhost correctly), you may need to specify the host address. By default, the host address is set to `localhost`. You can specify the host address using the `--host ` argument. For example, to start the server on port 8081 and local address such that it is accessible from other machines on the network, you can run the following command: + +```bash +autogenstudio ui --port 8081 --host 0.0.0.0 +``` + +## Q: Can I export my agent workflows for use in a python app? + +Yes. In the Build view, you can click the export button to save your agent workflow as a JSON file. This file can be imported in a python application using the `WorkflowManager` class. For example: + +```python + +from autogenstudio import WorkflowManager +# load workflow from exported json workflow file. +workflow_manager = WorkflowManager(workflow="path/to/your/workflow_.json") + +# run the workflow on a task +task_query = "What is the height of the Eiffel Tower?. Dont write code, just respond to the question." +workflow_manager.run(message=task_query) + +``` + +## Q: Can I deploy my agent workflows as APIs? + +Yes. You can launch the workflow as an API endpoint from the command line using the `autogenstudio` commandline tool. For example: + +```bash +autogenstudio serve --workflow=workflow.json --port=5000 +``` + +Similarly, the workflow launch command above can be wrapped into a Dockerfile that can be deployed on cloud services like Azure Container Apps or Azure Web Apps. + +## Q: Can I run AutoGen Studio in a Docker container? + +A: Yes, you can run AutoGen Studio in a Docker container. You can build the Docker image using the provided [Dockerfile](https://github.com/microsoft/autogen/blob/autogenstudio/samples/apps/autogen-studio/Dockerfile) and run the container using the following commands: + +```bash +FROM python:3.10 + +WORKDIR /code + +RUN pip install -U gunicorn autogenstudio + +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH \ + AUTOGENSTUDIO_APPDIR=/home/user/app + +WORKDIR $HOME/app + +COPY --chown=user . $HOME/app + +CMD gunicorn -w $((2 * $(getconf _NPROCESSORS_ONLN) + 1)) --timeout 12600 -k uvicorn.workers.UvicornWorker autogenstudio.web.app:app --bind "0.0.0.0:8081" +``` + +Using Gunicorn as the application server for improved performance is recommended. To run AutoGen Studio with Gunicorn, you can use the following command: + +```bash +gunicorn -w $((2 * $(getconf _NPROCESSORS_ONLN) + 1)) --timeout 12600 -k uvicorn.workers.UvicornWorker autogenstudio.web.app:app --bind +``` diff --git a/website/docs/autogen-studio/getting-started.md b/website/docs/autogen-studio/getting-started.md new file mode 100644 index 000000000000..eee1707b7cc3 --- /dev/null +++ b/website/docs/autogen-studio/getting-started.md @@ -0,0 +1,121 @@ +# AutoGen Studio - Getting Started + +[![PyPI version](https://badge.fury.io/py/autogenstudio.svg)](https://badge.fury.io/py/autogenstudio) +[![Downloads](https://static.pepy.tech/badge/autogenstudio/week)](https://pepy.tech/project/autogenstudio) + +![ARA](./img/ara_stockprices.png) + +AutoGen Studio is an low-code interface built to help you rapidly prototype AI agents, enhance them with skills, compose them into workflows and interact with them to accomplish tasks. It is built on top of the [AutoGen](https://microsoft.github.io/autogen) framework, which is a toolkit for building AI agents. + +Code for AutoGen Studio is on GitHub at [microsoft/autogen](https://github.com/microsoft/autogen/tree/main/samples/apps/autogen-studio) + +> **Note**: AutoGen Studio is meant to help you rapidly prototype multi-agent workflows and demonstrate an example of end user interfaces built with AutoGen. It is not meant to be a production-ready app. Developers are encouraged to use the AutoGen framework to build their own applications, implementing authentication, security and other features required for deployed applications. + +**Updates** + +- April 17: AutoGen Studio database layer is now rewritten to use [SQLModel](https://sqlmodel.tiangolo.com/) (Pydantic + SQLAlchemy). This provides entity linking (skills, models, agents and workflows are linked via association tables) and supports multiple [database backend dialects](https://docs.sqlalchemy.org/en/20/dialects/) supported in SQLAlchemy (SQLite, PostgreSQL, MySQL, Oracle, Microsoft SQL Server). The backend database can be specified with a `--database-uri` argument when running the application. For example, `autogenstudio ui --database-uri sqlite:///database.sqlite` for SQLite and `autogenstudio ui --database-uri postgresql+psycopg://user:password@localhost/dbname` for PostgreSQL. + +- March 12: Default directory for AutoGen Studio is now /home//.autogenstudio. You can also specify this directory using the `--appdir` argument when running the application. For example, `autogenstudio ui --appdir /path/to/folder`. This will store the database and other files in the specified directory e.g. `/path/to/folder/database.sqlite`. `.env` files in that directory will be used to set environment variables for the app. + +### Installation + +There are two ways to install AutoGen Studio - from PyPi or from source. We **recommend installing from PyPi** unless you plan to modify the source code. + +1. **Install from PyPi** + + We recommend using a virtual environment (e.g., conda) to avoid conflicts with existing Python packages. With Python 3.10 or newer active in your virtual environment, use pip to install AutoGen Studio: + + ```bash + pip install autogenstudio + ``` + +2. **Install from Source** + + > Note: This approach requires some familiarity with building interfaces in React. + + If you prefer to install from source, ensure you have Python 3.10+ and Node.js (version above 14.15.0) installed. Here's how you get started: + + - Clone the AutoGen Studio repository and install its Python dependencies: + + ```bash + pip install -e . + ``` + + - Navigate to the `samples/apps/autogen-studio/frontend` directory, install dependencies, and build the UI: + + ```bash + npm install -g gatsby-cli + npm install --global yarn + cd frontend + yarn install + yarn build + ``` + +For Windows users, to build the frontend, you may need alternative commands to build the frontend. + +```bash + + gatsby clean && rmdir /s /q ..\\autogenstudio\\web\\ui 2>nul & (set \"PREFIX_PATH_VALUE=\" || ver>nul) && gatsby build --prefix-paths && xcopy /E /I /Y public ..\\autogenstudio\\web\\ui + +``` + +### Running the Application + +Once installed, run the web UI by entering the following in your terminal: + +```bash +autogenstudio ui --port 8081 +``` + +This will start the application on the specified port. Open your web browser and go to `http://localhost:8081/` to begin using AutoGen Studio. + +AutoGen Studio also takes several parameters to customize the application: + +- `--host ` argument to specify the host address. By default, it is set to `localhost`. Y +- `--appdir ` argument to specify the directory where the app files (e.g., database and generated user files) are stored. By default, it is set to the a `.autogenstudio` directory in the user's home directory. +- `--port ` argument to specify the port number. By default, it is set to `8080`. +- `--reload` argument to enable auto-reloading of the server when changes are made to the code. By default, it is set to `False`. +- `--database-uri` argument to specify the database URI. Example values include `sqlite:///database.sqlite` for SQLite and `postgresql+psycopg://user:password@localhost/dbname` for PostgreSQL. If this is not specified, the database URI defaults to a `database.sqlite` file in the `--appdir` directory. + +Now that you have AutoGen Studio installed and running, you are ready to explore its capabilities, including defining and modifying agent workflows, interacting with agents and sessions, and expanding agent skills. + +### Capabilities / Roadmap + +Some of the capabilities supported by the app frontend include the following: + +- [x] Build / Configure agents (currently supports two agent workflows based on `UserProxyAgent` and `AssistantAgent`), modify their configuration (e.g. skills, temperature, model, agent system message, model etc) and compose them into workflows. +- [x] Chat with agent workflows and specify tasks. +- [x] View agent messages and output files in the UI from agent runs. +- [x] Support for more complex agent workflows (e.g. `GroupChat` and `Sequential` workflows). +- [x] Improved user experience (e.g., streaming intermediate model output, better summarization of agent responses, etc). + +Review project roadmap and issues [here](https://github.com/microsoft/autogen/issues/737) . + +Project Structure: + +- _autogenstudio/_ code for the backend classes and web api (FastAPI) +- _frontend/_ code for the webui, built with Gatsby and TailwindCSS + +## Contribution Guide + +We welcome contributions to AutoGen Studio. We recommend the following general steps to contribute to the project: + +- Review the overall AutoGen project [contribution guide](https://github.com/microsoft/autogen?tab=readme-ov-file#contributing) +- Please review the AutoGen Studio [roadmap](https://github.com/microsoft/autogen/issues/737) to get a sense of the current priorities for the project. Help is appreciated especially with Studio issues tagged with `help-wanted` +- Please initiate a discussion on the roadmap issue or a new issue to discuss your proposed contribution. +- Please review the autogenstudio dev branch here [dev branch](https://github.com/microsoft/autogen/tree/autogenstudio) and use as a base for your contribution. This way, your contribution will be aligned with the latest changes in the AutoGen Studio project. +- Submit a pull request with your contribution! +- If you are modifying AutoGen Studio, it has its own devcontainer. See instructions in `.devcontainer/README.md` to use it +- Please use the tag `studio` for any issues, questions, and PRs related to Studio + +## A Note on Security + +AutoGen Studio is a research prototype and is not meant to be used in a production environment. Some baseline practices are encouraged e.g., using Docker code execution environment for your agents. + +However, other considerations such as rigorous tests related to jailbreaking, ensuring LLMs only have access to the right keys of data given the end user's permissions, and other security features are not implemented in AutoGen Studio. + +If you are building a production application, please use the AutoGen framework and implement the necessary security features. + +## Acknowledgements + +AutoGen Studio is Based on the [AutoGen](https://microsoft.github.io/autogen) project. It was adapted from a research prototype built in October 2023 (original credits: Gagan Bansal, Adam Fourney, Victor Dibia, Piali Choudhury, Saleema Amershi, Ahmed Awadallah, Chi Wang). diff --git a/website/docs/autogen-studio/img/agent_assistant.png b/website/docs/autogen-studio/img/agent_assistant.png new file mode 100644 index 000000000000..e34e90e36654 --- /dev/null +++ b/website/docs/autogen-studio/img/agent_assistant.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8eff59d97c9fbdf118eefe071894125d6421cad6b428c3427d61630d57a3e8 +size 133246 diff --git a/website/docs/autogen-studio/img/agent_groupchat.png b/website/docs/autogen-studio/img/agent_groupchat.png new file mode 100644 index 000000000000..516cc28c6f98 --- /dev/null +++ b/website/docs/autogen-studio/img/agent_groupchat.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:060d2bfb9c38da7535015718202e17cdd6c51545a0a986dedfe6c91bd8accb52 +size 146086 diff --git a/website/docs/autogen-studio/img/agent_new.png b/website/docs/autogen-studio/img/agent_new.png new file mode 100644 index 000000000000..696e794320ac --- /dev/null +++ b/website/docs/autogen-studio/img/agent_new.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18de683a302d4cfaf218b3d57d6f72d17925b589c15ea0604e4b3bd03f6b464c +size 141037 diff --git a/website/docs/autogen-studio/img/agent_skillsmodel.png b/website/docs/autogen-studio/img/agent_skillsmodel.png new file mode 100644 index 000000000000..fea9113e503d --- /dev/null +++ b/website/docs/autogen-studio/img/agent_skillsmodel.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d85a85e931123b404ab1f3d20e2fe52a0e874479f5b36a6d56cd3ffaa0f9991b +size 147060 diff --git a/website/docs/autogen-studio/img/ara_stockprices.png b/website/docs/autogen-studio/img/ara_stockprices.png new file mode 100644 index 000000000000..f5adf6256e55 --- /dev/null +++ b/website/docs/autogen-studio/img/ara_stockprices.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e3340a765da6dff6585c8b2e8a4014df0c94b537d62d341d2d0d45627bbc345 +size 198222 diff --git a/website/docs/autogen-studio/img/model_new.png b/website/docs/autogen-studio/img/model_new.png new file mode 100644 index 000000000000..424c7e437273 --- /dev/null +++ b/website/docs/autogen-studio/img/model_new.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82cf098881c1b318aeec3858aedbc80dea3e80e6d34c0dbd36d721a8e14cc058 +size 94667 diff --git a/website/docs/autogen-studio/img/model_openai.png b/website/docs/autogen-studio/img/model_openai.png new file mode 100644 index 000000000000..9b107c60439a --- /dev/null +++ b/website/docs/autogen-studio/img/model_openai.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:696397efe3a289f5dd084a5a7fbfe3f151adb21a19be617d3e66255acc4a404d +size 90123 diff --git a/website/docs/autogen-studio/img/skill.png b/website/docs/autogen-studio/img/skill.png new file mode 100644 index 000000000000..5357fff4c05e --- /dev/null +++ b/website/docs/autogen-studio/img/skill.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f675444d66c0f6dbca9756b92d2bd166cd29eb645efadc38b7331ab891bef204 +size 232801 diff --git a/website/docs/autogen-studio/img/workflow_chat.png b/website/docs/autogen-studio/img/workflow_chat.png new file mode 100644 index 000000000000..a83462146b3f --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_chat.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b14764c16149a1094ba95612e84fd28ef778485cc026a1cb4904a8c3f0b7815 +size 127639 diff --git a/website/docs/autogen-studio/img/workflow_export.png b/website/docs/autogen-studio/img/workflow_export.png new file mode 100644 index 000000000000..b8ef14c4c11a --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_export.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db5e6f7171a4de9ddfeb6d1e29d7dac2464f727720438ae3433cf78ffe8b75ce +size 204265 diff --git a/website/docs/autogen-studio/img/workflow_new.png b/website/docs/autogen-studio/img/workflow_new.png new file mode 100644 index 000000000000..52f864016e15 --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_new.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64614cd603aa384270075788253566a8035bd0d0011c28af0476f6e484111e4c +size 90426 diff --git a/website/docs/autogen-studio/img/workflow_profile.png b/website/docs/autogen-studio/img/workflow_profile.png new file mode 100644 index 000000000000..0464bccfc483 --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_profile.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ad630cdf09112be8831c830f516a2ec061de1d0097e03d205eda982ab408a63 +size 283288 diff --git a/website/docs/autogen-studio/img/workflow_sequential.png b/website/docs/autogen-studio/img/workflow_sequential.png new file mode 100644 index 000000000000..2fe6c76f0061 --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_sequential.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:670715663ec78b47d53e2689ad2853e07f99ac498ed890f9bdd36c309e52758f +size 117232 diff --git a/website/docs/autogen-studio/img/workflow_test.png b/website/docs/autogen-studio/img/workflow_test.png new file mode 100644 index 000000000000..1146ecfa1f99 --- /dev/null +++ b/website/docs/autogen-studio/img/workflow_test.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0567731649d732d1bacd557b94b5eec87b8d491fa4207f4a8e29170ee56419d +size 258139 diff --git a/website/docs/autogen-studio/usage.md b/website/docs/autogen-studio/usage.md new file mode 100644 index 000000000000..e63c040fb325 --- /dev/null +++ b/website/docs/autogen-studio/usage.md @@ -0,0 +1,114 @@ +# Using AutoGen Studio + +AutoGen Studio supports the declarative creation of an agent workflow and tasks can be specified and run in a chat interface for the agents to complete. The expected usage behavior is that developers can create skills and models, _attach_ them to agents, and compose agents into workflows that can be tested interactively in the chat interface. + +## Building an Agent Workflow + +AutoGen Studio implements several entities that are ultimately composed into a workflow. + +### Skills + +A skill is a python function that implements the solution to a task. In general, a good skill has a descriptive name (e.g. generate*images), extensive docstrings and good defaults (e.g., writing out files to disk for persistence and reuse). Skills can be \_associated with* or _attached to_ agent specifications. + +![AutoGen Studio Skill Interface](./img/skill.png) + +### Models + +A model refers to the configuration of an LLM. Similar to skills, a model can be attached to an agent specification. +The AutoGen Studio interface supports multiple model types including OpenAI models (and any other model endpoint provider that supports the OpenAI endpoint specification), Azure OpenAI models and Gemini Models. + +![AutoGen Studio Create new model](./img/model_new.png) +![AutoGen Studio Create new model](./img/model_openai.png) + +### Agents + +An agent entity declaratively specifies properties for an AutoGen agent (mirrors most but not all of the members of a base AutoGen Conversable agent class). Currently `UserProxyAgent` and `AssistantAgent` and `GroupChat` agent abstractions are supported. + +![AutoGen Studio Create new agent](./img/agent_new.png) +![AutoGen Studio Createan assistant agent](./img/agent_groupchat.png) + +Once agents have been created, existing models or skills can be _added_ to the agent. + +![AutoGen Studio Add skills and models to agent](./img/agent_skillsmodel.png) + +### Workflows + +An agent workflow is a specification of a set of agents (team of agents) that can work together to accomplish a task. AutoGen Studio supports two types of high level workflow patterns: + +#### Autonomous Chat : + +This workflow implements a paradigm where agents are defined and a chat is initiated between the agents to accomplish a task. AutoGen simplifies this into defining an `initiator` agent and a `receiver` agent where the receiver agent is selected from a list of previously created agents. Note that when the receiver is a `GroupChat` agent (i.e., contains multiple agents), the communication pattern between those agents is determined by the `speaker_selection_method` parameter in the `GroupChat` agent configuration. + +![AutoGen Studio Autonomous Chat Workflow](./img/workflow_chat.png) + +#### Sequential Chat + +This workflow allows users to specify a list of `AssistantAgent` agents that are executed in sequence to accomplish a task. The runtime behavior here follows the following pattern: at each step, each `AssistantAgent` is _paired_ with a `UserProxyAgent` and chat initiated between this pair to process the input task. The result of this exchange is summarized and provided to the next `AssistantAgent` which is also paired with a `UserProxyAgent` and their summarized result is passed to the next `AssistantAgent` in the sequence. This continues until the last `AssistantAgent` in the sequence is reached. + +![AutoGen Studio Sequential Workflow](./img/workflow_sequential.png) + + + + + +## Testing an Agent Workflow + +AutoGen Studio allows users to interactively test workflows on tasks and review resulting artifacts (such as images, code, and documents). + +![AutoGen Studio Test Workflow](./img/workflow_test.png) + +Users can also review the “inner monologue” of agent workflows as they address tasks, and view profiling information such as costs associated with the run (such as number of turns, number of tokens etc.), and agent actions (such as whether tools were called and the outcomes of code execution). + +![AutoGen Studio Profile Workflow Results](./img/workflow_profile.png) + +## Exporting Agent Workflows + +Users can download the skills, agents, and workflow configurations they create as well as share and reuse these artifacts. AutoGen Studio also offers a seamless process to export workflows and deploy them as application programming interfaces (APIs) that can be consumed in other applications deploying workflows as APIs. + +### Export Workflow + +AutoGen Studio allows you to export a selected workflow as a JSON configuration file. + +Build -> Workflows -> (On workflow card) -> Export + +![AutoGen Studio Export Workflow](./img/workflow_export.png) + +### Using AutoGen Studio Workflows in a Python Application + +An exported workflow can be easily integrated into any Python application using the `WorkflowManager` class with just two lines of code. Underneath, the WorkflowManager rehydrates the workflow specification into AutoGen agents that are subsequently used to address tasks. + +```python + +from autogenstudio import WorkflowManager +# load workflow from exported json workflow file. +workflow_manager = WorkflowManager(workflow="path/to/your/workflow_.json") + +# run the workflow on a task +task_query = "What is the height of the Eiffel Tower?. Dont write code, just respond to the question." +workflow_manager.run(message=task_query) + +``` + +### Deploying AutoGen Studio Workflows as APIs + +The workflow can be launched as an API endpoint from the command line using the autogenstudio commandline tool. + +```bash +autogenstudio serve --workflow=workflow.json --port=5000 +``` + +Similarly, the workflow launch command above can be wrapped into a Dockerfile that can be deployed on cloud services like Azure Container Apps or Azure Web Apps. diff --git a/website/docusaurus.config.js b/website/docusaurus.config.js index f60ac41fd611..efc13096b0f7 100644 --- a/website/docusaurus.config.js +++ b/website/docusaurus.config.js @@ -8,9 +8,9 @@ customPostCssPlugin = () => { configurePostCss(options) { options.plugins.push(require("postcss-preset-env")); return options; - } + }, }; -} +}; module.exports = { title: "AutoGen", @@ -24,13 +24,13 @@ module.exports = { projectName: "AutoGen", // Usually your repo name. scripts: [ { - src: '/autogen/js/custom.js', + src: "/autogen/js/custom.js", async: true, defer: true, }, ], markdown: { - format: 'detect', // Support for MD files with .md extension + format: "detect", // Support for MD files with .md extension }, themeConfig: { docs: { @@ -80,6 +80,11 @@ module.exports = { docId: "FAQ", label: "FAQs", }, + { + type: "doc", + docId: "autogen-studio/getting-started", + label: "AutoGen Studio", + }, { type: "doc", docId: "ecosystem", @@ -127,9 +132,8 @@ module.exports = { { label: "Dotnet", href: "https://microsoft.github.io/autogen-for-net/", - } + }, ], - }, { to: "blog", @@ -175,7 +179,6 @@ module.exports = { { label: "Discord", href: "https://aka.ms/autogen-dc", - }, { label: "Twitter", @@ -187,17 +190,17 @@ module.exports = { copyright: `Copyright © ${new Date().getFullYear()} AutoGen Authors | Privacy and Cookies`, }, announcementBar: { - id: 'whats_new', + id: "whats_new", content: 'What\'s new in AutoGen? Read this blog for an overview of updates', - backgroundColor: '#fafbfc', - textColor: '#091E42', + backgroundColor: "#fafbfc", + textColor: "#091E42", isCloseable: true, }, /* Clarity Config */ clarity: { ID: "lnxpe6skj1", // The Tracking ID provided by Clarity - } + }, }, presets: [ [ @@ -281,14 +284,10 @@ module.exports = { { to: "/docs/contributor-guide/contributing", from: ["/docs/Contribute"], - } + }, ], }, ], - [ - 'docusaurus-plugin-clarity', - { - } - ], + ["docusaurus-plugin-clarity", {}], ], }; diff --git a/website/sidebars.js b/website/sidebars.js index 49d8fbf87e95..589c0ee9ba4d 100644 --- a/website/sidebars.js +++ b/website/sidebars.js @@ -9,114 +9,143 @@ Create as many sidebars as you want. */ - module.exports = { +module.exports = { docsSidebar: [ - 'Getting-Started', + "Getting-Started", { type: "category", label: "Installation", collapsed: true, items: ["installation/Docker", "installation/Optional-Dependencies"], link: { - type: 'doc', - id: "installation/Installation" + type: "doc", + id: "installation/Installation", }, }, { - type: 'category', - label: 'Tutorial', + type: "category", + label: "Tutorial", collapsed: false, link: { - type: 'generated-index', - title: 'Tutorial', - description: 'Tutorial on the basic concepts of AutoGen', - slug: 'tutorial', + type: "generated-index", + title: "Tutorial", + description: "Tutorial on the basic concepts of AutoGen", + slug: "tutorial", }, items: [ { - type: 'doc', - id: 'tutorial/introduction', - label: 'Introduction', + type: "doc", + id: "tutorial/introduction", + label: "Introduction", }, { - type: 'doc', - id: 'tutorial/chat-termination', - label: 'Chat Termination', + type: "doc", + id: "tutorial/chat-termination", + label: "Chat Termination", }, { - type: 'doc', - id: 'tutorial/human-in-the-loop', - label: 'Human in the Loop', + type: "doc", + id: "tutorial/human-in-the-loop", + label: "Human in the Loop", }, { - type: 'doc', - id: 'tutorial/code-executors', - label: 'Code Executors', + type: "doc", + id: "tutorial/code-executors", + label: "Code Executors", }, { - type: 'doc', - id: 'tutorial/tool-use', - label: 'Tool Use', + type: "doc", + id: "tutorial/tool-use", + label: "Tool Use", }, { - type: 'doc', - id: 'tutorial/conversation-patterns', - label: 'Conversation Patterns', + type: "doc", + id: "tutorial/conversation-patterns", + label: "Conversation Patterns", }, { - type: 'doc', - id: 'tutorial/what-next', - label: 'What Next?', - } + type: "doc", + id: "tutorial/what-next", + label: "What Next?", + }, ], }, - {'Use Cases': [{type: 'autogenerated', dirName: 'Use-Cases'}]}, + { "Use Cases": [{ type: "autogenerated", dirName: "Use-Cases" }] }, { - type: 'category', - label: 'User Guide', + type: "category", + label: "User Guide", collapsed: false, link: { - type: 'generated-index', - title: 'User Guide', - slug: 'topics', + type: "generated-index", + title: "User Guide", + slug: "topics", }, - items: [{type: 'autogenerated', dirName: 'topics'}] + items: [{ type: "autogenerated", dirName: "topics" }], }, { - type: 'link', - label: 'API Reference', - href: '/docs/reference/agentchat/conversable_agent', + type: "link", + label: "API Reference", + href: "/docs/reference/agentchat/conversable_agent", }, { - type: 'doc', - label: 'FAQs', - id: 'FAQ', + type: "doc", + label: "FAQs", + id: "FAQ", }, + { - 'type': 'category', - 'label': 'Ecosystem', - 'link': { - type: 'generated-index', - title: 'Ecosystem', - description: 'Learn about the ecosystem of AutoGen', - slug: 'ecosystem', + type: "category", + label: "AutoGen Studio", + collapsed: true, + items: [ + { + type: "doc", + id: "autogen-studio/getting-started", + label: "Getting Started", + }, + { + type: "doc", + id: "autogen-studio/usage", + label: "Using AutoGen Studio", + }, + { + type: "doc", + id: "autogen-studio/faqs", + label: "AutoGen Studio FAQs", + }, + ], + link: { + type: "generated-index", + title: "AutoGen Studio", + description: "Learn about AutoGen Studio", + slug: "autogen-studio", + }, + }, + { + type: "category", + label: "Ecosystem", + link: { + type: "generated-index", + title: "Ecosystem", + description: "Learn about the ecosystem of AutoGen", + slug: "ecosystem", }, - 'items': [{type: 'autogenerated', dirName: 'ecosystem'}] + items: [{ type: "autogenerated", dirName: "ecosystem" }], }, { type: "category", label: "Contributor Guide", collapsed: true, - items: [{type: 'autogenerated', dirName: 'contributor-guide'}], + items: [{ type: "autogenerated", dirName: "contributor-guide" }], link: { - type: 'generated-index', - title: 'Contributor Guide', - description: 'Learn how to contribute to AutoGen', - slug: 'contributor-guide', + type: "generated-index", + title: "Contributor Guide", + description: "Learn how to contribute to AutoGen", + slug: "contributor-guide", }, }, - 'Research', - 'Migration-Guide' + "Research", + "Migration-Guide", ], // pydoc-markdown auto-generated markdowns from docstrings referenceSideBar: [require("./docs/reference/sidebar.json")], @@ -124,14 +153,16 @@ { type: "category", label: "Notebooks", - items: [{ - type: "autogenerated", - dirName: "notebooks", - },], + items: [ + { + type: "autogenerated", + dirName: "notebooks", + }, + ], link: { - type: 'doc', - id: "notebooks" + type: "doc", + id: "notebooks", }, }, - ] + ], }; From a16b307dc0622c5e2d799267c403731da3cc9123 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Mon, 10 Jun 2024 10:31:45 -0700 Subject: [PATCH 2/5] [.Net] Add Goolge gemini (#2868) * update * add vertex gemini test * remove DTO * add test for vertexGeminiAgent * update test name * update IGeminiClient interface * add test for streaming * add message connector * add gemini message extension * add tests * update * add gemnini sample * update examples * add test for iamge * fix test * add more tests * add streaming message test * add comment * remove unused json * implement google gemini client * update * fix comment --- dotnet/AutoGen.sln | 20 + dotnet/Directory.Build.props | 25 + dotnet/eng/Version.props | 1 + .../images/background.png | 0 dotnet/resource/images/square.png | 3 + .../Example05_Dalle_And_GPT4V.cs | 2 +- .../AutoGen.Gemini.Sample.csproj | 19 + .../Chat_With_Google_Gemini.cs | 38 ++ .../Chat_With_Vertex_Gemini.cs | 39 ++ .../Function_Call_With_Gemini.cs | 129 +++++ .../Image_Chat_With_Vertex_Gemini.cs | 44 ++ .../sample/AutoGen.Gemini.Sample/Program.cs | 6 + .../AutoGen.Ollama.Sample.csproj | 7 +- .../AutoGen.Ollama.Sample/Chat_With_LLaVA.cs | 2 +- .../DTO/ChatCompletionRequest.cs | 7 +- .../src/AutoGen.Core/Message/ImageMessage.cs | 39 +- .../src/AutoGen.Gemini/AutoGen.Gemini.csproj | 18 + .../Extension/FunctionContractExtension.cs | 90 ++++ dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs | 268 ++++++++++ .../src/AutoGen.Gemini/GoogleGeminiClient.cs | 83 +++ dotnet/src/AutoGen.Gemini/IGeminiClient.cs | 15 + .../Middleware/GeminiAgentExtension.cs | 40 ++ .../Middleware/GeminiMessageConnector.cs | 483 ++++++++++++++++++ .../src/AutoGen.Gemini/VertexGeminiClient.cs | 38 ++ .../AnthropicClientTest.cs | 4 +- .../AutoGen.Anthropic.Tests.csproj | 9 +- .../AutoGen.DotnetInteractive.Tests.csproj | 10 +- ....ItGenerateGetWeatherToolTest.approved.txt | 17 + .../AutoGen.Gemini.Tests.csproj | 19 + .../FunctionContractExtensionTests.cs | 27 + dotnet/test/AutoGen.Gemini.Tests/Functions.cs | 28 + .../AutoGen.Gemini.Tests/GeminiAgentTests.cs | 311 +++++++++++ .../GeminiMessageTests.cs | 380 ++++++++++++++ .../GoogleGeminiClientTests.cs | 132 +++++ .../test/AutoGen.Gemini.Tests/SampleTests.cs | 28 + .../VertexGeminiClientTests.cs | 134 +++++ .../AutoGen.Mistral.Tests.csproj | 10 +- .../AutoGen.Ollama.Tests.csproj | 10 +- .../AutoGen.OpenAI.Tests.csproj | 10 +- .../AutoGen.SemanticKernel.Tests.csproj | 10 +- .../AutoGen.SourceGenerator.Tests.csproj | 10 +- .../test/AutoGen.Tests/AutoGen.Tests.csproj | 10 +- .../test/AutoGen.Tests/ImageMessageTests.cs | 38 ++ dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 26 +- 44 files changed, 2530 insertions(+), 109 deletions(-) rename dotnet/{sample/AutoGen.Ollama.Sample => resource}/images/background.png (100%) create mode 100644 dotnet/resource/images/square.png create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs create mode 100644 dotnet/sample/AutoGen.Gemini.Sample/Program.cs create mode 100644 dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj create mode 100644 dotnet/src/AutoGen.Gemini/Extension/FunctionContractExtension.cs create mode 100644 dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs create mode 100644 dotnet/src/AutoGen.Gemini/GoogleGeminiClient.cs create mode 100644 dotnet/src/AutoGen.Gemini/IGeminiClient.cs create mode 100644 dotnet/src/AutoGen.Gemini/Middleware/GeminiAgentExtension.cs create mode 100644 dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs create mode 100644 dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/ApprovalTests/FunctionContractExtensionTests.ItGenerateGetWeatherToolTest.approved.txt create mode 100644 dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj create mode 100644 dotnet/test/AutoGen.Gemini.Tests/FunctionContractExtensionTests.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/Functions.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/GoogleGeminiClientTests.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/SampleTests.cs create mode 100644 dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs create mode 100644 dotnet/test/AutoGen.Tests/ImageMessageTests.cs diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 2bc106c0acad..6c4e8f0396b6 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -53,6 +53,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Tests", " EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Samples", "sample\AutoGen.Anthropic.Samples\AutoGen.Anthropic.Samples.csproj", "{834B4E85-64E5-4382-8465-548F332E5298}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini", "src\AutoGen.Gemini\AutoGen.Gemini.csproj", "{EFE0DC86-80FC-4D52-95B7-07654BA1A769}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Tests", "test\AutoGen.Gemini.Tests\AutoGen.Gemini.Tests.csproj", "{8EA16BAB-465A-4C07-ABC4-1070D40067E9}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Gemini.Sample", "sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj", "{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}" EndProject Global @@ -149,6 +154,18 @@ Global {834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.Build.0 = Debug|Any CPU {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.ActiveCfg = Release|Any CPU {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.Build.0 = Release|Any CPU + {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.Build.0 = Release|Any CPU + {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.Build.0 = Release|Any CPU + {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.Build.0 = Release|Any CPU {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -180,6 +197,9 @@ Global {6A95E113-B824-4524-8F13-CD0C3E1C8804} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {834B4E85-64E5-4382-8465-548F332E5298} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} + {EFE0DC86-80FC-4D52-95B7-07654BA1A769} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {8EA16BAB-465A-4C07-ABC4-1070D40067E9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} {6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index aeb667438e26..4b3e9441f1ee 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -13,12 +13,37 @@ CS1998;CS1591 $(NoWarn);$(CSNoWarn);NU5104 true + true false true true + false $(MSBuildThisFileDirectory) + + + + + + + + + + + + + Always + testData/%(RecursiveDir)%(Filename)%(Extension) + + + + + + Always + resource/%(RecursiveDir)%(Filename)%(Extension) + + diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index a43da436b388..0b8dcaa565cb 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -12,6 +12,7 @@ 17.7.0 1.0.0-beta.24229.4 8.0.0 + 3.0.0 4.3.0.2 \ No newline at end of file diff --git a/dotnet/sample/AutoGen.Ollama.Sample/images/background.png b/dotnet/resource/images/background.png similarity index 100% rename from dotnet/sample/AutoGen.Ollama.Sample/images/background.png rename to dotnet/resource/images/background.png diff --git a/dotnet/resource/images/square.png b/dotnet/resource/images/square.png new file mode 100644 index 000000000000..afb4f4cd4df8 --- /dev/null +++ b/dotnet/resource/images/square.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918 +size 491 diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs index 2d21615ef71b..67fd40ea3ac4 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs @@ -93,7 +93,7 @@ public static async Task RunAsync() if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION")) { var imageUrl = content.Split("\n").Last(); - var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From); + var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From, mimeType: "image/png"); Console.WriteLine($"download image from {imageUrl} to {imagePath}"); var httpClient = new HttpClient(); diff --git a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj new file mode 100644 index 000000000000..b1779b56c390 --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj @@ -0,0 +1,19 @@ + + + + Exe + net8.0 + enable + enable + true + True + + + + + + + + + + diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs new file mode 100644 index 000000000000..233c35c81222 --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Chat_With_Google_Gemini.cs + +using AutoGen.Core; +using AutoGen.Gemini.Middleware; +using FluentAssertions; + +namespace AutoGen.Gemini.Sample; + +public class Chat_With_Google_Gemini +{ + public static async Task RunAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY"); + + if (apiKey is null) + { + Console.WriteLine("Please set GOOGLE_GEMINI_API_KEY environment variable."); + return; + } + + #region Create_Gemini_Agent + var geminiAgent = new GeminiChatAgent( + name: "gemini", + model: "gemini-1.5-flash-001", + apiKey: apiKey, + systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code") + .RegisterMessageConnector() + .RegisterPrintMessage(); + #endregion Create_Gemini_Agent + + var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?"); + + #region verify_reply + reply.Should().BeOfType(); + #endregion verify_reply + } +} diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs new file mode 100644 index 000000000000..679a07ed69b9 --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Chat_With_Vertex_Gemini.cs + +using AutoGen.Core; +using AutoGen.Gemini.Middleware; +using FluentAssertions; + +namespace AutoGen.Gemini.Sample; + +public class Chat_With_Vertex_Gemini +{ + public static async Task RunAsync() + { + var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + + if (projectID is null) + { + Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable."); + return; + } + + #region Create_Gemini_Agent + var geminiAgent = new GeminiChatAgent( + name: "gemini", + model: "gemini-1.5-flash-001", + location: "us-east1", + project: projectID, + systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code") + .RegisterMessageConnector() + .RegisterPrintMessage(); + #endregion Create_Gemini_Agent + + var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?"); + + #region verify_reply + reply.Should().BeOfType(); + #endregion verify_reply + } +} diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs new file mode 100644 index 000000000000..d1b681d87096 --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Function_Call_With_Gemini.cs + +using AutoGen.Core; +using AutoGen.Gemini.Middleware; +using FluentAssertions; +using Google.Cloud.AIPlatform.V1; + +namespace AutoGen.Gemini.Sample; + +public partial class MovieFunction +{ + /// + /// find movie titles currently playing in theaters based on any description, genre, title words, etc. + /// + /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616 + /// Any kind of description including category or genre, title words, attributes, etc. + /// + [Function] + public async Task FindMovies(string location, string description) + { + // dummy implementation + var movies = new List { "Barbie", "Spiderman", "Batman" }; + var result = $"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}"; + + return result; + } + + /// + /// find theaters based on location and optionally movie title which is currently playing in theaters + /// + /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616 + /// Any movie title + [Function] + public async Task FindTheaters(string location, string movie) + { + // dummy implementation + var theaters = new List { "AMC", "Regal", "Cinemark" }; + var result = $"Theaters playing {movie} in {location} are: {string.Join(", ", theaters)}"; + + return result; + } + + /// + /// Find the start times for movies playing in a specific theater + /// + /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616 + /// Any movie title + /// Name of the theater + /// Date for requested showtime + /// + [Function] + public async Task GetShowtimes(string location, string movie, string theater, string date) + { + // dummy implementation + var showtimes = new List { "10:00 AM", "12:00 PM", "2:00 PM", "4:00 PM", "6:00 PM", "8:00 PM" }; + var result = $"Showtimes for {movie} at {theater} in {location} are: {string.Join(", ", showtimes)}"; + + return result; + } + +} + +/// +/// Modified from https://ai.google.dev/gemini-api/docs/function-calling +/// +public partial class Function_Call_With_Gemini +{ + public static async Task RunAsync() + { + var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + + if (projectID is null) + { + Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable."); + return; + } + + var movieFunction = new MovieFunction(); + var functionMiddleware = new FunctionCallMiddleware( + functions: [ + movieFunction.FindMoviesFunctionContract, + movieFunction.FindTheatersFunctionContract, + movieFunction.GetShowtimesFunctionContract + ], + functionMap: new Dictionary>> + { + { movieFunction.FindMoviesFunctionContract.Name!, movieFunction.FindMoviesWrapper }, + { movieFunction.FindTheatersFunctionContract.Name!, movieFunction.FindTheatersWrapper }, + { movieFunction.GetShowtimesFunctionContract.Name!, movieFunction.GetShowtimesWrapper }, + }); + + #region Create_Gemini_Agent + var geminiAgent = new GeminiChatAgent( + name: "gemini", + model: "gemini-1.5-flash-001", + location: "us-central1", + project: projectID, + systemMessage: "You are a helpful AI assistant", + toolConfig: new ToolConfig() + { + FunctionCallingConfig = new FunctionCallingConfig() + { + Mode = FunctionCallingConfig.Types.Mode.Auto, + } + }) + .RegisterMessageConnector() + .RegisterPrintMessage() + .RegisterStreamingMiddleware(functionMiddleware); + #endregion Create_Gemini_Agent + + #region Single_turn + var question = new TextMessage(Role.User, "What movies are showing in North Seattle tonight?"); + var functionCallReply = await geminiAgent.SendAsync(question); + #endregion Single_turn + + #region Single_turn_verify_reply + functionCallReply.Should().BeOfType(); + #endregion Single_turn_verify_reply + + #region Multi_turn + var finalReply = await geminiAgent.SendAsync(chatHistory: [question, functionCallReply]); + #endregion Multi_turn + + #region Multi_turn_verify_reply + finalReply.Should().BeOfType(); + #endregion Multi_turn_verify_reply + } +} diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs new file mode 100644 index 000000000000..86193b653d9c --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Image_Chat_With_Vertex_Gemini.cs + +using AutoGen.Core; +using AutoGen.Gemini.Middleware; +using FluentAssertions; + +namespace AutoGen.Gemini.Sample; + +public class Image_Chat_With_Vertex_Gemini +{ + public static async Task RunAsync() + { + var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + + if (projectID is null) + { + Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable."); + return; + } + + #region Create_Gemini_Agent + var geminiAgent = new GeminiChatAgent( + name: "gemini", + model: "gemini-1.5-flash-001", + location: "us-east4", + project: projectID, + systemMessage: "You explain image content to user") + .RegisterMessageConnector() + .RegisterPrintMessage(); + #endregion Create_Gemini_Agent + + #region Send_Image_Request + var imagePath = Path.Combine("resource", "images", "background.png"); + var image = await File.ReadAllBytesAsync(imagePath); + var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png")); + var reply = await geminiAgent.SendAsync("what's in the image", [imageMessage]); + #endregion Send_Image_Request + + #region Verify_Reply + reply.Should().BeOfType(); + #endregion Verify_Reply + } +} diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Program.cs b/dotnet/sample/AutoGen.Gemini.Sample/Program.cs new file mode 100644 index 000000000000..5e76942209aa --- /dev/null +++ b/dotnet/sample/AutoGen.Gemini.Sample/Program.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Program.cs + +using AutoGen.Gemini.Sample; + +Image_Chat_With_Vertex_Gemini.RunAsync().Wait(); diff --git a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj index 1dc94400869e..5277408d595d 100644 --- a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj +++ b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj @@ -5,6 +5,7 @@ enable True $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110 + true @@ -15,10 +16,4 @@ - - - PreserveNewest - - - diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs index d52afb075e12..d9e38c886c2e 100644 --- a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs +++ b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs @@ -28,7 +28,7 @@ public static async Task RunAsync() #endregion Create_Ollama_Agent #region Send_Message - var image = Path.Combine("images", "background.png"); + var image = Path.Combine("resource", "images", "background.png"); var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png"); var imageMessage = new ImageMessage(Role.User, binaryData); var textMessage = new TextMessage(Role.User, "what's in this image?"); diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs index fa1654bc11d0..f8967bc7e7fb 100644 --- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs @@ -1,11 +1,10 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. - +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionRequest.cs using System.Text.Json.Serialization; +using System.Collections.Generic; namespace AutoGen.Anthropic.DTO; -using System.Collections.Generic; - public class ChatCompletionRequest { [JsonPropertyName("model")] diff --git a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs index d2e2d0803003..685354dfe7a9 100644 --- a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs @@ -7,18 +7,34 @@ namespace AutoGen.Core; public class ImageMessage : IMessage { - public ImageMessage(Role role, string url, string? from = null) + public ImageMessage(Role role, string url, string? from = null, string? mimeType = null) + : this(role, new Uri(url), from, mimeType) { - this.Role = role; - this.From = from; - this.Url = url; } - public ImageMessage(Role role, Uri uri, string? from = null) + public ImageMessage(Role role, Uri uri, string? from = null, string? mimeType = null) { this.Role = role; this.From = from; this.Url = uri.ToString(); + + // try infer mimeType from uri extension if not provided + if (mimeType is null) + { + mimeType = uri switch + { + _ when uri.AbsoluteUri.EndsWith(".png", StringComparison.OrdinalIgnoreCase) => "image/png", + _ when uri.AbsoluteUri.EndsWith(".jpg", StringComparison.OrdinalIgnoreCase) => "image/jpeg", + _ when uri.AbsoluteUri.EndsWith(".jpeg", StringComparison.OrdinalIgnoreCase) => "image/jpeg", + _ when uri.AbsoluteUri.EndsWith(".gif", StringComparison.OrdinalIgnoreCase) => "image/gif", + _ when uri.AbsoluteUri.EndsWith(".bmp", StringComparison.OrdinalIgnoreCase) => "image/bmp", + _ when uri.AbsoluteUri.EndsWith(".webp", StringComparison.OrdinalIgnoreCase) => "image/webp", + _ when uri.AbsoluteUri.EndsWith(".svg", StringComparison.OrdinalIgnoreCase) => "image/svg+xml", + _ => throw new ArgumentException("MimeType is required for ImageMessage", nameof(mimeType)) + }; + } + + this.MimeType = mimeType; } public ImageMessage(Role role, BinaryData data, string? from = null) @@ -28,7 +44,7 @@ public ImageMessage(Role role, BinaryData data, string? from = null) throw new ArgumentException("Data cannot be empty", nameof(data)); } - if (string.IsNullOrWhiteSpace(data.MediaType)) + if (data.MediaType is null) { throw new ArgumentException("MediaType is needed for DataUri Images", nameof(data)); } @@ -36,15 +52,18 @@ public ImageMessage(Role role, BinaryData data, string? from = null) this.Role = role; this.From = from; this.Data = data; + this.MimeType = data.MediaType; } - public Role Role { get; set; } + public Role Role { get; } - public string? Url { get; set; } + public string? Url { get; } public string? From { get; set; } - public BinaryData? Data { get; set; } + public BinaryData? Data { get; } + + public string MimeType { get; } public string BuildDataUri() { @@ -53,7 +72,7 @@ public string BuildDataUri() throw new NullReferenceException($"{nameof(Data)}"); } - return $"data:{this.Data.MediaType};base64,{Convert.ToBase64String(this.Data.ToArray())}"; + return $"data:{this.MimeType};base64,{Convert.ToBase64String(this.Data.ToArray())}"; } public override string ToString() diff --git a/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj new file mode 100644 index 000000000000..5a2a42ceb58c --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj @@ -0,0 +1,18 @@ + + + + netstandard2.0 + + + + + + + + + + + + + + diff --git a/dotnet/src/AutoGen.Gemini/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.Gemini/Extension/FunctionContractExtension.cs new file mode 100644 index 000000000000..64f78fa165b1 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/Extension/FunctionContractExtension.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// FunctionContractExtension.cs + +using System.Collections.Generic; +using System.Linq; +using AutoGen.Core; +using Google.Cloud.AIPlatform.V1; +using Json.Schema; +using Json.Schema.Generation; +using OpenAPISchemaType = Google.Cloud.AIPlatform.V1.Type; +using Type = System.Type; + +namespace AutoGen.Gemini.Extension; + +public static class FunctionContractExtension +{ + /// + /// Convert a to a that can be used in gpt funciton call. + /// + public static FunctionDeclaration ToFunctionDeclaration(this FunctionContract function) + { + var required = function.Parameters!.Where(p => p.IsRequired) + .Select(p => p.Name) + .ToList(); + var parameterProperties = new Dictionary(); + + foreach (var parameter in function.Parameters ?? Enumerable.Empty()) + { + var schema = ToOpenApiSchema(parameter.ParameterType); + schema.Description = parameter.Description; + schema.Title = parameter.Name; + schema.Nullable = !parameter.IsRequired; + parameterProperties.Add(parameter.Name!, schema); + } + + return new FunctionDeclaration + { + Name = function.Name, + Description = function.Description, + Parameters = new OpenApiSchema + { + Required = + { + required, + }, + Properties = + { + parameterProperties, + }, + Type = OpenAPISchemaType.Object, + }, + }; + } + + private static OpenApiSchema ToOpenApiSchema(Type? type) + { + if (type == null) + { + return new OpenApiSchema + { + Type = OpenAPISchemaType.Unspecified + }; + } + + var schema = new JsonSchemaBuilder().FromType(type).Build(); + var openApiSchema = new OpenApiSchema + { + Type = schema.GetJsonType() switch + { + SchemaValueType.Array => OpenAPISchemaType.Array, + SchemaValueType.Boolean => OpenAPISchemaType.Boolean, + SchemaValueType.Integer => OpenAPISchemaType.Integer, + SchemaValueType.Number => OpenAPISchemaType.Number, + SchemaValueType.Object => OpenAPISchemaType.Object, + SchemaValueType.String => OpenAPISchemaType.String, + _ => OpenAPISchemaType.Unspecified + }, + }; + + if (schema.GetJsonType() == SchemaValueType.Object && schema.GetProperties() is var properties && properties != null) + { + foreach (var property in properties) + { + openApiSchema.Properties.Add(property.Key, ToOpenApiSchema(property.Value.GetType())); + } + } + + return openApiSchema; + } +} diff --git a/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs new file mode 100644 index 000000000000..b081faae8321 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiChatAgent.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; +using AutoGen.Gemini.Extension; +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf.Collections; +namespace AutoGen.Gemini; + +public class GeminiChatAgent : IStreamingAgent +{ + private readonly IGeminiClient client; + private readonly string? systemMessage; + private readonly string model; + private readonly ToolConfig? toolConfig; + private readonly RepeatedField? safetySettings; + private readonly string responseMimeType; + private readonly Tool[]? tools; + + /// + /// Create that connects to Gemini. + /// + /// the gemini client to use. e.g. + /// agent name + /// the model id. It needs to be in the format of + /// 'projects/{project}/locations/{location}/publishers/{provider}/models/{model}' if the is + /// system message + /// tool config + /// tools + /// safety settings + /// response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain' + public GeminiChatAgent( + IGeminiClient client, + string name, + string model, + string? systemMessage = null, + ToolConfig? toolConfig = null, + Tool[]? tools = null, + RepeatedField? safetySettings = null, + string responseMimeType = "text/plain") + { + this.client = client; + this.Name = name; + this.systemMessage = systemMessage; + this.model = model; + this.toolConfig = toolConfig; + this.safetySettings = safetySettings; + this.responseMimeType = responseMimeType; + this.tools = tools; + } + + /// + /// Create that connects to Gemini using + /// + /// agent name + /// the name of gemini model, e.g. gemini-1.5-flash-001 + /// google gemini api key + /// system message + /// tool config + /// tools + /// + /// response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain' + /// /// + /// + /// + public GeminiChatAgent( + string name, + string model, + string apiKey, + string systemMessage = "You are a helpful AI assistant", + ToolConfig? toolConfig = null, + Tool[]? tools = null, + RepeatedField? safetySettings = null, + string responseMimeType = "text/plain") + : this( + client: new GoogleGeminiClient(apiKey), + name: name, + model: model, + systemMessage: systemMessage, + toolConfig: toolConfig, + tools: tools, + safetySettings: safetySettings, + responseMimeType: responseMimeType) + { + } + + /// + /// Create that connects to Vertex AI. + /// + /// agent name + /// system message + /// the name of gemini model, e.g. gemini-1.5-flash-001 + /// project id + /// model location + /// model provider, default is 'google' + /// tool config + /// tools + /// + /// response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain' + /// + /// + /// + public GeminiChatAgent( + string name, + string model, + string project, + string location, + string provider = "google", + string? systemMessage = null, + ToolConfig? toolConfig = null, + Tool[]? tools = null, + RepeatedField? safetySettings = null, + string responseMimeType = "text/plain") + : this( + client: new VertexGeminiClient(location), + name: name, + model: $"projects/{project}/locations/{location}/publishers/{provider}/models/{model}", + systemMessage: systemMessage, + toolConfig: toolConfig, + tools: tools, + safetySettings: safetySettings, + responseMimeType: responseMimeType) + { + } + + public string Name { get; } + + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + var request = BuildChatRequest(messages, options); + var response = await this.client.GenerateContentAsync(request, cancellationToken: cancellationToken).ConfigureAwait(false); + + return MessageEnvelope.Create(response, this.Name); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var request = BuildChatRequest(messages, options); + var response = this.client.GenerateContentStreamAsync(request); + + await foreach (var item in response.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return MessageEnvelope.Create(item, this.Name); + } + } + + private GenerateContentRequest BuildChatRequest(IEnumerable messages, GenerateReplyOptions? options) + { + var geminiMessages = messages.Select(m => m switch + { + IMessage contentMessage => contentMessage.Content, + _ => throw new NotSupportedException($"Message type {m.GetType()} is not supported.") + }); + + // there are several rules applies to the messages that can be sent to Gemini in a multi-turn chat + // - The first message must be from the user or function + // - The (user|model) roles must alternate e.g. (user, model, user, model, ...) + // - The last message must be from the user or function + + // check if the first message is from the user + if (geminiMessages.FirstOrDefault()?.Role != "user" && geminiMessages.FirstOrDefault()?.Role != "function") + { + throw new ArgumentException("The first message must be from the user or function", nameof(messages)); + } + + // check if the last message is from the user + if (geminiMessages.LastOrDefault()?.Role != "user" && geminiMessages.LastOrDefault()?.Role != "function") + { + throw new ArgumentException("The last message must be from the user or function", nameof(messages)); + } + + // merge continuous messages with the same role into one message + var mergedMessages = geminiMessages.Aggregate(new List(), (acc, message) => + { + if (acc.Count == 0 || acc.Last().Role != message.Role) + { + acc.Add(message); + } + else + { + acc.Last().Parts.AddRange(message.Parts); + } + + return acc; + }); + + var systemMessage = this.systemMessage switch + { + null => null, + string message => new Content + { + Parts = { new[] { new Part { Text = message } } }, + Role = "system_instruction" + } + }; + + List tools = this.tools?.ToList() ?? new List(); + + var request = new GenerateContentRequest() + { + Contents = { mergedMessages }, + SystemInstruction = systemMessage, + Model = this.model, + GenerationConfig = new GenerationConfig + { + StopSequences = { options?.StopSequence ?? Enumerable.Empty() }, + ResponseMimeType = this.responseMimeType, + CandidateCount = 1, + }, + }; + + if (this.toolConfig is not null) + { + request.ToolConfig = this.toolConfig; + } + + if (this.safetySettings is not null) + { + request.SafetySettings.Add(this.safetySettings); + } + + if (options?.MaxToken.HasValue is true) + { + request.GenerationConfig.MaxOutputTokens = options.MaxToken.Value; + } + + if (options?.Temperature.HasValue is true) + { + request.GenerationConfig.Temperature = options.Temperature.Value; + } + + if (options?.Functions is { Length: > 0 }) + { + foreach (var function in options.Functions) + { + tools.Add(new Tool + { + FunctionDeclarations = { function.ToFunctionDeclaration() }, + }); + } + } + + // merge tools into one tool + // because multipe tools are currently not supported by Gemini + // see https://github.com/googleapis/python-aiplatform/issues/3771 + var aggregatedTool = new Tool + { + FunctionDeclarations = { tools.SelectMany(t => t.FunctionDeclarations) }, + }; + + if (aggregatedTool is { FunctionDeclarations: { Count: > 0 } }) + { + request.Tools.Add(aggregatedTool); + } + + return request; + } +} diff --git a/dotnet/src/AutoGen.Gemini/GoogleGeminiClient.cs b/dotnet/src/AutoGen.Gemini/GoogleGeminiClient.cs new file mode 100644 index 000000000000..9489061e27e9 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/GoogleGeminiClient.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GoogleGeminiClient.cs + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf; + +namespace AutoGen.Gemini; + +public class GoogleGeminiClient : IGeminiClient +{ + private readonly string apiKey; + private const string endpoint = "https://generativelanguage.googleapis.com/v1beta"; + private readonly HttpClient httpClient = new(); + private const string generateContentPath = "models/{0}:generateContent"; + private const string generateContentStreamPath = "models/{0}:streamGenerateContent"; + + public GoogleGeminiClient(HttpClient httpClient, string apiKey) + { + this.apiKey = apiKey; + this.httpClient = httpClient; + } + + public GoogleGeminiClient(string apiKey) + { + this.apiKey = apiKey; + } + + public async Task GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default) + { + var path = string.Format(generateContentPath, request.Model); + var url = $"{endpoint}/{path}?key={apiKey}"; + + var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json"); + var response = await httpClient.PostAsync(url, httpContent, cancellationToken); + + if (!response.IsSuccessStatusCode) + { + throw new Exception($"Failed to generate content. Status code: {response.StatusCode}"); + } + + var json = await response.Content.ReadAsStringAsync(); + return GenerateContentResponse.Parser.ParseJson(json); + } + + public async IAsyncEnumerable GenerateContentStreamAsync(GenerateContentRequest request) + { + var path = string.Format(generateContentStreamPath, request.Model); + var url = $"{endpoint}/{path}?key={apiKey}&alt=sse"; + + var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json"); + var requestMessage = new HttpRequestMessage(HttpMethod.Post, url) + { + Content = httpContent + }; + + var response = await httpClient.SendAsync(requestMessage, HttpCompletionOption.ResponseHeadersRead); + + if (!response.IsSuccessStatusCode) + { + throw new Exception($"Failed to generate content. Status code: {response.StatusCode}"); + } + + var stream = await response.Content.ReadAsStreamAsync(); + var jp = new JsonParser(JsonParser.Settings.Default.WithIgnoreUnknownFields(true)); + using var streamReader = new System.IO.StreamReader(stream); + while (!streamReader.EndOfStream) + { + var json = await streamReader.ReadLineAsync(); + if (string.IsNullOrWhiteSpace(json)) + { + continue; + } + + json = json.Substring("data:".Length).Trim(); + yield return jp.Parse(json); + } + } +} diff --git a/dotnet/src/AutoGen.Gemini/IGeminiClient.cs b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs new file mode 100644 index 000000000000..2e209e02b030 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IVertexGeminiClient.cs + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Google.Cloud.AIPlatform.V1; + +namespace AutoGen.Gemini; + +public interface IGeminiClient +{ + Task GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateContentStreamAsync(GenerateContentRequest request); +} diff --git a/dotnet/src/AutoGen.Gemini/Middleware/GeminiAgentExtension.cs b/dotnet/src/AutoGen.Gemini/Middleware/GeminiAgentExtension.cs new file mode 100644 index 000000000000..8718d54f960a --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/Middleware/GeminiAgentExtension.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiAgentExtension.cs + +using AutoGen.Core; + +namespace AutoGen.Gemini.Middleware; + +public static class GeminiAgentExtension +{ + + /// + /// Register an to the + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this GeminiChatAgent agent, GeminiMessageConnector? connector = null) + { + if (connector == null) + { + connector = new GeminiMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } + + /// + /// Register an to the where T is + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this MiddlewareStreamingAgent agent, GeminiMessageConnector? connector = null) + { + if (connector == null) + { + connector = new GeminiMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } +} diff --git a/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs new file mode 100644 index 000000000000..35008ebf00c2 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs @@ -0,0 +1,483 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using static Google.Cloud.AIPlatform.V1.Candidate.Types; +using IMessage = AutoGen.Core.IMessage; + +namespace AutoGen.Gemini.Middleware; + +public class GeminiMessageConnector : IStreamingMiddleware +{ + /// + /// if true, the connector will throw an exception if it encounters an unsupport message type. + /// Otherwise, it will ignore processing the message and return the message as is. + /// + private readonly bool strictMode; + + /// + /// Initializes a new instance of the class. + /// + /// whether to throw an exception if it encounters an unsupport message type. + /// If true, the connector will throw an exception if it encounters an unsupport message type. + /// If false, it will ignore processing the message and return the message as is. + public GeminiMessageConnector(bool strictMode = false) + { + this.strictMode = strictMode; + } + + public string Name => nameof(GeminiMessageConnector); + + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var messages = ProcessMessage(context.Messages, agent); + + var bucket = new List(); + + await foreach (var reply in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken)) + { + if (reply is Core.IMessage m) + { + // if m.Content is empty and stop reason is Stop, ignore the message + if (m.Content.Candidates.Count == 1 && m.Content.Candidates[0].Content.Parts.Count == 1 && m.Content.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text) + { + var text = m.Content.Candidates[0].Content.Parts[0].Text; + var stopReason = m.Content.Candidates[0].FinishReason; + if (string.IsNullOrEmpty(text) && stopReason == FinishReason.Stop) + { + continue; + } + } + + bucket.Add(m.Content); + + yield return PostProcessStreamingMessage(m.Content, agent); + } + else if (strictMode) + { + throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}"); + } + else + { + yield return reply; + } + + // aggregate the message updates from bucket into a single message + if (bucket is { Count: > 0 }) + { + var isTextMessageUpdates = bucket.All(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text); + var isFunctionCallUpdates = bucket.Any(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall); + if (isTextMessageUpdates) + { + var text = string.Join(string.Empty, bucket.Select(m => m.Candidates[0].Content.Parts[0].Text)); + var textMessage = new TextMessage(Role.Assistant, text, agent.Name); + + yield return textMessage; + } + else if (isFunctionCallUpdates) + { + var functionCallParts = bucket.Where(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall) + .Select(m => m.Candidates[0].Content.Parts[0]).ToList(); + + var toolCalls = new List(); + foreach (var part in functionCallParts) + { + var fc = part.FunctionCall; + var toolCall = new ToolCall(fc.Name, fc.Args.ToString()); + + toolCalls.Add(toolCall); + } + + var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name); + + yield return toolCallMessage; + } + else + { + throw new InvalidOperationException("The response should contain either text or tool calls."); + } + } + } + } + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var messages = ProcessMessage(context.Messages, agent); + var reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken); + + return reply switch + { + Core.IMessage m => PostProcessMessage(m.Content, agent), + _ when strictMode => throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}"), + _ => reply, + }; + } + + private IMessage PostProcessStreamingMessage(GenerateContentResponse m, IAgent agent) + { + this.ValidateGenerateContentResponse(m); + + var candidate = m.Candidates[0]; + var parts = candidate.Content.Parts; + + if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text) + { + var content = parts[0].Text; + return new TextMessageUpdate(Role.Assistant, content, agent.Name); + } + else + { + var toolCalls = new List(); + foreach (var part in parts) + { + if (part.DataCase == Part.DataOneofCase.FunctionCall) + { + var fc = part.FunctionCall; + var toolCall = new ToolCall(fc.Name, fc.Args.ToString()); + + toolCalls.Add(toolCall); + } + } + + if (toolCalls.Count > 0) + { + var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name); + return toolCallMessage; + } + else + { + throw new InvalidOperationException("The response should contain either text or tool calls."); + } + } + } + + private IMessage PostProcessMessage(GenerateContentResponse m, IAgent agent) + { + this.ValidateGenerateContentResponse(m); + var candidate = m.Candidates[0]; + var parts = candidate.Content.Parts; + + if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text) + { + var content = parts[0].Text; + return new TextMessage(Role.Assistant, content, agent.Name); + } + else + { + var toolCalls = new List(); + foreach (var part in parts) + { + if (part.DataCase == Part.DataOneofCase.FunctionCall) + { + var fc = part.FunctionCall; + var toolCall = new ToolCall(fc.Name, fc.Args.ToString()); + + toolCalls.Add(toolCall); + } + } + + if (toolCalls.Count > 0) + { + var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name); + return toolCallMessage; + } + else + { + throw new InvalidOperationException("The response should contain either text or tool calls."); + } + } + } + + private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent) + { + return messages.SelectMany(m => + { + if (m is Core.IMessage messageEnvelope) + { + return [m]; + } + else + { + return m switch + { + TextMessage textMessage => ProcessTextMessage(textMessage, agent), + ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent), + MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent), + ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage, agent), + ToolCallAggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent), + _ when strictMode => throw new InvalidOperationException($"Unsupported message type: {m.GetType()}"), + _ => [m], + }; + } + }); + } + + private IEnumerable ProcessToolCallAggregateMessage(ToolCallAggregateMessage toolCallAggregateMessage, IAgent agent) + { + var parseAsUser = ShouldParseAsUser(toolCallAggregateMessage, agent); + if (parseAsUser) + { + var content = toolCallAggregateMessage.GetContent(); + + if (content is string str) + { + var textMessage = new TextMessage(Role.User, str, toolCallAggregateMessage.From); + + return ProcessTextMessage(textMessage, agent); + } + + return []; + } + else + { + var toolCallContents = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent); + var toolCallResultContents = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2, agent); + + return toolCallContents.Concat(toolCallResultContents); + } + } + + private void ValidateGenerateContentResponse(GenerateContentResponse response) + { + if (response.Candidates.Count != 1) + { + throw new InvalidOperationException("The response should contain exactly one candidate."); + } + + var candidate = response.Candidates[0]; + if (candidate.Content is null) + { + var finishReason = candidate.FinishReason; + var finishMessage = candidate.FinishMessage; + + throw new InvalidOperationException($"The response should contain content but the content is empty. FinishReason: {finishReason}, FinishMessage: {finishMessage}"); + } + } + + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage, IAgent agent) + { + var functionCallResultParts = new List(); + foreach (var toolCallResult in toolCallResultMessage.ToolCalls) + { + if (toolCallResult.Result is null) + { + continue; + } + + // if result is already a json object, use it as is + var json = toolCallResult.Result; + try + { + JsonNode.Parse(json); + } + catch (JsonException) + { + // if the result is not a json object, wrap it in a json object + var result = new { result = json }; + json = JsonSerializer.Serialize(result); + } + var part = new Part + { + FunctionResponse = new FunctionResponse + { + Name = toolCallResult.FunctionName, + Response = Struct.Parser.ParseJson(json), + } + }; + + functionCallResultParts.Add(part); + } + + var content = new Content + { + Parts = { functionCallResultParts }, + Role = "function", + }; + + return [MessageEnvelope.Create(content, toolCallResultMessage.From)]; + } + + private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent) + { + var shouldParseAsUser = ShouldParseAsUser(toolCallMessage, agent); + if (strictMode && shouldParseAsUser) + { + throw new InvalidOperationException("ToolCallMessage is not supported as user role in Gemini."); + } + + var functionCallParts = new List(); + foreach (var toolCall in toolCallMessage.ToolCalls) + { + var part = new Part + { + FunctionCall = new FunctionCall + { + Name = toolCall.FunctionName, + Args = Struct.Parser.ParseJson(toolCall.FunctionArguments), + } + }; + + functionCallParts.Add(part); + } + var content = new Content + { + Parts = { functionCallParts }, + Role = "model" + }; + + return [MessageEnvelope.Create(content, toolCallMessage.From)]; + } + + private IEnumerable ProcessMultiModalMessage(MultiModalMessage multiModalMessage, IAgent agent) + { + var parts = new List(); + foreach (var message in multiModalMessage.Content) + { + if (message is TextMessage textMessage) + { + parts.Add(new Part { Text = textMessage.Content }); + } + else if (message is ImageMessage imageMessage) + { + parts.Add(CreateImagePart(imageMessage)); + } + else + { + throw new InvalidOperationException($"Unsupported message type: {message.GetType()}"); + } + } + + var shouldParseAsUser = ShouldParseAsUser(multiModalMessage, agent); + + if (strictMode && !shouldParseAsUser) + { + // image message is not supported as model role in Gemini + throw new InvalidOperationException("Image message is not supported as model role in Gemini."); + } + + var content = new Content + { + Parts = { parts }, + Role = shouldParseAsUser ? "user" : "model", + }; + + return [MessageEnvelope.Create(content, multiModalMessage.From)]; + } + + private IEnumerable ProcessTextMessage(TextMessage textMessage, IAgent agent) + { + if (textMessage.Role == Role.System) + { + // there are only user | model role in Gemini + // if the role is system and the strict mode is enabled, throw an exception + if (strictMode) + { + throw new InvalidOperationException("System role is not supported in Gemini."); + } + + // if strict mode is not enabled, parse the message as a user message + var content = new Content + { + Parts = { new[] { new Part { Text = textMessage.Content } } }, + Role = "user", + }; + + return [MessageEnvelope.Create(content, textMessage.From)]; + } + + var shouldParseAsUser = ShouldParseAsUser(textMessage, agent); + + if (shouldParseAsUser) + { + var content = new Content + { + Parts = { new[] { new Part { Text = textMessage.Content } } }, + Role = "user", + }; + + return [MessageEnvelope.Create(content, textMessage.From)]; + } + else + { + var content = new Content + { + Parts = { new[] { new Part { Text = textMessage.Content } } }, + Role = "model", + }; + + return [MessageEnvelope.Create(content, textMessage.From)]; + } + } + + private IEnumerable ProcessImageMessage(ImageMessage imageMessage, IAgent agent) + { + var imagePart = CreateImagePart(imageMessage); + var shouldParseAsUser = ShouldParseAsUser(imageMessage, agent); + + if (strictMode && !shouldParseAsUser) + { + // image message is not supported as model role in Gemini + throw new InvalidOperationException("Image message is not supported as model role in Gemini."); + } + + var content = new Content + { + Parts = { imagePart }, + Role = shouldParseAsUser ? "user" : "model", + }; + + return [MessageEnvelope.Create(content, imageMessage.From)]; + } + + private Part CreateImagePart(ImageMessage message) + { + if (message.Url is string url) + { + return new Part + { + FileData = new FileData + { + FileUri = url, + MimeType = message.MimeType + } + }; + } + else if (message.Data is BinaryData data) + { + return new Part + { + InlineData = new Blob + { + MimeType = message.MimeType, + Data = ByteString.CopyFrom(data.ToArray()), + } + }; + } + else + { + throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided"); + } + } + + private bool ShouldParseAsUser(IMessage message, IAgent agent) + { + return message switch + { + TextMessage textMessage => (textMessage.Role == Role.User && textMessage.From is null) + || (textMessage.From != agent.Name), + _ => message.From != agent.Name, + }; + } +} diff --git a/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs new file mode 100644 index 000000000000..c54f2280dfd3 --- /dev/null +++ b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IGeminiClient.cs + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Google.Cloud.AIPlatform.V1; + +namespace AutoGen.Gemini; + +internal class VertexGeminiClient : IGeminiClient +{ + private readonly PredictionServiceClient client; + public VertexGeminiClient(PredictionServiceClient client) + { + this.client = client; + } + + public VertexGeminiClient(string location) + { + PredictionServiceClientBuilder builder = new() + { + Endpoint = $"{location}-aiplatform.googleapis.com", + }; + + this.client = builder.Build(); + } + + public Task GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default) + { + return client.GenerateContentAsync(request, cancellationToken); + } + + public IAsyncEnumerable GenerateContentStreamAsync(GenerateContentRequest request) + { + return client.StreamGenerateContent(request).GetResponseStream(); + } +} diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs index 0b64c9e4e3c2..af050eef9280 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs @@ -1,4 +1,4 @@ -using System.Text; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using AutoGen.Anthropic.DTO; @@ -43,7 +43,7 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync() request.Model = AnthropicConstants.Claude3Haiku; request.Stream = true; request.MaxTokens = 500; - request.SystemMessage = "You are a helpful assistant that convert input to json object"; + request.SystemMessage = "You are a helpful assistant that convert input to json object, use JSON format."; request.Messages = new List() { new("user", "name: John, age: 41, email: g123456@gmail.com") diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj index 8cd1e3003b0e..8ce405a07ce8 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj +++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj @@ -6,16 +6,9 @@ false True AutoGen.Anthropic.Tests + True - - - - - - - - diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj index cf2c24eaf786..0f77db2c1c36 100644 --- a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj +++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj @@ -4,18 +4,10 @@ $(TestTargetFramework) enable false + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.Gemini.Tests/ApprovalTests/FunctionContractExtensionTests.ItGenerateGetWeatherToolTest.approved.txt b/dotnet/test/AutoGen.Gemini.Tests/ApprovalTests/FunctionContractExtensionTests.ItGenerateGetWeatherToolTest.approved.txt new file mode 100644 index 000000000000..d7ec585cb205 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/ApprovalTests/FunctionContractExtensionTests.ItGenerateGetWeatherToolTest.approved.txt @@ -0,0 +1,17 @@ +{ + "name": "GetWeatherAsync", + "description": "Get weather for a city.", + "parameters": { + "type": "OBJECT", + "properties": { + "city": { + "type": "STRING", + "description": "city", + "title": "city" +} + }, + "required": [ + "city" + ] + } +} \ No newline at end of file diff --git a/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj new file mode 100644 index 000000000000..f4fb55825e54 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj @@ -0,0 +1,19 @@ + + + + Exe + $(TestTargetFramework) + enable + enable + True + True + + + + + + + + + + diff --git a/dotnet/test/AutoGen.Gemini.Tests/FunctionContractExtensionTests.cs b/dotnet/test/AutoGen.Gemini.Tests/FunctionContractExtensionTests.cs new file mode 100644 index 000000000000..51d799acc220 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/FunctionContractExtensionTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// FunctionContractExtensionTests.cs + +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using AutoGen.Gemini.Extension; +using Google.Protobuf; +using Xunit; + +namespace AutoGen.Gemini.Tests; + +public class FunctionContractExtensionTests +{ + private readonly Functions functions = new Functions(); + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("ApprovalTests")] + public void ItGenerateGetWeatherToolTest() + { + var contract = functions.GetWeatherAsyncFunctionContract; + var tool = contract.ToFunctionDeclaration(); + var formatter = new JsonFormatter(JsonFormatter.Settings.Default.WithIndentation(" ")); + var json = formatter.Format(tool); + Approvals.Verify(json); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/Functions.cs b/dotnet/test/AutoGen.Gemini.Tests/Functions.cs new file mode 100644 index 000000000000..e3e07ee633fb --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/Functions.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Functions.cs + +using AutoGen.Core; + +namespace AutoGen.Gemini.Tests; + +public partial class Functions +{ + /// + /// Get weather for a city. + /// + /// city + /// weather + [Function] + public async Task GetWeatherAsync(string city) + { + return await Task.FromResult($"The weather in {city} is sunny."); + } + + [Function] + public async Task GetMovies(string location, string description) + { + var movies = new List { "Barbie", "Spiderman", "Batman" }; + + return await Task.FromResult($"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}"); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs new file mode 100644 index 000000000000..220492d64575 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiAgentTests.cs + +using AutoGen.Tests; +using Google.Cloud.AIPlatform.V1; +using AutoGen.Core; +using FluentAssertions; +using AutoGen.Gemini.Extension; +using static Google.Cloud.AIPlatform.V1.Part; +using Xunit.Abstractions; +using AutoGen.Gemini.Middleware; +namespace AutoGen.Gemini.Tests; + +public class GeminiAgentTests +{ + private readonly Functions functions = new Functions(); + private readonly ITestOutputHelper _output; + + public GeminiAgentTests(ITestOutputHelper output) + { + _output = output; + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task VertexGeminiAgentGenerateReplyForTextContentAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + + var textContent = new Content + { + Role = "user", + Parts = + { + new Part + { + Text = "Hello", + } + } + }; + + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location, + systemMessage: "You are a helpful AI assistant"); + var message = MessageEnvelope.Create(textContent, from: agent.Name); + + var completion = await agent.SendAsync(message); + + completion.Should().BeOfType>(); + completion.From.Should().Be(agent.Name); + + var response = ((MessageEnvelope)completion).Content; + response.Should().NotBeNull(); + response.Candidates.Count.Should().BeGreaterThan(0); + response.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task VertexGeminiAgentGenerateStreamingReplyForTextContentAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + + var textContent = new Content + { + Role = "user", + Parts = + { + new Part + { + Text = "Hello", + } + } + }; + + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location, + systemMessage: "You are a helpful AI assistant"); + var message = MessageEnvelope.Create(textContent, from: agent.Name); + + var completion = agent.GenerateStreamingReplyAsync([message]); + var chunks = new List(); + IStreamingMessage finalReply = null!; + + await foreach (var item in completion) + { + item.Should().NotBeNull(); + item.From.Should().Be(agent.Name); + var streamingMessage = (IMessage)item; + streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty(); + chunks.Add(item); + finalReply = item; + } + + chunks.Count.Should().BeGreaterThan(0); + finalReply.Should().NotBeNull(); + finalReply.Should().BeOfType>(); + var response = ((MessageEnvelope)finalReply).Content; + response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task VertexGeminiAgentGenerateReplyWithToolsAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + var tools = new Tool[] + { + new Tool + { + FunctionDeclarations = { + functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration(), + }, + }, + new Tool + { + FunctionDeclarations = + { + functions.GetMoviesFunctionContract.ToFunctionDeclaration(), + }, + }, + }; + + var textContent = new Content + { + Role = "user", + Parts = + { + new Part + { + Text = "what's the weather in seattle", + } + } + }; + + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location, + systemMessage: "You are a helpful AI assistant", + tools: tools, + toolConfig: new ToolConfig() + { + FunctionCallingConfig = new FunctionCallingConfig() + { + Mode = FunctionCallingConfig.Types.Mode.Auto, + } + }); + + var message = MessageEnvelope.Create(textContent, from: agent.Name); + + var completion = await agent.SendAsync(message); + + completion.Should().BeOfType>(); + completion.From.Should().Be(agent.Name); + + var response = ((MessageEnvelope)completion).Content; + response.Should().NotBeNull(); + response.Candidates.Count.Should().BeGreaterThan(0); + response.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task VertexGeminiAgentGenerateStreamingReplyWithToolsAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + var tools = new Tool[] + { + new Tool + { + FunctionDeclarations = { functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration() }, + }, + }; + + var textContent = new Content + { + Role = "user", + Parts = + { + new Part + { + Text = "what's the weather in seattle", + } + } + }; + + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location, + systemMessage: "You are a helpful AI assistant", + tools: tools, + toolConfig: new ToolConfig() + { + FunctionCallingConfig = new FunctionCallingConfig() + { + Mode = FunctionCallingConfig.Types.Mode.Auto, + } + }); + + var message = MessageEnvelope.Create(textContent, from: agent.Name); + + var chunks = new List(); + IStreamingMessage finalReply = null!; + + var completion = agent.GenerateStreamingReplyAsync([message]); + + await foreach (var item in completion) + { + item.Should().NotBeNull(); + item.From.Should().Be(agent.Name); + var streamingMessage = (IMessage)item; + streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty(); + if (streamingMessage.Content.Candidates[0].FinishReason != Candidate.Types.FinishReason.Stop) + { + streamingMessage.Content.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall); + } + chunks.Add(item); + finalReply = item; + } + + chunks.Count.Should().BeGreaterThan(0); + finalReply.Should().NotBeNull(); + finalReply.Should().BeOfType>(); + var response = ((MessageEnvelope)finalReply).Content; + response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task GeminiAgentUpperCaseTestAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location) + .RegisterMessageConnector(); + + var singleAgentTest = new SingleAgentTest(_output); + await singleAgentTest.UpperCaseStreamingTestAsync(agent); + await singleAgentTest.UpperCaseTestAsync(agent); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task GeminiAgentEchoFunctionCallTestAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + var singleAgentTest = new SingleAgentTest(_output); + var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract; + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location, + tools: + [ + new Tool + { + FunctionDeclarations = { echoFunctionContract.ToFunctionDeclaration() }, + }, + ]) + .RegisterMessageConnector(); + + await singleAgentTest.EchoFunctionCallTestAsync(agent); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task GeminiAgentEchoFunctionCallExecutionTestAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set."); + var model = "gemini-1.5-flash-001"; + var singleAgentTest = new SingleAgentTest(_output); + var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract; + var functionMiddleware = new FunctionCallMiddleware( + functions: [echoFunctionContract], + functionMap: new Dictionary>>() + { + { echoFunctionContract.Name!, singleAgentTest.EchoAsyncWrapper }, + }); + + var agent = new GeminiChatAgent( + name: "assistant", + model: model, + project: project, + location: location) + .RegisterMessageConnector() + .RegisterStreamingMiddleware(functionMiddleware); + + await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent); + await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs new file mode 100644 index 000000000000..7d72c18f1438 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs @@ -0,0 +1,380 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiMessageTests.cs + +using AutoGen.Core; +using AutoGen.Gemini.Middleware; +using AutoGen.Tests; +using FluentAssertions; +using Google.Cloud.AIPlatform.V1; +using Xunit; + +namespace AutoGen.Gemini.Tests; + +public class GeminiMessageTests +{ + [Fact] + public async Task ItProcessUserTextMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(1); + message.Content.Role.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + // when from is null and role is user + await agent.SendAsync("Hello"); + + // when from is user and role is user + var userMessage = new TextMessage(Role.User, "Hello", from: "user"); + await agent.SendAsync(userMessage); + + // when from is user but role is assistant + userMessage = new TextMessage(Role.Assistant, "Hello", from: "user"); + await agent.SendAsync(userMessage); + } + + [Fact] + public async Task ItProcessAssistantTextMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(1); + message.Content.Role.Should().Be("model"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + // when from is user and role is assistant + var message = new TextMessage(Role.User, "Hello", from: agent.Name); + await agent.SendAsync(message); + + // when from is assistant and role is assistant + message = new TextMessage(Role.Assistant, "Hello", from: agent.Name); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItProcessSystemTextMessageAsUserMessageWhenStrictModeIsFalseAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(1); + message.Content.Role.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var message = new TextMessage(Role.System, "Hello", from: agent.Name); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItThrowExceptionOnSystemMessageWhenStrictModeIsTrueAsync() + { + var messageConnector = new GeminiMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(messageConnector); + + var message = new TextMessage(Role.System, "Hello", from: agent.Name); + var action = new Func(async () => await agent.SendAsync(message)); + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task ItProcessUserImageMessageAsInlineDataAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(1); + message.Content.Role.Should().Be("user"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.InlineData); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var imagePath = Path.Combine("testData", "images", "background.png"); + var image = File.ReadAllBytes(imagePath); + var message = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png")); + message.MimeType.Should().Be("image/png"); + + await agent.SendAsync(message); + } + + [Fact] + public async Task ItProcessUserImageMessageAsFileDataAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(1); + message.Content.Role.Should().Be("user"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FileData); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var imagePath = Path.Combine("testData", "images", "image.png"); + var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri; + var message = new ImageMessage(Role.User, url); + message.MimeType.Should().Be("image/png"); + + await agent.SendAsync(message); + } + + [Fact] + public async Task ItProcessMultiModalMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Parts.Count.Should().Be(2); + message.Content.Role.Should().Be("user"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text); + message.Content.Parts.Last().DataCase.Should().Be(Part.DataOneofCase.FileData); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var imagePath = Path.Combine("testData", "images", "image.png"); + var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri; + var message = new ImageMessage(Role.User, url); + message.MimeType.Should().Be("image/png"); + var textMessage = new TextMessage(Role.User, "What's in this image?"); + var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, message]); + + await agent.SendAsync(multiModalMessage); + } + + [Fact] + public async Task ItProcessToolCallMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Role.Should().Be("model"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var toolCallMessage = new ToolCallMessage("test", "{}", "user"); + await agent.SendAsync(toolCallMessage); + } + + [Fact] + public async Task ItProcessStreamingTextMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterStreamingMiddleware(messageConnector); + + var messageChunks = Enumerable.Range(0, 10) + .Select(i => new GenerateContentResponse() + { + Candidates = + { + new Candidate() + { + Content = new Content() + { + Role = "user", + Parts = { new Part { Text = i.ToString() } }, + } + } + } + }) + .Select(m => MessageEnvelope.Create(m)); + + IStreamingMessage? finalReply = null; + await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks)) + { + reply.Should().BeAssignableTo(); + finalReply = reply; + } + + finalReply.Should().BeOfType(); + var textMessage = (TextMessage)finalReply!; + textMessage.GetContent().Should().Be("0123456789"); + } + + [Fact] + public async Task ItProcessToolCallResultMessageAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Role.Should().Be("function"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse); + message.Content.Parts.First().FunctionResponse.Response.ToString().Should().Be("{ \"result\": \"result\" }"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + + var message = new ToolCallResultMessage("result", "test", "{}", "user"); + await agent.SendAsync(message); + + // when the result is already a json object string + message = new ToolCallResultMessage("{ \"result\": \"result\" }", "test", "{}", "user"); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItProcessToolCallAggregateMessageAsTextContentAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.First(); + innerMessage.Should().BeOfType>(); + var message = (IMessage)innerMessage; + message.Content.Role.Should().Be("user"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + var toolCallMessage = new ToolCallMessage("test", "{}", "user"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", "user"); + var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: "user"); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItProcessToolCallAggregateMessageAsFunctionContentAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + msgs.Count().Should().Be(2); + var functionCallMessage = msgs.First(); + functionCallMessage.Should().BeOfType>(); + var message = (IMessage)functionCallMessage; + message.Content.Role.Should().Be("model"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall); + + var functionResultMessage = msgs.Last(); + functionResultMessage.Should().BeOfType>(); + message = (IMessage)functionResultMessage; + message.Content.Role.Should().Be("function"); + message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse); + + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(messageConnector); + var toolCallMessage = new ToolCallMessage("test", "{}", agent.Name); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", agent.Name); + var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingUnknownMessageTypeInStrictModeAsync() + { + var messageConnector = new GeminiMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(messageConnector); + + var unknownMessage = new + { + text = "Hello", + }; + + var message = MessageEnvelope.Create(unknownMessage, from: agent.Name); + var action = new Func(async () => await agent.SendAsync(message)); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task ItReturnUnknownMessageTypeInNonStrictModeAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + var message = msgs.First(); + message.Should().BeAssignableTo(); + return message; + }) + .RegisterMiddleware(messageConnector); + + var unknownMessage = new + { + text = "Hello", + }; + + var message = MessageEnvelope.Create(unknownMessage, from: agent.Name); + await agent.SendAsync(message); + } + + [Fact] + public async Task ItShortcircuitContentTypeAsync() + { + var messageConnector = new GeminiMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, ct) => + { + var message = msgs.First(); + message.Should().BeOfType>(); + + return message; + }) + .RegisterMiddleware(messageConnector); + + var message = new Content() + { + Parts = { new Part { Text = "Hello" } }, + Role = "user", + }; + + await agent.SendAsync(MessageEnvelope.Create(message, from: agent.Name)); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/GoogleGeminiClientTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GoogleGeminiClientTests.cs new file mode 100644 index 000000000000..3bda12eda1a6 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/GoogleGeminiClientTests.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GoogleGeminiClientTests.cs + +using AutoGen.Tests; +using FluentAssertions; +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf; +using static Google.Cloud.AIPlatform.V1.Candidate.Types; + +namespace AutoGen.Gemini.Tests; + +public class GoogleGeminiClientTests +{ + [ApiKeyFact("GOOGLE_GEMINI_API_KEY")] + public async Task ItGenerateContentAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set"); + var client = new GoogleGeminiClient(apiKey); + var model = "gemini-1.5-flash-001"; + + var text = "Write a long, tedious story"; + var request = new GenerateContentRequest + { + Model = model, + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + } + } + } + } + }; + var completion = await client.GenerateContentAsync(request); + + completion.Should().NotBeNull(); + completion.Candidates.Count.Should().BeGreaterThan(0); + completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("GOOGLE_GEMINI_API_KEY")] + public async Task ItGenerateContentWithImageAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set"); + var client = new GoogleGeminiClient(apiKey); + var model = "gemini-1.5-flash-001"; + + var text = "what's in the image"; + var imagePath = Path.Combine("testData", "images", "background.png"); + var image = File.ReadAllBytes(imagePath); + var request = new GenerateContentRequest + { + Model = model, + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + }, + new Part + { + InlineData = new () + { + MimeType = "image/png", + Data = ByteString.CopyFrom(image), + }, + } + } + } + } + }; + + var completion = await client.GenerateContentAsync(request); + completion.Should().NotBeNull(); + completion.Candidates.Count.Should().BeGreaterThan(0); + completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("GOOGLE_GEMINI_API_KEY")] + public async Task ItStreamingGenerateContentTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set"); + var client = new GoogleGeminiClient(apiKey); + var model = "gemini-1.5-flash-001"; + + var text = "Tell me a long tedious joke"; + var request = new GenerateContentRequest + { + Model = model, + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + } + } + } + } + }; + + var response = client.GenerateContentStreamAsync(request); + var chunks = new List(); + GenerateContentResponse? final = null; + await foreach (var item in response) + { + item.Candidates.Count.Should().BeGreaterThan(0); + final = item; + chunks.Add(final); + } + + chunks.Should().NotBeEmpty(); + final.Should().NotBeNull(); + final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0); + final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/SampleTests.cs b/dotnet/test/AutoGen.Gemini.Tests/SampleTests.cs new file mode 100644 index 000000000000..1f9b557af246 --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/SampleTests.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SampleTests.cs + +using AutoGen.Gemini.Sample; +using AutoGen.Tests; + +namespace AutoGen.Gemini.Tests; + +public class SampleTests +{ + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task TestChatWithVertexGeminiAsync() + { + await Chat_With_Vertex_Gemini.RunAsync(); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task TestFunctionCallWithGeminiAsync() + { + await Function_Call_With_Gemini.RunAsync(); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task TestImageChatWithVertexGeminiAsync() + { + await Image_Chat_With_Vertex_Gemini.RunAsync(); + } +} diff --git a/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs new file mode 100644 index 000000000000..2f06305ed59f --- /dev/null +++ b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GeminiVertexClientTests.cs + +using AutoGen.Tests; +using FluentAssertions; +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf; +using static Google.Cloud.AIPlatform.V1.Candidate.Types; +namespace AutoGen.Gemini.Tests; + +public class VertexGeminiClientTests +{ + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task ItGenerateContentAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + var client = new VertexGeminiClient(location); + var model = "gemini-1.5-flash-001"; + + var text = "Hello"; + var request = new GenerateContentRequest + { + Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}", + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + } + } + } + } + }; + var completion = await client.GenerateContentAsync(request); + + completion.Should().NotBeNull(); + completion.Candidates.Count.Should().BeGreaterThan(0); + completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task ItGenerateContentWithImageAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + var client = new VertexGeminiClient(location); + var model = "gemini-1.5-flash-001"; + + var text = "what's in the image"; + var imagePath = Path.Combine("testData", "images", "image.png"); + var image = File.ReadAllBytes(imagePath); + var request = new GenerateContentRequest + { + Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}", + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + }, + new Part + { + InlineData = new () + { + MimeType = "image/png", + Data = ByteString.CopyFrom(image), + }, + } + } + } + } + }; + + var completion = await client.GenerateContentAsync(request); + completion.Should().NotBeNull(); + completion.Candidates.Count.Should().BeGreaterThan(0); + completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("GCP_VERTEX_PROJECT_ID")] + public async Task ItStreamingGenerateContentTestAsync() + { + var location = "us-central1"; + var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID"); + var client = new VertexGeminiClient(location); + var model = "gemini-1.5-flash-001"; + + var text = "Hello, write a long tedious joke"; + var request = new GenerateContentRequest + { + Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}", + Contents = + { + new Content + { + Role = "user", + Parts = + { + new Part + { + Text = text, + } + } + } + } + }; + + var response = client.GenerateContentStreamAsync(request); + var chunks = new List(); + GenerateContentResponse? final = null; + await foreach (var item in response) + { + item.Candidates.Count.Should().BeGreaterThan(0); + final = item; + chunks.Add(final); + } + + chunks.Should().NotBeEmpty(); + final.Should().NotBeNull(); + final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0); + final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop); + } +} diff --git a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj index eff704869280..d734119dbb09 100644 --- a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj +++ b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj @@ -4,18 +4,10 @@ $(TestTargetFramework) enable false + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj index 27f80716f1c0..1e26b38d8a4f 100644 --- a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj +++ b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj @@ -4,18 +4,10 @@ $(TestTargetFramework) enable false + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj index 044975354b80..ba499232beb9 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj +++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj @@ -3,18 +3,10 @@ $(TestTargetFramework) false + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj b/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj index b6d03ddc4af1..8be4b55b1722 100644 --- a/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj +++ b/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj @@ -5,18 +5,10 @@ enable false $(NoWarn);SKEXP0110 + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj b/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj index 0d0d91e0522c..2e0ead045bef 100644 --- a/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj @@ -4,18 +4,10 @@ $(TestTargetFramework) enable false + True True - - - - - - - - - diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj index 740772c04079..4def281ed7b4 100644 --- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj +++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj @@ -3,18 +3,10 @@ $(TestTargetFramework) True + True $(NoWarn);xUnit1013;SKEXP0110 - - - - - - - - - diff --git a/dotnet/test/AutoGen.Tests/ImageMessageTests.cs b/dotnet/test/AutoGen.Tests/ImageMessageTests.cs new file mode 100644 index 000000000000..210cb1017ed3 --- /dev/null +++ b/dotnet/test/AutoGen.Tests/ImageMessageTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ImageMessageTests.cs + +using System; +using System.IO; +using System.Threading.Tasks; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Tests; + +public class ImageMessageTests +{ + [Fact] + public async Task ItCreateFromLocalImage() + { + var image = Path.Combine("testData", "images", "background.png"); + var binary = File.ReadAllBytes(image); + var base64 = Convert.ToBase64String(binary); + var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(binary, "image/png")); + + imageMessage.MimeType.Should().Be("image/png"); + imageMessage.BuildDataUri().Should().Be($"data:image/png;base64,{base64}"); + } + + [Fact] + public async Task ItCreateFromUrl() + { + var image = Path.Combine("testData", "images", "background.png"); + var fullPath = Path.GetFullPath(image); + var localUrl = new Uri(fullPath).AbsoluteUri; + var imageMessage = new ImageMessage(Role.User, localUrl); + + imageMessage.Url.Should().Be(localUrl); + imageMessage.MimeType.Should().Be("image/png"); + imageMessage.Data.Should().BeNull(); + } +} diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index b784ff8da035..418b55e70c7d 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -266,10 +266,10 @@ public async Task GetHighestLabel(string labelName, string color) public async Task EchoFunctionCallTestAsync(IAgent agent) { - var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function"); + //var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function"); var helloWorld = new TextMessage(Role.User, "echo Hello world"); - var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld }); + var reply = await agent.SendAsync(chatHistory: new[] { helloWorld }); reply.From.Should().Be(agent.Name); reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync)); @@ -277,10 +277,10 @@ public async Task EchoFunctionCallTestAsync(IAgent agent) public async Task EchoFunctionCallExecutionTestAsync(IAgent agent) { - var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); + //var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); var helloWorld = new TextMessage(Role.User, "echo Hello world"); - var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld }); + var reply = await agent.SendAsync(chatHistory: new[] { helloWorld }); reply.GetContent().Should().Be("[ECHO] Hello world"); reply.From.Should().Be(agent.Name); @@ -289,13 +289,13 @@ public async Task EchoFunctionCallExecutionTestAsync(IAgent agent) public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent) { - var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); + //var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); var helloWorld = new TextMessage(Role.User, "echo Hello world"); var option = new GenerateReplyOptions { Temperature = 0, }; - var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option); var answer = "[ECHO] Hello world"; IStreamingMessage? finalReply = default; await foreach (var reply in replyStream) @@ -319,25 +319,23 @@ public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent ag public async Task UpperCaseTestAsync(IAgent agent) { - var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case"); - var uppCaseMessage = new TextMessage(Role.User, "abcdefg"); + var message = new TextMessage(Role.User, "Please convert abcde to upper case."); - var reply = await agent.SendAsync(chatHistory: new[] { message, uppCaseMessage }); + var reply = await agent.SendAsync(chatHistory: new[] { message }); - reply.GetContent().Should().Contain("ABCDEFG"); + reply.GetContent().Should().Contain("ABCDE"); reply.From.Should().Be(agent.Name); } public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent) { - var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case"); - var helloWorld = new TextMessage(Role.User, "a b c d e f g h i j k l m n"); + var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case"); var option = new GenerateReplyOptions { Temperature = 0, }; - var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); - var answer = "A B C D E F G H I J K L M N"; + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option); + var answer = "HELLO WORLD"; TextMessage? finalReply = default; await foreach (var reply in replyStream) { From d578d0dfd9491f6a5a54c9b214f4baa41ddc552f Mon Sep 17 00:00:00 2001 From: David Luong Date: Mon, 10 Jun 2024 13:32:33 -0400 Subject: [PATCH 3/5] Squash changes (#2849) --- .../src/AutoGen.Anthropic/AnthropicClient.cs | 3 +- .../DTO/ChatCompletionRequest.cs | 8 +- .../Middleware/AnthropicMessageConnector.cs | 112 ++++++++++++++---- .../Agent/MiddlewareStreamingAgent.cs | 1 - .../AnthropicClientAgentTest.cs | 95 +++++++++++++-- .../AnthropicClientTest.cs | 37 +++++- .../AnthropicTestUtils.cs | 8 +- .../AutoGen.Anthropic.Tests.csproj | 6 + .../images/.gitattributes | 1 + .../AutoGen.Anthropic.Tests/images/square.png | 3 + 10 files changed, 239 insertions(+), 35 deletions(-) create mode 100644 dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes create mode 100644 dotnet/test/AutoGen.Anthropic.Tests/images/square.png diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs index 8ea0bef86e2c..90bd33683f20 100644 --- a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs +++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs @@ -23,7 +23,8 @@ public sealed class AnthropicClient : IDisposable private static readonly JsonSerializerOptions JsonSerializerOptions = new() { - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = { new ContentBaseConverter() } }; private static readonly JsonSerializerOptions JsonDeserializerOptions = new() diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs index f8967bc7e7fb..0c1749eaa989 100644 --- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs @@ -49,9 +49,15 @@ public class ChatMessage public string Role { get; set; } [JsonPropertyName("content")] - public string Content { get; set; } + public List Content { get; set; } public ChatMessage(string role, string content) + { + Role = role; + Content = new List() { new TextContent { Text = content } }; + } + + public ChatMessage(string role, List content) { Role = role; Content = content; diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs index bfe79190925f..bb2f5820f74c 100644 --- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs +++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.Http; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -19,7 +20,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { var messages = context.Messages; - var chatMessages = ProcessMessage(messages, agent); + var chatMessages = await ProcessMessageAsync(messages, agent); var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); return response is IMessage chatMessage @@ -31,7 +32,7 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c [EnumeratorCancellation] CancellationToken cancellationToken = default) { var messages = context.Messages; - var chatMessages = ProcessMessage(messages, agent); + var chatMessages = await ProcessMessageAsync(messages, agent); await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) { @@ -53,60 +54,78 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage chatMessage, IStreamingAgent agent) { - Delta? delta = chatMessage.Content.Delta; + var delta = chatMessage.Content.Delta; return delta != null && !string.IsNullOrEmpty(delta.Text) ? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name) : null; } - private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent) + private async Task> ProcessMessageAsync(IEnumerable messages, IAgent agent) { - return messages.SelectMany(m => + var processedMessages = new List(); + + foreach (var message in messages) { - return m switch + var processedMessage = message switch { TextMessage textMessage => ProcessTextMessage(textMessage, agent), - _ => [m], + + ImageMessage imageMessage => + new MessageEnvelope(new ChatMessage("user", + new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } } + .ToList()), + from: agent.Name), + + MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent), + _ => message, }; - }); + + processedMessages.Add(processedMessage); + } + + return processedMessages; } private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from) { if (response.Content is null) + { throw new ArgumentNullException(nameof(response.Content)); + } if (response.Content.Count != 1) + { throw new NotSupportedException($"{nameof(response.Content)} != 1"); + } return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name); } - private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent) + private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent) { - IEnumerable messages; + ChatMessage messages; if (textMessage.From == agent.Name) { - messages = [new ChatMessage( - "assistant", textMessage.Content)]; + messages = new ChatMessage( + "assistant", textMessage.Content); } else if (textMessage.From is null) { if (textMessage.Role == Role.User) { - messages = [new ChatMessage( - "user", textMessage.Content)]; + messages = new ChatMessage( + "user", textMessage.Content); } else if (textMessage.Role == Role.Assistant) { - messages = [new ChatMessage( - "assistant", textMessage.Content)]; + messages = new ChatMessage( + "assistant", textMessage.Content); } else if (textMessage.Role == Role.System) { - messages = [new ChatMessage( - "system", textMessage.Content)]; + messages = new ChatMessage( + "system", textMessage.Content); } else { @@ -116,10 +135,61 @@ private IEnumerable> ProcessTextMessage(TextMessage textMe else { // if from is not null, then the message is from user - messages = [new ChatMessage( - "user", textMessage.Content)]; + messages = new ChatMessage( + "user", textMessage.Content); } - return messages.Select(m => new MessageEnvelope(m, from: textMessage.From)); + return new MessageEnvelope(messages, from: textMessage.From); + } + + private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent) + { + var content = new List(); + foreach (var message in multiModalMessage.Content) + { + switch (message) + { + case TextMessage textMessage when textMessage.GetContent() is not null: + content.Add(new TextContent { Text = textMessage.GetContent() }); + break; + case ImageMessage imageMessage: + content.Add(new ImageContent() { Source = await ProcessImageSourceAsync(imageMessage) }); + break; + } + } + + var chatMessage = new ChatMessage("user", content); + return MessageEnvelope.Create(chatMessage, agent.Name); + } + + private async Task ProcessImageSourceAsync(ImageMessage imageMessage) + { + if (imageMessage.Data != null) + { + return new ImageSource + { + MediaType = imageMessage.Data.MediaType, + Data = Convert.ToBase64String(imageMessage.Data.ToArray()) + }; + } + + if (imageMessage.Url is null) + { + throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided"); + } + + var uri = new Uri(imageMessage.Url); + using var client = new HttpClient(); + var response = client.GetAsync(uri).Result; + if (!response.IsSuccessStatusCode) + { + throw new HttpRequestException($"Failed to download the image from {uri}"); + } + + return new ImageSource + { + MediaType = "image/jpeg", + Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync()) + }; } } diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 251d3c110f98..52967d6ff1ce 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -49,7 +49,6 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs index ba31f2297ba8..d29025b44aff 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -1,31 +1,108 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicClientAgentTest.cs +using AutoGen.Anthropic.DTO; using AutoGen.Anthropic.Extensions; using AutoGen.Anthropic.Utils; +using AutoGen.Core; using AutoGen.Tests; -using Xunit.Abstractions; +using FluentAssertions; -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public class AnthropicClientAgentTest { - private readonly ITestOutputHelper _output; - - public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output; - [ApiKeyFact("ANTHROPIC_API_KEY")] public async Task AnthropicAgentChatCompletionTestAsync() { var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that convert user message to upper case") + .RegisterMessageConnector(); + + var uppCaseMessage = new TextMessage(Role.User, "abcdefg"); + + var reply = await agent.SendAsync(chatHistory: new[] { uppCaseMessage }); + + reply.GetContent().Should().Contain("ABCDEFG"); + reply.From.Should().Be(agent.Name); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestProcessImageAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); var agent = new AnthropicClientAgent( client, name: "AnthropicAgent", AnthropicConstants.Claude3Haiku).RegisterMessageConnector(); - var singleAgentTest = new SingleAgentTest(_output); - await singleAgentTest.UpperCaseTestAsync(agent); - await singleAgentTest.UpperCaseStreamingTestAsync(agent); + var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png"); + var imageMessage = new ChatMessage("user", + [new ImageContent { Source = new ImageSource { MediaType = "image/png", Data = base64Image } }]); + + var messages = new IMessage[] { MessageEnvelope.Create(imageMessage) }; + + // test streaming + foreach (var message in messages) + { + var reply = agent.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be(agent.Name); + } + } + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestMultiModalAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku) + .RegisterMessageConnector(); + + var image = Path.Combine("images", "square.png"); + var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png"); + var imageMessage = new ImageMessage(Role.User, binaryData); + var textMessage = new TextMessage(Role.User, "What's in this image?"); + var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]); + + var reply = await agent.SendAsync(multiModalMessage); + reply.Should().BeOfType(); + reply.GetRole().Should().Be(Role.Assistant); + reply.GetContent().Should().NotBeNullOrEmpty(); + reply.From.Should().Be(agent.Name); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestImageMessageAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that is capable of determining what an image is. Tell me a brief description of the image." + ) + .RegisterMessageConnector(); + + var image = Path.Combine("images", "square.png"); + var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png"); + var imageMessage = new ImageMessage(Role.User, binaryData); + + var reply = await agent.SendAsync(imageMessage); + reply.Should().BeOfType(); + reply.GetRole().Should().Be(Role.Assistant); + reply.GetContent().Should().NotBeNullOrEmpty(); + reply.From.Should().Be(agent.Name); } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs index af050eef9280..a0b1f60cfb95 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs @@ -7,7 +7,7 @@ using FluentAssertions; using Xunit; -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public class AnthropicClientTests { @@ -73,6 +73,41 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync() results.First().streamingMessage!.Role.Should().Be("assistant"); } + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicClientImageChatCompletionTestAsync() + { + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var request = new ChatCompletionRequest(); + request.Model = AnthropicConstants.Claude3Haiku; + request.Stream = false; + request.MaxTokens = 100; + request.SystemMessage = "You are a LLM that is suppose to describe the content of the image. Give me a description of the provided image."; + + var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png"); + var messages = new List + { + new("user", + [ + new ImageContent { Source = new ImageSource {MediaType = "image/png", Data = base64Image} } + ]) + }; + + request.Messages = messages; + + var response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None); + + Assert.NotNull(response); + Assert.NotNull(response.Content); + Assert.NotEmpty(response.Content); + response.Content.Count.Should().Be(1); + response.Content.First().Should().BeOfType(); + var textContent = (TextContent)response.Content.First(); + Assert.Equal("text", textContent.Type); + Assert.NotNull(response.Usage); + response.Usage.OutputTokens.Should().BeGreaterThan(0); + } + private sealed class Person { [JsonPropertyName("name")] diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs index a5b80eee3bdf..de630da6d87c 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs @@ -1,10 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicTestUtils.cs -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public static class AnthropicTestUtils { public static string ApiKey => Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Please set ANTHROPIC_API_KEY environment variable."); + + public static async Task Base64FromImageAsync(string imageName) + { + return Convert.ToBase64String( + await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName))); + } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj index 8ce405a07ce8..0f22d9fe6764 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj +++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj @@ -13,4 +13,10 @@ + + + + PreserveNewest + + diff --git a/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes b/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes new file mode 100644 index 000000000000..56e7c34d4989 --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes @@ -0,0 +1 @@ +square.png filter=lfs diff=lfs merge=lfs -text diff --git a/dotnet/test/AutoGen.Anthropic.Tests/images/square.png b/dotnet/test/AutoGen.Anthropic.Tests/images/square.png new file mode 100644 index 000000000000..5c2b3ed820b1 --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/images/square.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8341030e5b93aab2c55dcd40ffa26ced8e42cc15736a8348176ffd155ad2d937 +size 8167 From 2d6c8c012bfd206998ffd76e470b9db4d2ed2a5b Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Mon, 10 Jun 2024 11:23:51 -0700 Subject: [PATCH 4/5] version update (#2908) * version update * version update --- OAI_CONFIG_LIST_sample | 4 ++-- autogen/oai/client.py | 4 ++-- autogen/oai/completion.py | 2 +- autogen/oai/openai_utils.py | 4 ++-- autogen/version.py | 2 +- notebook/agentchat_MathChat.ipynb | 4 ++-- notebook/agentchat_cost_token_tracking.ipynb | 2 +- notebook/agentchat_custom_model.ipynb | 4 ++-- notebook/agentchat_dalle_and_gpt4v.ipynb | 2 +- notebook/agentchat_function_call.ipynb | 4 ++-- .../agentchat_function_call_currency_calculator.ipynb | 4 ++-- notebook/agentchat_human_feedback.ipynb | 4 ++-- notebook/agentchat_microsoft_fabric.ipynb | 2 +- notebook/agentchat_planning.ipynb | 4 ++-- notebook/agentchat_stream.ipynb | 4 ++-- notebook/agentchat_teachable_oai_assistants.ipynb | 4 ++-- notebook/agentchat_two_users.ipynb | 4 ++-- notebook/agentchat_web_info.ipynb | 4 ++-- notebook/oai_chatgpt_gpt4.ipynb | 4 ++-- notebook/oai_completion.ipynb | 8 ++++---- samples/apps/websockets/application.py | 2 +- test/oai/test_utils.py | 4 ++-- test/test_logging.py | 2 +- website/blog/2023-12-01-AutoGenStudio/index.mdx | 2 +- website/docs/Use-Cases/enhanced_inference.md | 2 +- website/docs/topics/llm_configuration.ipynb | 4 ++-- 26 files changed, 45 insertions(+), 45 deletions(-) diff --git a/OAI_CONFIG_LIST_sample b/OAI_CONFIG_LIST_sample index 9fc0dc803a06..aa0b39216298 100644 --- a/OAI_CONFIG_LIST_sample +++ b/OAI_CONFIG_LIST_sample @@ -13,13 +13,13 @@ "api_key": "", "base_url": "", "api_type": "azure", - "api_version": "2024-02-15-preview" + "api_version": "2024-02-01" }, { "model": "", "api_key": "", "base_url": "", "api_type": "azure", - "api_version": "2024-02-15-preview" + "api_version": "2024-02-01" } ] diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 4c1da7a39311..f1a9c2fbf251 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -349,7 +349,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), "api_type": "azure", "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), - "api_version": "2024-02-15-preview", + "api_version": "2024-02-01", }, { "model": "gpt-3.5-turbo", @@ -559,7 +559,7 @@ def yes_or_no_filter(context, response): ``` - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. - - api_version (str | None): The api version. Default to None. E.g., "2024-02-15-preview". + - api_version (str | None): The api version. Default to None. E.g., "2024-02-01". Raises: - RuntimeError: If all declared custom model clients are not registered - APIError: If any model client create call raises an APIError diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py index e3b01ee4dd80..5a62cde33df0 100644 --- a/autogen/oai/completion.py +++ b/autogen/oai/completion.py @@ -741,7 +741,7 @@ def create( "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), "api_type": "azure", "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), - "api_version": "2024-02-15-preview", + "api_version": "2024-02-01", }, { "model": "gpt-3.5-turbo", diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index a676e9643904..0c8a0a413375 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -14,7 +14,7 @@ from packaging.version import parse NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"] -DEFAULT_AZURE_API_VERSION = "2024-02-15-preview" +DEFAULT_AZURE_API_VERSION = "2024-02-01" OAI_PRICE1K = { # https://openai.com/api/pricing/ # gpt-4o @@ -127,7 +127,7 @@ def get_config_list( # Optionally, define the API type and version if they are common for all keys api_type = 'azure' - api_version = '2024-02-15-preview' + api_version = '2024-02-01' # Call the get_config_list function to get a list of configuration dictionaries config_list = get_config_list(api_keys, base_urls, api_type, api_version) diff --git a/autogen/version.py b/autogen/version.py index 968391a2dbd2..4f6b515ecb20 100644 --- a/autogen/version.py +++ b/autogen/version.py @@ -1 +1 @@ -__version__ = "0.2.28" +__version__ = "0.2.29" diff --git a/notebook/agentchat_MathChat.ipynb b/notebook/agentchat_MathChat.ipynb index 8a234ede013a..afa00fb7562c 100644 --- a/notebook/agentchat_MathChat.ipynb +++ b/notebook/agentchat_MathChat.ipynb @@ -84,14 +84,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", " {\n", " 'model': 'gpt-3.5-turbo',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", "]\n", "```\n", diff --git a/notebook/agentchat_cost_token_tracking.ipynb b/notebook/agentchat_cost_token_tracking.ipynb index 7feb7a908f4e..fecc98f32763 100644 --- a/notebook/agentchat_cost_token_tracking.ipynb +++ b/notebook/agentchat_cost_token_tracking.ipynb @@ -88,7 +88,7 @@ " \"model\": \"gpt-35-turbo-0613\", # 0613 or newer is needed to use functions\n", " \"base_url\": \"\", \n", " \"api_type\": \"azure\", \n", - " \"api_version\": \"2024-02-15-preview\", # 2023-07-01-preview or newer is needed to use functions\n", + " \"api_version\": \"2024-02-01\", # 2023-07-01-preview or newer is needed to use functions\n", " \"api_key\": \"\",\n", " \"tags\": [\"gpt-3.5-turbo\", \"0613\"],\n", " }\n", diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index b06d2c3cf4e4..5097713a0923 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -226,14 +226,14 @@ " \"api_key\": \"\",\n", " \"base_url\": \"\",\n", " \"api_type\": \"azure\",\n", - " \"api_version\": \"2024-02-15-preview\"\n", + " \"api_version\": \"2024-02-01\"\n", " },\n", " {\n", " \"model\": \"gpt-4-32k\",\n", " \"api_key\": \"\",\n", " \"base_url\": \"\",\n", " \"api_type\": \"azure\",\n", - " \"api_version\": \"2024-02-15-preview\"\n", + " \"api_version\": \"2024-02-01\"\n", " }\n", "]\n", "```\n", diff --git a/notebook/agentchat_dalle_and_gpt4v.ipynb b/notebook/agentchat_dalle_and_gpt4v.ipynb index 258b49d6976b..e07578016a98 100644 --- a/notebook/agentchat_dalle_and_gpt4v.ipynb +++ b/notebook/agentchat_dalle_and_gpt4v.ipynb @@ -93,7 +93,7 @@ " {\n", " 'model': 'dalle',\n", " 'api_key': 'Your API Key here',\n", - " 'api_version': '2024-02-15-preview'\n", + " 'api_version': '2024-02-01'\n", " }\n", "]\n", " ```" diff --git a/notebook/agentchat_function_call.ipynb b/notebook/agentchat_function_call.ipynb index c91699d0d44c..2a173c8e2698 100644 --- a/notebook/agentchat_function_call.ipynb +++ b/notebook/agentchat_function_call.ipynb @@ -90,7 +90,7 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01\n", " 'tags': ['tool', 'gpt-3.5-turbo'],\n", " },\n", " {\n", @@ -98,7 +98,7 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01\n", " 'tags': ['tool', 'gpt-3.5-turbo-16k'],\n", " },\n", "]\n", diff --git a/notebook/agentchat_function_call_currency_calculator.ipynb b/notebook/agentchat_function_call_currency_calculator.ipynb index a7a5a92bbd90..d6ce5a88762a 100644 --- a/notebook/agentchat_function_call_currency_calculator.ipynb +++ b/notebook/agentchat_function_call_currency_calculator.ipynb @@ -90,7 +90,7 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " 'tags': ['tool', '3.5-tool'],\n", " },\n", " {\n", @@ -98,7 +98,7 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " 'tags': ['tool', '3.5-tool'],\n", " },\n", "]\n", diff --git a/notebook/agentchat_human_feedback.ipynb b/notebook/agentchat_human_feedback.ipynb index 75078e67cf9c..000d788d6a56 100644 --- a/notebook/agentchat_human_feedback.ipynb +++ b/notebook/agentchat_human_feedback.ipynb @@ -90,14 +90,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", " {\n", " 'model': 'gpt-3.5-turbo-16k',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", "]\n", "```\n", diff --git a/notebook/agentchat_microsoft_fabric.ipynb b/notebook/agentchat_microsoft_fabric.ipynb index 55793e0abb1f..58e7ddd20234 100644 --- a/notebook/agentchat_microsoft_fabric.ipynb +++ b/notebook/agentchat_microsoft_fabric.ipynb @@ -578,7 +578,7 @@ " \"api_key\": access_token,\n", " \"base_url\": prebuilt_AI_base_url,\n", " \"api_type\": \"azure\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", + " \"api_version\": \"2024-02-01\",\n", " },\n", "]" ] diff --git a/notebook/agentchat_planning.ipynb b/notebook/agentchat_planning.ipynb index 508792f01a57..14b393958dc1 100644 --- a/notebook/agentchat_planning.ipynb +++ b/notebook/agentchat_planning.ipynb @@ -93,14 +93,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " }, # Azure OpenAI API endpoint for gpt-4\n", " {\n", " 'model': 'gpt-4-32k',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " }, # Azure OpenAI API endpoint for gpt-4-32k\n", "]\n", "```\n", diff --git a/notebook/agentchat_stream.ipynb b/notebook/agentchat_stream.ipynb index 8cb899d2b508..8127cdfbab04 100644 --- a/notebook/agentchat_stream.ipynb +++ b/notebook/agentchat_stream.ipynb @@ -90,14 +90,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", " {\n", " 'model': 'gpt-3.5-turbo-16k',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", "]\n", "```\n", diff --git a/notebook/agentchat_teachable_oai_assistants.ipynb b/notebook/agentchat_teachable_oai_assistants.ipynb index 9bd69c9d51cd..3753be414f39 100644 --- a/notebook/agentchat_teachable_oai_assistants.ipynb +++ b/notebook/agentchat_teachable_oai_assistants.ipynb @@ -112,14 +112,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", " {\n", " 'model': 'gpt-4-32k',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", "]\n", "```\n", diff --git a/notebook/agentchat_two_users.ipynb b/notebook/agentchat_two_users.ipynb index 217492786885..eb9e0c1fbf28 100644 --- a/notebook/agentchat_two_users.ipynb +++ b/notebook/agentchat_two_users.ipynb @@ -70,14 +70,14 @@ " \"api_key\": \"\",\n", " \"base_url\": \"\",\n", " \"api_type\": \"azure\",\n", - " \"api_version\": \"2024-02-15-preview\"\n", + " \"api_version\": \"2024-02-01\"\n", " },\n", " {\n", " \"model\": \"gpt-4-32k\",\n", " \"api_key\": \"\",\n", " \"base_url\": \"\",\n", " \"api_type\": \"azure\",\n", - " \"api_version\": \"2024-02-15-preview\"\n", + " \"api_version\": \"2024-02-01\"\n", " }\n", "]\n", "```\n", diff --git a/notebook/agentchat_web_info.ipynb b/notebook/agentchat_web_info.ipynb index 31ac248ec9e3..f990c128b78c 100644 --- a/notebook/agentchat_web_info.ipynb +++ b/notebook/agentchat_web_info.ipynb @@ -104,14 +104,14 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", " {\n", " 'model': 'gpt-4-32k-0314',\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " },\n", "]\n", "```\n", diff --git a/notebook/oai_chatgpt_gpt4.ipynb b/notebook/oai_chatgpt_gpt4.ipynb index 34b5e5357fa6..280b7145e931 100644 --- a/notebook/oai_chatgpt_gpt4.ipynb +++ b/notebook/oai_chatgpt_gpt4.ipynb @@ -131,13 +131,13 @@ " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " }, # only if at least one Azure OpenAI API key is found\n", " {\n", " 'api_key': '',\n", " 'base_url': '',\n", " 'api_type': 'azure',\n", - " 'api_version': '2024-02-15-preview',\n", + " 'api_version': '2024-02-01',\n", " }, # only if the second Azure OpenAI API key is found\n", "]\n", "```\n", diff --git a/notebook/oai_completion.ipynb b/notebook/oai_completion.ipynb index 514ba6a4edeb..ac1b3f9c95f1 100644 --- a/notebook/oai_completion.ipynb +++ b/notebook/oai_completion.ipynb @@ -97,13 +97,13 @@ "# 'api_key': '',\n", "# 'base_url': '',\n", "# 'api_type': 'azure',\n", - "# 'api_version': '2024-02-15-preview',\n", + "# 'api_version': '2024-02-01',\n", "# }, # Azure OpenAI API endpoint for gpt-4\n", "# {\n", "# 'api_key': '',\n", "# 'base_url': '',\n", "# 'api_type': 'azure',\n", - "# 'api_version': '2024-02-15-preview',\n", + "# 'api_version': '2024-02-01',\n", "# }, # another Azure OpenAI API endpoint for gpt-4\n", "# ]\n", "\n", @@ -131,14 +131,14 @@ "# 'api_key': '',\n", "# 'base_url': '',\n", "# 'api_type': 'azure',\n", - "# 'api_version': '2024-02-15-preview',\n", + "# 'api_version': '2024-02-01',\n", "# }, # Azure OpenAI API endpoint for gpt-3.5-turbo\n", "# {\n", "# 'model': 'gpt-35-turbo-v0301',\n", "# 'api_key': '',\n", "# 'base_url': '',\n", "# 'api_type': 'azure',\n", - "# 'api_version': '2024-02-15-preview',\n", + "# 'api_version': '2024-02-01',\n", "# }, # another Azure OpenAI API endpoint for gpt-3.5-turbo with deployment name gpt-35-turbo-v0301\n", "# ]" ] diff --git a/samples/apps/websockets/application.py b/samples/apps/websockets/application.py index f2e453d92482..fe75d135330e 100755 --- a/samples/apps/websockets/application.py +++ b/samples/apps/websockets/application.py @@ -35,7 +35,7 @@ def _get_config_list() -> List[Dict[str, str]]: 'api_key': '0123456789abcdef0123456789abcdef', 'base_url': 'https://my-deployment.openai.azure.com/', 'api_type': 'azure', - 'api_version': '2024-02-15-preview', + 'api_version': '2024-02-01', }, { 'model': 'gpt-4', diff --git a/test/oai/test_utils.py b/test/oai/test_utils.py index d5ad84d8355d..99f8d8d24e8b 100755 --- a/test/oai/test_utils.py +++ b/test/oai/test_utils.py @@ -58,7 +58,7 @@ "api_key": "111113fc7e8a46419bfac511bb301111", "base_url": "https://1111.openai.azure.com", "api_type": "azure", - "api_version": "2024-02-15-preview" + "api_version": "2024-02-01" }, { "model": "gpt", @@ -83,7 +83,7 @@ "expected": JSON_SAMPLE_DICT[2:4], }, { - "filter_dict": {"api_type": "azure", "api_version": "2024-02-15-preview"}, + "filter_dict": {"api_type": "azure", "api_version": "2024-02-01"}, "exclude": False, "expected": [JSON_SAMPLE_DICT[2]], }, diff --git a/test/test_logging.py b/test/test_logging.py index c6f7a182c5c0..bd9a74d3fd47 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -202,7 +202,7 @@ def test_log_oai_client(db_connection): openai_config = { "api_key": "some_key", - "api_version": "2024-02-15-preview", + "api_version": "2024-02-01", "azure_deployment": "gpt-4", "azure_endpoint": "https://foobar.openai.azure.com/", } diff --git a/website/blog/2023-12-01-AutoGenStudio/index.mdx b/website/blog/2023-12-01-AutoGenStudio/index.mdx index 49151f7b355f..4d893b144a78 100644 --- a/website/blog/2023-12-01-AutoGenStudio/index.mdx +++ b/website/blog/2023-12-01-AutoGenStudio/index.mdx @@ -65,7 +65,7 @@ llm_config = LLMConfig( "api_key": "", "base_url": "", "api_type": "azure", - "api_version": "2024-02-15-preview" + "api_version": "2024-02-01" }], temperature=0, ) diff --git a/website/docs/Use-Cases/enhanced_inference.md b/website/docs/Use-Cases/enhanced_inference.md index e97b67fa9dea..14723391e8cf 100644 --- a/website/docs/Use-Cases/enhanced_inference.md +++ b/website/docs/Use-Cases/enhanced_inference.md @@ -183,7 +183,7 @@ client = OpenAIWrapper( "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), "api_type": "azure", "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), - "api_version": "2024-02-15-preview", + "api_version": "2024-02-01", }, { "model": "gpt-3.5-turbo", diff --git a/website/docs/topics/llm_configuration.ipynb b/website/docs/topics/llm_configuration.ipynb index c0a1b7e74a98..f6f383cd85d8 100644 --- a/website/docs/topics/llm_configuration.ipynb +++ b/website/docs/topics/llm_configuration.ipynb @@ -92,7 +92,7 @@ " \"api_type\": \"azure\",\n", " \"api_key\": os.environ['AZURE_OPENAI_API_KEY'],\n", " \"base_url\": \"https://ENDPOINT.openai.azure.com/\",\n", - " \"api_version\": \"2024-02-15-preview\"\n", + " \"api_version\": \"2024-02-01\"\n", " }\n", " ]\n", " ```\n", @@ -328,7 +328,7 @@ " \"api_key\": os.environ.get(\"AZURE_OPENAI_API_KEY\"),\n", " \"api_type\": \"azure\",\n", " \"base_url\": os.environ.get(\"AZURE_OPENAI_API_BASE\"),\n", - " \"api_version\": \"2024-02-15-preview\",\n", + " \"api_version\": \"2024-02-01\",\n", " },\n", " {\n", " \"model\": \"llama-7B\",\n", From bf7e4d619c24bc329a62741acf402e7866c211a4 Mon Sep 17 00:00:00 2001 From: Audel Rouhi Date: Tue, 11 Jun 2024 07:16:56 -0500 Subject: [PATCH 5/5] Bugfix: PGVector/RAG - Calculate the Vector Size based on Model Dimensions (#2865) * Calculate the dimension size based off model chosen. * Added example docstring. * Validated working notebook with sentence models of different dimensions. * Validated removal of model_name working. * Second example uses conn object. * embedding_function no longer directly references .encode * Fixed pre-commit issue. * Use try/except to raise error when shape is not found in embedding function. * Re-ran notebook. * Update autogen/agentchat/contrib/vectordb/pgvectordb.py Co-authored-by: Li Jiang * Update autogen/agentchat/contrib/vectordb/pgvectordb.py Co-authored-by: Li Jiang * Added .encode * Removed example comment. * Fix overwrite doesn't work with existing collection when custom embedding function has different dimension from default one --------- Co-authored-by: Li Jiang --- .../agentchat/contrib/vectordb/pgvectordb.py | 84 ++-- .../agentchat_pgvector_RetrieveChat.ipynb | 451 ++++++++++++++++-- .../test_pgvector_retrievechat.py | 2 +- 3 files changed, 447 insertions(+), 90 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py index 38507cb7998e..ac86802b6723 100644 --- a/autogen/agentchat/contrib/vectordb/pgvectordb.py +++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py @@ -32,10 +32,11 @@ class Collection: client: The PGVector client. collection_name (str): The name of the collection. Default is "documents". embedding_function (Callable): The embedding function used to generate the vector representation. + Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None. + Models can be chosen from: + https://huggingface.co/models?library=sentence-transformers metadata (Optional[dict]): The metadata of the collection. get_or_create (Optional): The flag indicating whether to get or create the collection. - model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from: - https://huggingface.co/models?library=sentence-transformers """ def __init__( @@ -45,7 +46,6 @@ def __init__( embedding_function: Callable = None, metadata=None, get_or_create=None, - model_name="all-MiniLM-L6-v2", ): """ Initialize the Collection object. @@ -56,30 +56,26 @@ def __init__( embedding_function: The embedding function used to generate the vector representation. metadata: The metadata of the collection. get_or_create: The flag indicating whether to get or create the collection. - model_name: | Sentence embedding model to use. Models can be chosen from: - https://huggingface.co/models?library=sentence-transformers Returns: None """ self.client = client - self.embedding_function = embedding_function - self.model_name = model_name self.name = self.set_collection_name(collection_name) self.require_embeddings_or_documents = False self.ids = [] - try: - self.embedding_function = ( - SentenceTransformer(self.model_name) if embedding_function is None else embedding_function - ) - except Exception as e: - logger.error( - f"Validate the model name entered: {self.model_name} " - f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}" - ) - raise e + if embedding_function: + self.embedding_function = embedding_function + else: + self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16} self.documents = "" self.get_or_create = get_or_create + # This will get the model dimension size by computing the embeddings dimensions + sentences = [ + "The weather is lovely today in paradise.", + ] + embeddings = self.embedding_function(sentences) + self.dimension = len(embeddings[0]) def set_collection_name(self, collection_name) -> str: name = re.sub("-", "_", collection_name) @@ -115,14 +111,14 @@ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metad elif metadatas is not None: for doc_id, metadata, document in zip(ids, metadatas, documents): metadata = re.sub("'", '"', str(metadata)) - embedding = self.embedding_function.encode(document) + embedding = self.embedding_function(document) sql_values.append((doc_id, metadata, embedding, document)) sql_string = ( f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n" ) else: for doc_id, document in zip(ids, documents): - embedding = self.embedding_function.encode(document) + embedding = self.embedding_function(document) sql_values.append((doc_id, document, embedding)) sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n" logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}") @@ -166,7 +162,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me elif metadatas is not None: for doc_id, metadata, document in zip(ids, metadatas, documents): metadata = re.sub("'", '"', str(metadata)) - embedding = self.embedding_function.encode(document) + embedding = self.embedding_function(document) sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding)) sql_string = ( f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" @@ -176,7 +172,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me ) else: for doc_id, document in zip(ids, documents): - embedding = self.embedding_function.encode(document) + embedding = self.embedding_function(document) sql_values.append((doc_id, document, embedding, document)) sql_string = ( f"INSERT INTO {self.name} (id, documents, embedding)\n" @@ -304,7 +300,7 @@ def get( ) except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e: logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}") - self.create_collection(collection_name=self.name) + self.create_collection(collection_name=self.name, dimension=self.dimension) logger.info(f"Created table {self.name}") cursor.close() @@ -419,7 +415,7 @@ def query( cursor = self.client.cursor() results = [] for query_text in query_texts: - vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist() + vector = self.embedding_function(query_text, convert_to_tensor=False).tolist() if distance_type.lower() == "cosine": index_function = "<=>" elif distance_type.lower() == "euclidean": @@ -526,22 +522,31 @@ def delete_collection(self, collection_name: Optional[str] = None) -> None: cursor.execute(f"DROP TABLE IF EXISTS {self.name}") cursor.close() - def create_collection(self, collection_name: Optional[str] = None) -> None: + def create_collection( + self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None + ) -> None: """ Create a new collection. Args: collection_name (Optional[str]): The name of the new collection. + dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model Returns: None """ if collection_name: self.name = collection_name + + if dimension: + self.dimension = dimension + elif self.dimension is None: + self.dimension = 384 + cursor = self.client.cursor() cursor.execute( f"CREATE TABLE {self.name} (" - f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));" + f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));" f"CREATE INDEX " f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, ' f'ef_construction = {self.metadata["hnsw:construction_ef"]});' @@ -573,7 +578,6 @@ def __init__( connect_timeout: Optional[int] = 10, embedding_function: Callable = None, metadata: Optional[dict] = None, - model_name: Optional[str] = "all-MiniLM-L6-v2", ) -> None: """ Initialize the vector database. @@ -591,15 +595,14 @@ def __init__( username: str | The database username to use. Default is None. password: str | The database user password to use. Default is None. connect_timeout: int | The timeout to set for the connection. Default is 10. - embedding_function: Callable | The embedding function used to generate the vector representation - of the documents. Default is None. + embedding_function: Callable | The embedding function used to generate the vector representation. + Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None. + Models can be chosen from: + https://huggingface.co/models?library=sentence-transformers metadata: dict | The metadata of the vector database. Default is None. If None, it will use this setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef". For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw - model_name: str | Sentence embedding model to use. Models can be chosen from: - https://huggingface.co/models?library=sentence-transformers - Returns: None """ @@ -613,17 +616,10 @@ def __init__( password=password, connect_timeout=connect_timeout, ) - self.model_name = model_name - try: - self.embedding_function = ( - SentenceTransformer(self.model_name) if embedding_function is None else embedding_function - ) - except Exception as e: - logger.error( - f"Validate the model name entered: {self.model_name} " - f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}" - ) - raise e + if embedding_function: + self.embedding_function = embedding_function + else: + self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode self.metadata = metadata register_vector(self.client) self.active_collection = None @@ -738,7 +734,6 @@ def create_collection( embedding_function=self.embedding_function, get_or_create=get_or_create, metadata=self.metadata, - model_name=self.model_name, ) collection.set_collection_name(collection_name=collection_name) collection.create_collection(collection_name=collection_name) @@ -751,7 +746,6 @@ def create_collection( embedding_function=self.embedding_function, get_or_create=get_or_create, metadata=self.metadata, - model_name=self.model_name, ) collection.set_collection_name(collection_name=collection_name) collection.create_collection(collection_name=collection_name) @@ -765,7 +759,6 @@ def create_collection( embedding_function=self.embedding_function, get_or_create=get_or_create, metadata=self.metadata, - model_name=self.model_name, ) collection.set_collection_name(collection_name=collection_name) collection.create_collection(collection_name=collection_name) @@ -797,7 +790,6 @@ def get_collection(self, collection_name: str = None) -> Collection: client=self.client, collection_name=collection_name, embedding_function=self.embedding_function, - model_name=self.model_name, ) return self.active_collection diff --git a/notebook/agentchat_pgvector_RetrieveChat.ipynb b/notebook/agentchat_pgvector_RetrieveChat.ipynb index 9b037b7c468d..1a8d70e29654 100644 --- a/notebook/agentchat_pgvector_RetrieveChat.ipynb +++ b/notebook/agentchat_pgvector_RetrieveChat.ipynb @@ -72,14 +72,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "models to use: ['gpt-35-turbo', 'gpt4-1106-preview', 'gpt-35-turbo-0613']\n" + "models to use: ['gpt4-1106-preview', 'gpt-4o', 'gpt-35-turbo', 'gpt-35-turbo-0613']\n" ] } ], @@ -89,6 +89,7 @@ "\n", "import chromadb\n", "import psycopg\n", + "from sentence_transformers import SentenceTransformer\n", "\n", "import autogen\n", "from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n", @@ -114,7 +115,10 @@ " \"api_key\": \"...\",\n", " },\n", "]\n", - "\n", + "config_list = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " file_location=\".\",\n", + ")\n", "assert len(config_list) > 0\n", "print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])" ] @@ -137,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -145,7 +149,7 @@ "output_type": "stream", "text": [ "Accepted file formats for `docs_path`:\n", - "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" + "['yaml', 'ppt', 'rst', 'jsonl', 'xml', 'txt', 'yml', 'log', 'rtf', 'msg', 'xlsx', 'htm', 'pdf', 'org', 'pptx', 'md', 'docx', 'epub', 'tsv', 'csv', 'html', 'doc', 'odt', 'json']\n" ] } ], @@ -156,17 +160,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/workspace/anaconda3/envs/autogen/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "/workspace/anaconda3/envs/autogen/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" + "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n" ] } ], @@ -185,6 +187,9 @@ "# Optionally create psycopg conn object\n", "# conn = psycopg.connect(conninfo=\"postgresql://postgres:postgres@localhost:5432/postgres\", autocommit=True)\n", "\n", + "# Optionally create embedding function object\n", + "sentence_transformer_ef = SentenceTransformer(\"all-distilroberta-v1\").encode\n", + "\n", "# 2. create the RetrieveUserProxyAgent instance named \"ragproxyagent\"\n", "# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n", "# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default,\n", @@ -218,11 +223,11 @@ " # \"dbname\": \"postgres\", # Optional vector database name\n", " # \"username\": \"postgres\", # Optional vector database username\n", " # \"password\": \"postgres\", # Optional vector database password\n", - " \"model_name\": \"all-MiniLM-L6-v2\", # Sentence embedding model from https://huggingface.co/models?library=sentence-transformers or https://www.sbert.net/docs/pretrained_models.html\n", " # \"conn\": conn, # Optional - conn object to connect to database\n", " },\n", " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n", - " \"overwrite\": False, # set to True if you want to overwrite an existing collection\n", + " \"overwrite\": True, # set to True if you want to overwrite an existing collection\n", + " \"embedding_function\": sentence_transformer_ef, # If left out SentenceTransformer(\"all-MiniLM-L6-v2\").encode will be used\n", " },\n", " code_execution_config=False, # set to False if you don't want to execute the code\n", ")" @@ -244,40 +249,43 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to create collection.\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-23 08:48:18,875 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `flaml_collection`.\u001b[0m\n" + "2024-06-11 19:57:44,122 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Trying to create collection.\n" + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-23 08:48:19,975 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", - "2024-05-23 08:48:19,977 - autogen.agentchat.contrib.vectordb.pgvectordb - INFO - Error executing select on non-existent table: flaml_collection. Creating it instead. Error: relation \"flaml_collection\" does not exist\n", - "LINE 1: SELECT id, metadatas, documents, embedding FROM flaml_collec...\n", - " ^\u001b[0m\n", - "2024-05-23 08:48:19,996 - autogen.agentchat.contrib.vectordb.pgvectordb - INFO - Created table flaml_collection\u001b[0m\n" + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", - "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", @@ -540,7 +548,6 @@ "\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -804,7 +811,50 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "To use FLAML to perform a classification task and use Spark to do parallel training, you need to use the Spark ML estimators for AutoML. First, you need to prepare your data in the required format as described in the previous section. FLAML provides a convenient function \"to_pandas_on_spark\" to convert your data into a pandas-on-spark dataframe/series, which Spark estimators require. After that, use the pandas-on-spark data like non-spark data and pass them using X_train, y_train or dataframe, label. Finally, configure FLAML to use Spark as the parallel backend during parallel tuning by setting the use_spark to true. An example code snippet is provided in the context above.\n", + "Based on the provided context which details the integration of Spark with FLAML for distributed training, and the requirement to perform a classification task with parallel training in Spark, here's a code snippet that configures FLAML to train a classification model for 30 seconds and cancels the jobs if the time limit is reached.\n", + "\n", + "```python\n", + "from flaml import AutoML\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "import pandas as pd\n", + "\n", + "# Your pandas DataFrame 'data' goes here\n", + "# Assuming 'data' is already a pandas DataFrame with appropriate data for classification\n", + "# and 'label_column' is the name of the column that we want to predict.\n", + "\n", + "# First, convert your pandas DataFrame to a pandas-on-spark DataFrame\n", + "psdf = to_pandas_on_spark(data)\n", + "\n", + "# Now, we prepare the settings for the AutoML training with Spark\n", + "automl_settings = {\n", + " \"time_budget\": 30, # Train for 30 seconds\n", + " \"metric\": \"accuracy\", # Assuming you want to use accuracy as the metric\n", + " \"task\": \"classification\",\n", + " \"n_concurrent_trials\": 2, # Adjust the number of concurrent trials depending on your cluster setup\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Force cancel jobs if time limit is reached\n", + "}\n", + "\n", + "# Create an AutoML instance\n", + "automl = AutoML()\n", + "\n", + "# Run the AutoML search\n", + "# You need to replace 'psdf' with your actual pandas-on-spark DataFrame variable\n", + "# and 'label_column' with the name of your label column\n", + "automl.fit(dataframe=psdf, label=label_column, **automl_settings)\n", + "```\n", + "\n", + "This code snippet assumes that the `data` variable contains the pandas DataFrame you want to classify and that `label_column` is the name of the target variable for the classification task. Make sure to replace 'data' and 'label_column' with your actual data and label column name before running this code.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -840,15 +890,51 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to create collection.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-11 19:58:21,076 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "VectorDB returns doc_ids: [['7968cf3c', 'bdfbc921']]\n", - "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n", + "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", @@ -1110,18 +1196,270 @@ "\n", "\n", "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: Who is the author of FLAML?\n", + "\n", + "Context is: # Research\n", + "\n", + "For technical details, please check our research publications.\n", + "\n", + "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021flaml,\n", + " title={FLAML: A Fast and Lightweight AutoML Library},\n", + " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", + " year={2021},\n", + " booktitle={MLSys},\n", + "}\n", + "```\n", + "\n", + "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021cfo,\n", + " title={Frugal Optimization for Cost-related Hyperparameters},\n", + " author={Qingyun Wu and Chi Wang and Silu Huang},\n", + " year={2021},\n", + " booktitle={AAAI},\n", + "}\n", + "```\n", + "\n", + "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021blendsearch,\n", + " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", + " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", + " year={2021},\n", + " booktitle={ICLR},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{liuwang2021hpolm,\n", + " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", + " author={Susan Xueqing Liu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ACL},\n", + "}\n", + "```\n", + "\n", + "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021chacha,\n", + " title={ChaCha for Online AutoML},\n", + " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", + " year={2021},\n", + " booktitle={ICML},\n", + "}\n", + "```\n", + "\n", + "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", + "\n", + "```bibtex\n", + "@inproceedings{wuwang2021fairautoml,\n", + " title={Fair AutoML},\n", + " author={Qingyun Wu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ArXiv preprint arXiv:2111.06495},\n", + "}\n", + "```\n", + "\n", + "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", + "\n", + "```bibtex\n", + "@inproceedings{kayaliwang2022default,\n", + " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", + " author={Moe Kayali and Chi Wang},\n", + " year={2022},\n", + " booktitle={ArXiv preprint arXiv:2202.09927},\n", + "}\n", + "```\n", + "\n", + "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", + "\n", + "```bibtex\n", + "@inproceedings{zhang2023targeted,\n", + " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", + " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", + " booktitle={International Conference on Learning Representations},\n", + " year={2023},\n", + " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", + "}\n", + "```\n", + "\n", + "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2023EcoOptiGen,\n", + " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", + " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2303.04673},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2023empirical,\n", + " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", + " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2306.01337},\n", + "}\n", + "```\n", + "# Integrate - Spark\n", + "\n", + "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", + "- Use Spark ML estimators for AutoML.\n", + "- Use Spark to run training in parallel spark jobs.\n", + "\n", + "## Spark ML Estimators\n", + "\n", + "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", + "\n", + "### Data\n", + "\n", + "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", + "\n", + "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", + "\n", + "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", + "- `index_col` is the column name to use as the index, default is None.\n", + "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", + "\n", + "Here is an example code snippet for Spark Data:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", + "# Creating a dictionary\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "```\n", + "\n", + "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", + "\n", + "Here is an example of how to use it:\n", + "\n", + "```python\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "```\n", + "\n", + "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", + "\n", + "### Estimators\n", + "\n", + "#### Model List\n", + "\n", + "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", + "\n", + "#### Usage\n", + "\n", + "First, prepare your data in the required format as described in the previous section.\n", + "\n", + "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", + "\n", + "Here is an example code snippet using SparkML models in AutoML:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "# prepare your data in pandas-on-spark format as we previously mentioned\n", + "\n", + "automl = flaml.AutoML()\n", + "settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", + " \"task\": \"regression\",\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=psdf,\n", + " label=label,\n", + " **settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", + "\n", + "## Parallel Spark Jobs\n", + "\n", + "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", + "\n", + "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "\n", + "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", + "\n", + "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", + "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", + "\n", + "An example code snippet for using parallel Spark jobs:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "automl_experiment = flaml.AutoML()\n", + "automl_settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2,\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=dataframe,\n", + " label=label,\n", + " **automl_settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", + "\n", "\n", - "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", "\n", "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", "\n", "--------------------------------------------------------------------------------\n" @@ -1132,16 +1470,43 @@ "# reset the assistant. Always reset the assistant before starting a new conversation.\n", "assistant.reset()\n", "\n", + "# Optionally create psycopg conn object\n", + "conn = psycopg.connect(conninfo=\"postgresql://postgres:postgres@localhost:5432/postgres\", autocommit=True)\n", + "\n", + "ragproxyagent = RetrieveUserProxyAgent(\n", + " name=\"ragproxyagent\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + " retrieve_config={\n", + " \"task\": \"code\",\n", + " \"docs_path\": [\n", + " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n", + " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n", + " os.path.join(os.path.abspath(\"\"), \"..\", \"website\", \"docs\"),\n", + " ],\n", + " \"custom_text_types\": [\"non-existent-type\"],\n", + " \"chunk_token_size\": 2000,\n", + " \"model\": config_list[0][\"model\"],\n", + " \"vector_db\": \"pgvector\", # PGVector database\n", + " \"collection_name\": \"flaml_collection\",\n", + " \"db_config\": {\n", + " # \"connection_string\": \"postgresql://postgres:postgres@localhost:5432/postgres\", # Optional - connect to an external vector database\n", + " # \"host\": \"postgres\", # Optional vector database host\n", + " # \"port\": 5432, # Optional vector database port\n", + " # \"dbname\": \"postgres\", # Optional vector database name\n", + " # \"username\": \"postgres\", # Optional vector database username\n", + " # \"password\": \"postgres\", # Optional vector database password\n", + " \"conn\": conn, # Optional - conn object to connect to database\n", + " },\n", + " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n", + " \"overwrite\": True, # set to True if you want to overwrite an existing collection\n", + " },\n", + " code_execution_config=False, # set to False if you don't want to execute the code\n", + ")\n", + "\n", "qa_problem = \"Who is the author of FLAML?\"\n", "chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -1166,7 +1531,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.13" }, "skip_test": "Requires interactive usage" }, diff --git a/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py b/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py index b104f25af767..ca24f952f76d 100644 --- a/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py +++ b/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py @@ -56,7 +56,7 @@ def test_retrievechat(): }, ) - sentence_transformer_ef = SentenceTransformer("all-MiniLM-L6-v2") + sentence_transformer_ef = SentenceTransformer("all-MiniLM-L6-v2").encode ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER",