-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Graph RAG] Init Commit with GraphRag interfaces (#3388)
* [Graph RAG] Init Commit with GraphRag interfaces * Add Document class for graph rag input document * Add Graph RAG Capability * Add unit test for graph rag interfaces --------- Co-authored-by: Li Jiang <[email protected]> Co-authored-by: gagb <[email protected]>
- Loading branch information
1 parent
8be437a
commit e579a46
Showing
6 changed files
with
151 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum, auto | ||
from typing import Optional | ||
|
||
|
||
class DocumentType(Enum): | ||
""" | ||
Enum for supporting document type. | ||
""" | ||
|
||
TEXT = auto() | ||
HTML = auto() | ||
PDF = auto() | ||
|
||
|
||
@dataclass | ||
class Document: | ||
""" | ||
A wrapper of graph store query results. | ||
""" | ||
|
||
doctype: DocumentType | ||
data: Optional[object] = None | ||
path_or_url: Optional[str] = "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional, Protocol | ||
|
||
from .document import Document | ||
|
||
|
||
@dataclass | ||
class GraphStoreQueryResult: | ||
""" | ||
A wrapper of graph store query results. | ||
answer: human readable answer to question/query. | ||
results: intermediate results to question/query, e.g. node entities. | ||
""" | ||
|
||
answer: Optional[str] = None | ||
results: list = field(default_factory=list) | ||
|
||
|
||
class GraphQueryEngine(Protocol): | ||
"""An abstract base class that represents a graph query engine on top of a underlying graph database. | ||
This interface defines the basic methods for graph rag. | ||
""" | ||
|
||
def init_db(self, input_doc: List[Document] | None = None): | ||
""" | ||
This method initializes graph database with the input documents or records. | ||
Usually, it takes the following steps, | ||
1. connecting to a graph database. | ||
2. extract graph nodes, edges based on input data, graph schema and etc. | ||
3. build indexes etc. | ||
Args: | ||
input_doc: a list of input documents that are used to build the graph in database. | ||
Returns: GraphStore | ||
""" | ||
pass | ||
|
||
def add_records(self, new_records: List) -> bool: | ||
""" | ||
Add new records to the underlying database and add to the graph if required. | ||
""" | ||
pass | ||
|
||
def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult: | ||
""" | ||
This method transform a string format question into database query and return the result. | ||
""" | ||
pass |
56 changes: 56 additions & 0 deletions
56
autogen/agentchat/contrib/graph_rag/graph_rag_capability.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability | ||
from autogen.agentchat.conversable_agent import ConversableAgent | ||
|
||
from .graph_query_engine import GraphQueryEngine | ||
|
||
|
||
class GraphRagCapability(AgentCapability): | ||
""" | ||
A graph rag capability uses a graph query engine to give a conversable agent the graph rag ability. | ||
An agent class with graph rag capability could | ||
1. create a graph in the underlying database with input documents. | ||
2. retrieved relevant information based on messages received by the agent. | ||
3. generate answers from retrieved information and send messages back. | ||
For example, | ||
graph_query_engine = GraphQueryEngine(...) | ||
graph_query_engine.init_db([Document(doc1), Document(doc2), ...]) | ||
graph_rag_agent = ConversableAgent( | ||
name="graph_rag_agent", | ||
max_consecutive_auto_reply=3, | ||
... | ||
) | ||
graph_rag_capability = GraphRagCapbility(graph_query_engine) | ||
graph_rag_capability.add_to_agent(graph_rag_agent) | ||
user_proxy = UserProxyAgent( | ||
name="user_proxy", | ||
code_execution_config=False, | ||
is_termination_msg=lambda msg: "TERMINATE" in msg["content"], | ||
human_input_mode="ALWAYS", | ||
) | ||
user_proxy.initiate_chat(graph_rag_agent, message="Name a few actors who've played in 'The Matrix'") | ||
# ChatResult( | ||
# chat_id=None, | ||
# chat_history=[ | ||
# {'content': 'Name a few actors who've played in \'The Matrix\'', 'role': 'graph_rag_agent'}, | ||
# {'content': 'A few actors who have played in The Matrix are: | ||
# - Keanu Reeves | ||
# - Laurence Fishburne | ||
# - Carrie-Anne Moss | ||
# - Hugo Weaving', | ||
# 'role': 'user_proxy'}, | ||
# ...) | ||
""" | ||
|
||
def __init__(self, query_engine: GraphQueryEngine): | ||
""" | ||
initialize graph rag capability with a graph query engine | ||
""" | ||
... | ||
|
||
def add_to_agent(self, agent: ConversableAgent): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from unittest.mock import Mock | ||
|
||
from autogen.agentchat.contrib.graph_rag.graph_query_engine import GraphQueryEngine | ||
from autogen.agentchat.contrib.graph_rag.graph_rag_capability import GraphRagCapability | ||
from autogen.agentchat.conversable_agent import ConversableAgent | ||
|
||
|
||
def test_dry_run(): | ||
"""Dry run for basic graph rag objects.""" | ||
mock_graph_query_engine = Mock(spec=GraphQueryEngine) | ||
|
||
graph_rag_agent = ConversableAgent( | ||
name="graph_rag_agent", | ||
max_consecutive_auto_reply=3, | ||
) | ||
graph_rag_capability = GraphRagCapability(mock_graph_query_engine) | ||
graph_rag_capability.add_to_agent(graph_rag_agent) |