|
| 1 | +from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient |
| 2 | +from typing import Any, Dict, List, Protocol |
| 3 | + |
| 4 | + |
| 5 | +class UpdateableModelClient(ModelClient, Protocol): |
| 6 | + def update_model( |
| 7 | + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any |
| 8 | + ) -> Dict[str, Any]: |
| 9 | + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. |
| 10 | +
|
| 11 | + Learn from the preference data. |
| 12 | +
|
| 13 | + Args: |
| 14 | + preference_data: The preference data. |
| 15 | + inference_messages: The messages used for inference. |
| 16 | + **kwargs: other arguments. |
| 17 | +
|
| 18 | + Returns: |
| 19 | + Dict of learning stats. |
| 20 | + """ |
| 21 | + ... # pragma: no cover |
| 22 | + |
| 23 | + |
| 24 | +def _client_wrapper_update_model( |
| 25 | + oai_wrapper_client: OpenAIWrapper, |
| 26 | + preference_data: List[Any], |
| 27 | + inference_messages: List[Dict[str, Any]], |
| 28 | + **kwargs: Any, |
| 29 | +) -> Dict[str, Any]: |
| 30 | + """Learn from the preference data. |
| 31 | +
|
| 32 | + update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. |
| 33 | +
|
| 34 | + Args: |
| 35 | + oai_wrapper_client: The OpenAIWrapper client. |
| 36 | + preference_data: The preference data. |
| 37 | + inference_messages: The messages that were used during inference between the agent that is being updated and another agent. |
| 38 | + **kwargs: other arguments. |
| 39 | +
|
| 40 | + Returns: |
| 41 | + Learning stats. |
| 42 | +
|
| 43 | + Raises: |
| 44 | + ValueError: If multiple model clients are registered. |
| 45 | + NotImplementedError: If update_model is not implemented for the client. |
| 46 | + """ |
| 47 | + |
| 48 | + clients = oai_wrapper_client._clients |
| 49 | + |
| 50 | + if len(clients) != 1: |
| 51 | + raise ValueError("update_model is not supported for multiple model clients.") |
| 52 | + client = clients[0] |
| 53 | + if hasattr(client, "update_model") and callable(getattr(client, "update_model")): |
| 54 | + return client.update_model(preference_data, inference_messages, **kwargs) |
| 55 | + else: |
| 56 | + raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") |
| 57 | + |
| 58 | + |
| 59 | +def update_model( |
| 60 | + update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs |
| 61 | +) -> Dict[str, Any]: |
| 62 | + """Update the model using the preference data and the conversation history. |
| 63 | +
|
| 64 | + Args: |
| 65 | + update_agent (ConversableAgent): the agent whose model will be updated. |
| 66 | + preference_data (List[Dict]): a list of dictionaries containing the preference data. |
| 67 | + other_agent (Agent): the agent whose conversation history will be used to update the model. |
| 68 | + **kwargs: additional keyword arguments for the update model function. |
| 69 | +
|
| 70 | + Returns: |
| 71 | + Dict: a dictionary containing the update stats, inference_messages, and preference data, like so: |
| 72 | + { |
| 73 | + "update_stats": update_model_stats, |
| 74 | + "inference_messages": inference_messages, |
| 75 | + "preference_data": preference_data |
| 76 | + } |
| 77 | +
|
| 78 | + Raises: |
| 79 | + ValueError: If no OpenAIWrapper client is found. |
| 80 | + ValueError: If multiple model clients are registered. |
| 81 | + NotImplementedError: If update_model is not implemented for the underlying client. |
| 82 | + """ |
| 83 | + if update_agent.client is None: |
| 84 | + raise ValueError("No OpenAIWrapper client is found.") |
| 85 | + inference_messages = update_agent._oai_messages[other_agent] |
| 86 | + update_model_stats = _client_wrapper_update_model( |
| 87 | + update_agent.client, preference_data, inference_messages, **kwargs |
| 88 | + ) |
| 89 | + return { |
| 90 | + "update_stats": update_model_stats, |
| 91 | + "inference_messages": inference_messages, |
| 92 | + "preference_data": preference_data, |
| 93 | + } |
0 commit comments