Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to fine tune custom model on conversable agents #1787

Merged
merged 15 commits into from
Mar 11, 2024
Merged
46 changes: 46 additions & 0 deletions .github/workflows/samples-tools-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: SamplesToolsTests

on:
pull_request:
branches: ["main"]
paths:
- "autogen/**"
- "samples/tools/**"
- ".github/workflows/samples-tools-tests.yml"
- "setup.py"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions: {}
jobs:
SamplesToolsFineTuningTests:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install -e .
pip install pytest
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Test finetuning tools
run: |
pytest samples/tools/finetuning/tests/
87 changes: 87 additions & 0 deletions samples/tools/finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Tools for fine-tuning the local models that power agents

This directory aims to contain tools for fine-tuning the local models that power agents.

## Fine tune a custom model client

AutoGen supports the use of custom models to power agents [see blog post here](https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models). This directory contains a tool to provide feedback to that model, that can be used to fine-tune the model.

The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms.

Custom Model Client will have follow the protocol client defined in `update_model.py` `UpdateableModelClient` which is a subclass of `ModelClient` and adds the following method:

```python
def update_model(
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted.

Learn from the preference data.

Args:
preference_data: The preference data.
inference_messages: The messages that were used during inference between the agent that is being updated and another agent.
**kwargs: other arguments.

Returns:
Dict of learning stats.
"""
```

The function provided in the file `update_model.py` is called by passing these arguments:

- the agent whose model is to be updated
- the preference data
- the agent whose conversation is being used to provide the inference messages

The function will find the conversation thread that occurred between the "update agent" and the "other agent", and call the `update_model` method of the model client. It will return a dictionary containing the update stats, inference messages, and preference data:

```python
{
"update_stats": <the dictionary returned by the custom model client implementation>,
"inference_messages": <message used for inference>,
"preference_data": <the preference data passed in when update_model was called>
}
```

**NOTES**:

`inference_messages` will contain messages that were passed into the custom model client when `create` was called and a response was needed from the model. It is up to the author of the custom model client to decide which parts of the conversation are needed and how to use this data to fine-tune the model.

If a conversation has been long-running before `update_model` is called, then the `inference_messages` will contain a conversation thread that was used for multiple inference steps. It is again up to the author of the custom model client to decide which parts of the conversation correspond to the preference data and how to use this data to fine-tune the model.

An example of how to use this tool is shown below:

```python
from finetuning.update_model import update_model

assistant = AssistantAgent(
"assistant",
system_message="You are a helpful assistant.",
human_input_mode="NEVER",
llm_config={
"config_list": [<the config list containing the custom model>],
},
)

assistant.register_model_client(model_client_cls=<TheCustomModelClientClass>)

user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=1,
code_execution_config=False,
llm_config=False,
)

res = user_proxy.initiate_chat(assistant, message="the message")
response_content = res.summary

# Evaluate the summary here and provide feedback. Pretending I am going to perform DPO on the response.

# preference_data will be passed on as-is to the custom model client's update_model implementation
# so it should be in the format that the custom model client expects and is completely up to the author of the custom model client
preference_data = [("this is what the response should have been like", response_content)]

update_model_stats = update_model(assistant, preference_data, user_proxy)
```
3 changes: 3 additions & 0 deletions samples/tools/finetuning/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .update_model import update_model

__all__ = ["update_model"]
93 changes: 93 additions & 0 deletions samples/tools/finetuning/finetuning/update_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient
from typing import Any, Dict, List, Protocol


class UpdateableModelClient(ModelClient, Protocol):
def update_model(
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted.

Learn from the preference data.

Args:
preference_data: The preference data.
inference_messages: The messages used for inference.
**kwargs: other arguments.

Returns:
Dict of learning stats.
"""
... # pragma: no cover


def _client_wrapper_update_model(
oai_wrapper_client: OpenAIWrapper,
preference_data: List[Any],
inference_messages: List[Dict[str, Any]],
**kwargs: Any,
) -> Dict[str, Any]:
"""Learn from the preference data.

update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages.

Args:
oai_wrapper_client: The OpenAIWrapper client.
preference_data: The preference data.
inference_messages: The messages that were used during inference between the agent that is being updated and another agent.
**kwargs: other arguments.

Returns:
Learning stats.

Raises:
ValueError: If multiple model clients are registered.
NotImplementedError: If update_model is not implemented for the client.
"""

clients = oai_wrapper_client._clients

if len(clients) != 1:
raise ValueError("update_model is not supported for multiple model clients.")
client = clients[0]
if hasattr(client, "update_model") and callable(getattr(client, "update_model")):
return client.update_model(preference_data, inference_messages, **kwargs)
else:
raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.")


def update_model(
update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs
) -> Dict[str, Any]:
"""Update the model using the preference data and the conversation history.

Args:
update_agent (ConversableAgent): the agent whose model will be updated.
preference_data (List[Dict]): a list of dictionaries containing the preference data.
other_agent (Agent): the agent whose conversation history will be used to update the model.
**kwargs: additional keyword arguments for the update model function.

Returns:
Dict: a dictionary containing the update stats, inference_messages, and preference data, like so:
{
"update_stats": update_model_stats,
"inference_messages": inference_messages,
"preference_data": preference_data
}

Raises:
ValueError: If no OpenAIWrapper client is found.
ValueError: If multiple model clients are registered.
NotImplementedError: If update_model is not implemented for the underlying client.
"""
if update_agent.client is None:
raise ValueError("No OpenAIWrapper client is found.")
inference_messages = update_agent._oai_messages[other_agent]
update_model_stats = _client_wrapper_update_model(
update_agent.client, preference_data, inference_messages, **kwargs
)
return {
"update_stats": update_model_stats,
"inference_messages": inference_messages,
"preference_data": preference_data,
}
Loading
Loading