From 5d629426b0c4c8437e30b5390c4f3b542c4314fa Mon Sep 17 00:00:00 2001 From: Appointat Date: Fri, 15 Nov 2024 04:37:35 +0800 Subject: [PATCH] feat: Integrate Persona Hub Techniques into CAMEL for Enhanced Agent Diversity (#716) Co-authored-by: Zheng-Lu Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com> Co-authored-by: Harry-QY Co-authored-by: Harry Ye <116691547+harryeqs@users.noreply.github.com> Co-authored-by: Wendong --- camel/personas/__init__.py | 17 ++ camel/personas/persona.py | 86 +++++++ camel/personas/persona_hub.py | 283 +++++++++++++++++++++++ camel/prompts/__init__.py | 2 + camel/prompts/persona_hub.py | 61 +++++ examples/personas/personas_generation.py | 66 ++++++ poetry.lock | 29 ++- pyproject.toml | 1 + test/personas/test_persona_generator.py | 235 +++++++++++++++++++ 9 files changed, 765 insertions(+), 15 deletions(-) create mode 100644 camel/personas/__init__.py create mode 100644 camel/personas/persona.py create mode 100644 camel/personas/persona_hub.py create mode 100644 camel/prompts/persona_hub.py create mode 100644 examples/personas/personas_generation.py create mode 100644 test/personas/test_persona_generator.py diff --git a/camel/personas/__init__.py b/camel/personas/__init__.py new file mode 100644 index 0000000000..69efbe3e44 --- /dev/null +++ b/camel/personas/__init__.py @@ -0,0 +1,17 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from .persona import Persona +from .persona_hub import PersonaHub + +__all__ = ['Persona', 'PersonaHub'] diff --git a/camel/personas/persona.py b/camel/personas/persona.py new file mode 100644 index 0000000000..b1201d85d5 --- /dev/null +++ b/camel/personas/persona.py @@ -0,0 +1,86 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import uuid +from typing import ClassVar, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from camel.prompts import PersonaHubPrompt, TextPrompt + + +class Persona(BaseModel): + r"""A persona is a character in the society. + + Attributes: + name (Optional[str]): Name of the persona. + description (Optional[str]): Description of the persona. + t2p_prompt (Union[TextPrompt, str]): Text to Persona Prompt. + p2p_prompt (Union[TextPrompt, str]): Persona to Persona Prompt. + id (uuid.UUID): The unique identifier for the persona, automatically + generated. + """ + + name: Optional[str] = None + description: Optional[str] = None + _id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) + + # Field with default_factory to avoid circular import issues + # Union type allows either TextPrompt or str + t2p_prompt: Union[TextPrompt, str] = Field( + default_factory=lambda: PersonaHubPrompt.TEXT_TO_PERSONA, + description="Text to Persona Prompt", + ) + + # Similar to t2p_prompt, using default_factory for lazy evaluation + p2p_prompt: Union[TextPrompt, str] = Field( + default_factory=lambda: PersonaHubPrompt.PERSONA_TO_PERSONA, + description="Persona to Persona Prompt", + ) + + # Class-level configuration for Pydantic model + # ClassVar indicates this is a class variable, not an instance variable + model_config: ClassVar[ConfigDict] = ConfigDict( + # Allow the use of custom types TextPrompt + arbitrary_types_allowed=True, + # Custom JSON schema configuration + json_schema_extra={ + "properties": { + # Ensure t2p_prompt and p2p_prompt are treated as strings in + # JSON schema + "t2p_prompt": {"type": "string"}, + "p2p_prompt": {"type": "string"}, + } + }, + ) + + @property + def id(self) -> uuid.UUID: + return self._id + + @classmethod + def model_json_schema(cls): + schema = super().schema() + schema['properties']['id'] = {'type': 'string', 'format': 'uuid'} + return schema + + def dict(self, *args, **kwargs): + # Output: {'name': 'Alice', 'description': None, 't2p_prompt': '...', 'p2p_prompt': '...', 'id': 'f47ac10b-58cc-4372-a567-0e02b2c3d479'} # noqa: E501 + d = super().model_dump(*args, **kwargs) + d['id'] = str(self.id) + return d + + def json(self, *args, **kwargs): + # Output: '{"name": "Alice", "description": null, "t2p_prompt": "...", "p2p_prompt": "...", "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"}' # noqa: E501 + d = self.dict(*args, **kwargs) + return super().json(d, *args, **kwargs) diff --git a/camel/personas/persona_hub.py b/camel/personas/persona_hub.py new file mode 100644 index 0000000000..11c285f5f3 --- /dev/null +++ b/camel/personas/persona_hub.py @@ -0,0 +1,283 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import ast +import re +import uuid +from functools import lru_cache +from typing import Dict, List, Literal, Optional, Union + +import numpy as np +from pydantic import BaseModel, Field + +from camel.agents import ChatAgent +from camel.embeddings import BaseEmbedding +from camel.models import BaseModelBackend +from camel.personas import Persona +from camel.prompts import TextPrompt + + +# Set structured output schema +class PersonaResponse(BaseModel): + persona_name: str = Field(description="The name of the persona") + persona_description: str = Field( + description="The description of the persona" + ) + + +class PersonaHub: + r"""PersonaHub proposes a novel persona-driven data synthesis methodology + that leverages various perspectives within a large language model (LLM) to + create diverse synthetic data. By showcasing PersonaHub's use cases in + synthesizing high-quality mathematical and logical reasoning problems, + instructions (i.e., user prompts), knowledge-rich texts, game NPCs and + tools (functions) at scale, the authors demonstrate persona-driven data + synthesis is versatile, scalable, flexible, and easy to use, potentially + driving a paradigm shift in synthetic data creation and applications in + practice, which may have a profound impact on LLM research and development. + Please refer to the paper for more details: https://arxiv.org/pdf/2406.20094 + + Args: + model (BaseModelBackend, optional): The model to use for persona + generation and manipulation. (default: :obj:`None`) + """ + + def __init__( + self, + model: Optional[BaseModelBackend] = None, + ): + self.model = model + self.personas: Dict[uuid.UUID, Persona] = {} + + def __setitem__(self, persona: Persona): + r"""Add a persona to the group. + + Args: + persona (Persona): The persona to add. + """ + self.personas[persona.id] = persona + + def __delitem__(self, persona_id: uuid.UUID): + r"""Remove a persona from the group by ID. + + Args: + persona_id (uuid.UUID): The ID of the persona to remove. + """ + if persona_id in self.personas: + del self.personas[persona_id] + else: + raise KeyError("Persona ID not found") + + def __getitem__(self, persona_id: uuid.UUID) -> Persona: + r"""Get a persona by ID. + + Args: + persona_id (uuid.UUID): The ID of the persona to retrieve. + """ + if persona_id in self.personas: + return self.personas[persona_id] + else: + raise KeyError("Persona ID not found") + + def text_to_persona( + self, + text: str, + action: Literal["read", "write", "like", "dislike"] = "read", + ) -> Persona: + r"""Infers a specific persona who is likely to [read|write|like|dislike + |...] the given text. + + Args: + text (str): The input text for which to infer a persona. + action (str): The action associated with the persona (default is + "read"). + + Returns: + Persona: The inferred persona. + """ + persona = Persona() + + t2p_prompt: Union[TextPrompt, str] = persona.t2p_prompt + t2p_prompt_instruction = t2p_prompt.format(action=action, text=text) + + # Set Agent to generate personal + t2p_agent = ChatAgent( + system_message="You are a helpful assistant", model=self.model + ) + t2p_agent.reset() + + # Get output from agent + try: + response = t2p_agent.step( + t2p_prompt_instruction, + response_format=PersonaResponse, # type: ignore[arg-type] + ) + parsed_content = ast.literal_eval(response.msg.content) + persona.name = parsed_content["persona_name"] + persona.description = parsed_content["persona_description"] + except Exception as e: + raise RuntimeError(f"Text to persona step failed: {e}") + + return persona + + def persona_to_persona(self, persona: Persona) -> Dict[uuid.UUID, Persona]: + r"""Derives additional personas based on interpersonal relationships + from this persona. + + Args: + persona (Persona): The persona from which to derive related + personas. + + Returns: + Dict[uuid.UUID, Persona]: A dictionary of related personas. + """ + p2p_prompt: Union[TextPrompt, str] = persona.p2p_prompt + answer_template = """ +You MUST answer the question according to the format of the ANSWER TEMPLATE, and you can only modify the content within . +===== ANSWER TEMPLATE ===== +1. persona_name: +persona_description: +... +n. persona_name: +persona_description: +""" # noqa: E501 + p2p_prompt_instruction = ( + p2p_prompt.format( + persona_name=persona.name, + persona_description=persona.description, + ) + + answer_template + ) + + p2p_agent = ChatAgent( + system_message="You're a helpful assistant.", model=self.model + ) + p2p_agent.reset() + + # Get output from agent + try: + response = p2p_agent.step( + p2p_prompt_instruction # type: ignore[arg-type] + ) + # Structured output (TODO: Use a more robust parser) + pattern = r"(\d+)\.\s*persona_name:\s*(.*?)\s*persona_description:\s*(.*?)\s*(?=\d+\.|$)" # noqa: E501 + matches = re.findall(pattern, response.msg.content, re.DOTALL) + + personas: Dict[uuid.UUID, Persona] = {} + for match in matches: + name = match[1].strip() + description = match[2].strip() + new_persona = Persona(name=name, description=description) + personas[new_persona.id] = new_persona + except Exception as e: + raise RuntimeError(f"Persona to persona step failed: {e}") + + return personas + + def deduplicate( + self, + embedding_model: Optional[BaseEmbedding] = None, + similarity_threshold: float = 0.85, + ) -> None: + r"""Remove similar personas from the group. + + Args: + embedding_model (BaseEmbedding): The embedding model + for similarity compairsion. (default is `None`). + similarity_threshold (float): The similarity threshold for + deduplication (default is `0.85`). + """ + # Changed to default similarity threshold to 0.85 as the default + # text-embedding-3-small model may give lower similarities than others + # This is a simplified version. Need to implement a more + # sophisticated deduplication algorithm as described in the paper. + if not embedding_model: + from camel.embeddings import OpenAIEmbedding + + embedding_model = OpenAIEmbedding() + unique_personas: Dict[uuid.UUID, Persona] = {} + for persona_id, persona in self.personas.items(): + if not any( + self._is_similar( + persona, up, similarity_threshold, embedding_model + ) + for up in unique_personas.values() + ): + unique_personas[persona_id] = persona + self.personas = unique_personas + + @staticmethod + @lru_cache(maxsize=128) + def _get_embedding( + embedding_model: BaseEmbedding, description: Optional[str] + ) -> list[float]: + r"""Cache embeddings to reduce recomputation.""" + return embedding_model.embed(description) + + @staticmethod + def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float: + r"""Copmute the cosine similarity of two vectors. + + Args: + vec1 (np.ndarray): Vector 1 + vec2 (np.ndarray): Vector 2 + """ + return np.dot(vec1, vec2) / ( + np.linalg.norm(vec1) * np.linalg.norm(vec2) + ) + + def _is_similar( + self, + persona1: Persona, + persona2: Persona, + similarity_threshold: float, + embedding_model: BaseEmbedding, + ) -> bool: + r"""Check if two personas are similar by consine similarity + of the embeddings of their descriptions. + + Args: + persona1 (Persona1): A persona. + persona2 (Persona2): The other persona. + similarity_threshold (float): The threshold on consine similarity + to determine whether the two personas are similar. + embedding_model (BaseEmbedding): The embedding model + for similarity compairsion. + """ + + # Ensure persona descriptions are not None + persona1_description = persona1.description or "" + persona2_description = persona2.description or "" + + persona1_embeddings = self._get_embedding( + embedding_model, persona1_description + ) + persona2_embeddings = self._get_embedding( + embedding_model, persona2_description + ) + + similarity = self._cosine_similarity( + np.array(persona1_embeddings), np.array(persona2_embeddings) + ) + + return similarity >= similarity_threshold + + def __len__(self): + return len(self.personas) + + def __iter__(self): + return iter(self.personas.values()) + + def get_all_personas(self) -> List[Persona]: + r"""Return a list of all personas.""" + return list(self.personas.values()) diff --git a/camel/prompts/__init__.py b/camel/prompts/__init__.py index 5bf39c7f17..af996937d5 100644 --- a/camel/prompts/__init__.py +++ b/camel/prompts/__init__.py @@ -24,6 +24,7 @@ MultiConditionImageCraftPromptTemplateDict, ) from .object_recognition import ObjectRecognitionPromptTemplateDict +from .persona_hub import PersonaHubPrompt from .prompt_templates import PromptTemplateGenerator from .role_description_prompt_template import RoleDescriptionPromptTemplateDict from .solution_extraction import SolutionExtractionPromptTemplateDict @@ -43,6 +44,7 @@ 'RoleDescriptionPromptTemplateDict', 'TaskPromptTemplateDict', 'PromptTemplateGenerator', + 'PersonaHubPrompt', 'SolutionExtractionPromptTemplateDict', 'GenerateTextEmbeddingDataPromptTemplateDict', 'ObjectRecognitionPromptTemplateDict', diff --git a/camel/prompts/persona_hub.py b/camel/prompts/persona_hub.py new file mode 100644 index 0000000000..8b5c9cf193 --- /dev/null +++ b/camel/prompts/persona_hub.py @@ -0,0 +1,61 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +from typing import Any + +from camel.prompts.base import TextPrompt, TextPromptDict + + +class PersonaHubPrompt(TextPromptDict): + r"""A dictionary containing :obj:`TextPrompt` used for generating and + relating personas based on given text or existing personas. + + This class inherits from TextPromptDict, allowing for easy access and + management of the prompts. + + Attributes: + TEXT_TO_PERSONA (TextPrompt): A prompt for inferring a persona from a + given text. This prompt asks to identify who is likely to interact + with the provided text in various ways (read, write, like, + dislike). The response should follow a specific template format. + + PERSONA_TO_PERSONA (TextPrompt): A prompt for deriving related personas + based on a given persona. This prompt asks to describe personas who + might have a close relationship with the provided persona. The + response should follow a specific template format, allowing for + multiple related personas. + """ + + TEXT_TO_PERSONA = TextPrompt(""" +Who is likely to {action} the following text? Provide a detailed and specific persona description. + +Text: {text} +""") # noqa: E501 + + PERSONA_TO_PERSONA = TextPrompt(""" +Given the following persona: +{persona_name} +{persona_description} + +Who is likely to be in a close relationship with this persona? Describe the related personas and their relationships. +""") # noqa: E501 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.update( + { + "text_to_persona": self.TEXT_TO_PERSONA, + "persona_to_persona": self.PERSONA_TO_PERSONA, + } + ) diff --git a/examples/personas/personas_generation.py b/examples/personas/personas_generation.py new file mode 100644 index 0000000000..1eef044e6b --- /dev/null +++ b/examples/personas/personas_generation.py @@ -0,0 +1,66 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +from camel.personas.persona_hub import PersonaHub + +persona_group = PersonaHub() + +# Use the text_to_persona method +example_text = """Clinical Guideline: Administration of Injections in +Pediatric Patients Purpose: To provide standardized care for pediatric +patients requiring injections, ensuring safety, ...""" + +inferred_persona = persona_group.text_to_persona(example_text, action="read") +print( + f"Inferred Persona:\n{inferred_persona.name}" + f"\n{inferred_persona.description}\n" +) + +# Use the persona_to_persona method +related_personas = persona_group.persona_to_persona(persona=inferred_persona) +print("Related Personas:\n") +for persona_id, persona in related_personas.items(): + print(f"ID: {persona_id}") + print(f"Name: {persona.name}") + print(f"Description: {persona.description}") + print() +''' +=============================================================================== +Inferred Persona: +Pediatric Nurse +A healthcare professional specializing in the care of children, with expertise in administering medications and following clinical guidelines for pediatric patients. + +Related Personas: + +ID: 123e4567-e89b-12d3-a456-426614174000 +Name: Pediatrician +Description: A medical doctor who specializes in the care of infants, children, and adolescents. They work closely with pediatric nurses to ensure proper treatment and medication administration for young patients. + +ID: 123e4567-e89b-12d3-a456-426614174001 +Name: Child Life Specialist +Description: A professional who helps children and families cope with the challenges of hospitalization, illness, and disability. They often collaborate with medical staff to make medical procedures less stressful for pediatric patients. + +ID: 123e4567-e89b-12d3-a456-426614174002 +Name: Pediatric Pharmacist +Description: A pharmacist who specializes in medications for children, ensuring proper dosing and formulations. They work with the medical team to optimize medication regimens for pediatric patients. + +ID: 123e4567-e89b-12d3-a456-426614174003 +Name: Parent or Guardian +Description: The primary caregiver of a pediatric patient, who needs to understand and consent to medical procedures, including injections. They often have concerns and questions about their child's treatment. + +ID: 123e4567-e89b-12d3-a456-426614174004 +Name: Pediatric Hospital Administrator +Description: A healthcare manager responsible for overseeing pediatric departments or hospitals. They ensure that clinical guidelines are implemented and followed to maintain high standards of care for young patients. +=============================================================================== +''' # noqa: E501 diff --git a/poetry.lock b/poetry.lock index e7e96b4820..7c0ba2b1d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3929,6 +3929,7 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, + {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, ] @@ -5907,6 +5908,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs optional = true python-versions = ">=3.8" files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] @@ -5917,6 +5919,7 @@ description = "A collection of ASN.1-based protocols modules" optional = true python-versions = ">=3.8" files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] @@ -6222,6 +6225,7 @@ python-versions = ">=3.9" files = [ {file = "PyMuPDF-1.24.13-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c11bb9198af69d490b4b346421db827d875a28fbc760d239e691d4b3ed12b5ad"}, {file = "PyMuPDF-1.24.13-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:240d5c43daa9278db50d609162b48f673ab256d7e5c73eea67af517c1fc2d47c"}, + {file = "PyMuPDF-1.24.13-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e4c8808e62afbbde0f7b9c4151c4b1a5735911c2d39c34332860df600dba76f8"}, {file = "PyMuPDF-1.24.13-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c830610e4fde237fcf0532f1f8c1381453f48c164a5eadd0c6e5fd0bea1ca8e3"}, {file = "PyMuPDF-1.24.13-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4520558580ac6b5a7164fda29fbc14e39d3114fd803420721500edbf47d04872"}, {file = "PyMuPDF-1.24.13-cp39-abi3-win32.whl", hash = "sha256:ab22828d4fc205791ef1332a64893cbfc38cd9c331c5f46ae4537372ffee6fc1"}, @@ -7341,40 +7345,30 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:943f32bc9dedb3abff9879edc134901df92cfce2c3d5c9348f172f62eb2d771d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c3829bb364fdb8e0332c9931ecf57d9be3519241323c5274bd82f709cebc0c"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:e7e3736715fbf53e9be2a79eb4db68e4ed857017344d697e8b9749444ae57475"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7e75b4965e1d4690e93021adfcecccbca7d61c7bddd8e22406ef2ff20d74ef"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bc5f1e1c28e966d61d2519f2a3d451ba989f9ea0f2307de7bc45baa526de9e45"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a0e060aace4c24dcaf71023bbd7d42674e3b230f7e7b97317baf1e953e5b519"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -7589,6 +7583,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -9899,13 +9898,13 @@ propcache = ">=0.2.0" [[package]] name = "zipp" -version = "3.20.2" +version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, - {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, + {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, + {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, ] [package.extras] @@ -9934,4 +9933,4 @@ vector-databases = ["pymilvus", "qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b49bd4106eb3ed70a5cc22e20eb21592995f49e785c30e79749569f3e0b6a4d0" +content-hash = "3b39b36b876fe09771ae568690559c83b030becae57cfb2cddd0482d21d395e7" diff --git a/pyproject.toml b/pyproject.toml index 82cc220aed..045a866fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,6 +177,7 @@ encoders = [ tools = [ "beautifulsoup4", "docx2txt", + "functools", "PyMuPDF", "wikipedia", "duckduckgo-search", diff --git a/test/personas/test_persona_generator.py b/test/personas/test_persona_generator.py new file mode 100644 index 0000000000..70a07a27e3 --- /dev/null +++ b/test/personas/test_persona_generator.py @@ -0,0 +1,235 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import uuid +from typing import Dict +from unittest.mock import MagicMock, patch + +import pytest + +from camel.embeddings import OpenAIEmbedding +from camel.messages import BaseMessage +from camel.personas import Persona, PersonaHub +from camel.types import EmbeddingModelType, RoleType + +# Mock responses +MOCK_TEXT_TO_PERSONA_RESPONSE = """ +{ + "persona_name": "Data Scientist", + "persona_description": "A professional with expertise in statistical analysis, machine learning, and data visualization. They have strong programming skills, particularly in Python and R, and are experienced in working with large datasets to extract meaningful insights." +} +""" # noqa: E501 + +MOCK_PERSONA_TO_PERSONA_RESPONSE = """ +1. persona_name: Machine Learning Engineer +persona_description: A professional who specializes in developing and implementing machine learning models. They work closely with Data Scientists to turn data insights into practical applications. +2. persona_name: Business Analyst +persona_description: A professional who bridges the gap between data insights and business strategy. They collaborate with Data Scientists to translate complex analytical findings into actionable business recommendations. +3. persona_name: Data Engineer +persona_description: A professional who designs and maintains the data infrastructure that supports the work of Data Scientists. They ensure that data is collected, stored, and processed efficiently to enable data-driven decision-making. +""" # noqa: E501 + + +@pytest.fixture +def persona_generator(): + return PersonaHub(model=MagicMock()) + + +def test_init(persona_generator: PersonaHub): + assert isinstance(persona_generator, PersonaHub) + assert isinstance(persona_generator.personas, Dict) + assert len(persona_generator.personas) == 0 + + +def test___setitem__(persona_generator: PersonaHub): + persona = Persona( + name="Test Persona", + description="Test Description", + ) + persona_generator.__setitem__(persona) + assert persona_generator.__len__() == 1 + assert persona_generator.personas[persona.id] == persona + + +def test_remove_persona(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + persona_generator.__setitem__(persona1) + persona_generator.__setitem__(persona2) + + persona_generator.__delitem__(persona1.id) + assert persona_generator.__len__() == 1 + assert persona_generator.personas[persona2.id] == persona2 + + with pytest.raises(KeyError): + persona_generator.__delitem__(uuid.uuid4()) + + +def test_get_persona(persona_generator: PersonaHub): + persona = Persona( + name="Test Persona", + description="Test Description", + ) + persona_generator.__setitem__(persona) + + assert persona_generator.__getitem__(persona.id) == persona + + with pytest.raises(KeyError): + persona_generator.__getitem__(uuid.uuid4()) + + +def test_text_to_persona(persona_generator: PersonaHub): + mock_response = MagicMock() + mock_response.terminated = False + mock_response.msg = BaseMessage( + role_name="Assistant", + role_type=RoleType.ASSISTANT, + content=MOCK_TEXT_TO_PERSONA_RESPONSE, + meta_dict={}, + ) + + with patch('camel.agents.ChatAgent.step', return_value=mock_response): + persona = persona_generator.text_to_persona(text="Test text") + + assert isinstance(persona, Persona) + assert persona.name == "Data Scientist" + assert ( + persona.description + and "expertise in statistical analysis" in persona.description + ) + + +def test_persona_to_persona(persona_generator: PersonaHub): + mock_response = MagicMock() + mock_response.terminated = False + mock_response.msg = BaseMessage( + role_name="Assistant", + role_type=RoleType.ASSISTANT, + content=MOCK_PERSONA_TO_PERSONA_RESPONSE, + meta_dict={}, + ) + + with patch('camel.agents.ChatAgent.step', return_value=mock_response): + base_persona = Persona( + name="Data Scientist", description="A data expert" + ) + related_personas = persona_generator.persona_to_persona(base_persona) + + assert isinstance(related_personas, dict) + assert len(related_personas) == 3 + assert any( + p.name == "Machine Learning Engineer" + for p in related_personas.values() + ) + assert any(p.name == "Business Analyst" for p in related_personas.values()) + assert any(p.name == "Data Engineer" for p in related_personas.values()) + + +def test_deduplicate(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + persona_generator.__setitem__(persona1) + persona_generator.__setitem__(persona2) + + persona_generator.deduplicate( + embedding_model=OpenAIEmbedding( + model_type=EmbeddingModelType.TEXT_EMBEDDING_3_SMALL + ) + ) + + assert ( + len(persona_generator.personas) == 1 + ) # Only one persona left as the persona descriptions are very similar + + +def test_is_similar(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + assert persona_generator._is_similar( + persona1=persona1, + persona2=persona2, + similarity_threshold=0.85, + embedding_model=OpenAIEmbedding( + model_type=EmbeddingModelType.TEXT_EMBEDDING_3_SMALL + ), + ) + + +def test_len(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + persona_generator.__setitem__(persona1) + persona_generator.__setitem__(persona2) + + assert persona_generator.__len__() == 2 + + +def test_iter(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + persona_generator.__setitem__(persona1) + persona_generator.__setitem__(persona2) + + personas = list(persona_generator) + assert len(personas) == 2 + assert persona1 in personas + assert persona2 in personas + + +def test_get_all_personas(persona_generator: PersonaHub): + persona1 = Persona( + name="Test Persona 1", + description="Test Description 1", + ) + persona2 = Persona( + name="Test Persona 2", + description="Test Description 2", + ) + persona_generator.__setitem__(persona1) + persona_generator.__setitem__(persona2) + + all_personas = persona_generator.get_all_personas() + assert isinstance(all_personas, list) + assert len(all_personas) == 2 + assert persona1 in all_personas + assert persona2 in all_personas