Skip to content

Commit ce71d85

Browse files
olgavrouekzhu
andauthored
Ability to fine tune custom model on conversable agents (#1787)
* uAbility to update_model on conversable agents * formatting * formatting * move code from conversable agent into samples/tools and add testing and README * forgot install step * fix * leave core lib unchanged and move everything to samples/tools * remove skip openai --------- Co-authored-by: Eric Zhu <[email protected]>
1 parent b93e2c5 commit ce71d85

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: SamplesToolsTests
5+
6+
on:
7+
pull_request:
8+
branches: ["main"]
9+
paths:
10+
- "autogen/**"
11+
- "samples/tools/**"
12+
- ".github/workflows/samples-tools-tests.yml"
13+
- "setup.py"
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
17+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
18+
permissions: {}
19+
jobs:
20+
SamplesToolsFineTuningTests:
21+
runs-on: ${{ matrix.os }}
22+
strategy:
23+
fail-fast: false
24+
matrix:
25+
os: [ubuntu-latest, macos-latest]
26+
python-version: ["3.9", "3.10", "3.11"]
27+
steps:
28+
- uses: actions/checkout@v3
29+
- name: Set up Python ${{ matrix.python-version }}
30+
uses: actions/setup-python@v4
31+
with:
32+
python-version: ${{ matrix.python-version }}
33+
- name: Install packages and dependencies for all tests
34+
run: |
35+
python -m pip install --upgrade pip wheel
36+
pip install -e .
37+
pip install pytest
38+
- name: Set AUTOGEN_USE_DOCKER based on OS
39+
shell: bash
40+
run: |
41+
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
42+
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
43+
fi
44+
- name: Test finetuning tools
45+
run: |
46+
pytest samples/tools/finetuning/tests/

samples/tools/finetuning/README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Tools for fine-tuning the local models that power agents
2+
3+
This directory aims to contain tools for fine-tuning the local models that power agents.
4+
5+
## Fine tune a custom model client
6+
7+
AutoGen supports the use of custom models to power agents [see blog post here](https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models). This directory contains a tool to provide feedback to that model, that can be used to fine-tune the model.
8+
9+
The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms.
10+
11+
Custom Model Client will have follow the protocol client defined in `update_model.py` `UpdateableModelClient` which is a subclass of `ModelClient` and adds the following method:
12+
13+
```python
14+
def update_model(
15+
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any
16+
) -> Dict[str, Any]:
17+
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted.
18+
19+
Learn from the preference data.
20+
21+
Args:
22+
preference_data: The preference data.
23+
inference_messages: The messages that were used during inference between the agent that is being updated and another agent.
24+
**kwargs: other arguments.
25+
26+
Returns:
27+
Dict of learning stats.
28+
"""
29+
```
30+
31+
The function provided in the file `update_model.py` is called by passing these arguments:
32+
33+
- the agent whose model is to be updated
34+
- the preference data
35+
- the agent whose conversation is being used to provide the inference messages
36+
37+
The function will find the conversation thread that occurred between the "update agent" and the "other agent", and call the `update_model` method of the model client. It will return a dictionary containing the update stats, inference messages, and preference data:
38+
39+
```python
40+
{
41+
"update_stats": <the dictionary returned by the custom model client implementation>,
42+
"inference_messages": <message used for inference>,
43+
"preference_data": <the preference data passed in when update_model was called>
44+
}
45+
```
46+
47+
**NOTES**:
48+
49+
`inference_messages` will contain messages that were passed into the custom model client when `create` was called and a response was needed from the model. It is up to the author of the custom model client to decide which parts of the conversation are needed and how to use this data to fine-tune the model.
50+
51+
If a conversation has been long-running before `update_model` is called, then the `inference_messages` will contain a conversation thread that was used for multiple inference steps. It is again up to the author of the custom model client to decide which parts of the conversation correspond to the preference data and how to use this data to fine-tune the model.
52+
53+
An example of how to use this tool is shown below:
54+
55+
```python
56+
from finetuning.update_model import update_model
57+
58+
assistant = AssistantAgent(
59+
"assistant",
60+
system_message="You are a helpful assistant.",
61+
human_input_mode="NEVER",
62+
llm_config={
63+
"config_list": [<the config list containing the custom model>],
64+
},
65+
)
66+
67+
assistant.register_model_client(model_client_cls=<TheCustomModelClientClass>)
68+
69+
user_proxy = UserProxyAgent(
70+
"user_proxy",
71+
human_input_mode="NEVER",
72+
max_consecutive_auto_reply=1,
73+
code_execution_config=False,
74+
llm_config=False,
75+
)
76+
77+
res = user_proxy.initiate_chat(assistant, message="the message")
78+
response_content = res.summary
79+
80+
# Evaluate the summary here and provide feedback. Pretending I am going to perform DPO on the response.
81+
82+
# preference_data will be passed on as-is to the custom model client's update_model implementation
83+
# so it should be in the format that the custom model client expects and is completely up to the author of the custom model client
84+
preference_data = [("this is what the response should have been like", response_content)]
85+
86+
update_model_stats = update_model(assistant, preference_data, user_proxy)
87+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .update_model import update_model
2+
3+
__all__ = ["update_model"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient
2+
from typing import Any, Dict, List, Protocol
3+
4+
5+
class UpdateableModelClient(ModelClient, Protocol):
6+
def update_model(
7+
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any
8+
) -> Dict[str, Any]:
9+
"""Optional method to learn from the preference data, if the model supports learning. Can be omitted.
10+
11+
Learn from the preference data.
12+
13+
Args:
14+
preference_data: The preference data.
15+
inference_messages: The messages used for inference.
16+
**kwargs: other arguments.
17+
18+
Returns:
19+
Dict of learning stats.
20+
"""
21+
... # pragma: no cover
22+
23+
24+
def _client_wrapper_update_model(
25+
oai_wrapper_client: OpenAIWrapper,
26+
preference_data: List[Any],
27+
inference_messages: List[Dict[str, Any]],
28+
**kwargs: Any,
29+
) -> Dict[str, Any]:
30+
"""Learn from the preference data.
31+
32+
update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages.
33+
34+
Args:
35+
oai_wrapper_client: The OpenAIWrapper client.
36+
preference_data: The preference data.
37+
inference_messages: The messages that were used during inference between the agent that is being updated and another agent.
38+
**kwargs: other arguments.
39+
40+
Returns:
41+
Learning stats.
42+
43+
Raises:
44+
ValueError: If multiple model clients are registered.
45+
NotImplementedError: If update_model is not implemented for the client.
46+
"""
47+
48+
clients = oai_wrapper_client._clients
49+
50+
if len(clients) != 1:
51+
raise ValueError("update_model is not supported for multiple model clients.")
52+
client = clients[0]
53+
if hasattr(client, "update_model") and callable(getattr(client, "update_model")):
54+
return client.update_model(preference_data, inference_messages, **kwargs)
55+
else:
56+
raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.")
57+
58+
59+
def update_model(
60+
update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs
61+
) -> Dict[str, Any]:
62+
"""Update the model using the preference data and the conversation history.
63+
64+
Args:
65+
update_agent (ConversableAgent): the agent whose model will be updated.
66+
preference_data (List[Dict]): a list of dictionaries containing the preference data.
67+
other_agent (Agent): the agent whose conversation history will be used to update the model.
68+
**kwargs: additional keyword arguments for the update model function.
69+
70+
Returns:
71+
Dict: a dictionary containing the update stats, inference_messages, and preference data, like so:
72+
{
73+
"update_stats": update_model_stats,
74+
"inference_messages": inference_messages,
75+
"preference_data": preference_data
76+
}
77+
78+
Raises:
79+
ValueError: If no OpenAIWrapper client is found.
80+
ValueError: If multiple model clients are registered.
81+
NotImplementedError: If update_model is not implemented for the underlying client.
82+
"""
83+
if update_agent.client is None:
84+
raise ValueError("No OpenAIWrapper client is found.")
85+
inference_messages = update_agent._oai_messages[other_agent]
86+
update_model_stats = _client_wrapper_update_model(
87+
update_agent.client, preference_data, inference_messages, **kwargs
88+
)
89+
return {
90+
"update_stats": update_model_stats,
91+
"inference_messages": inference_messages,
92+
"preference_data": preference_data,
93+
}

0 commit comments

Comments
 (0)