Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphrag integration #4612

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e3e8f45
add initial global search draft
lpinheiroms Dec 7, 2024
8242378
add graphrag dep
lpinheiroms Dec 9, 2024
fb2fb19
Merge branch 'main' into lpinheiro/feat/add-graphrag-tools
lpinheiroms Dec 10, 2024
a13c18b
fix local search embedding
lpinheiroms Dec 17, 2024
8f3c484
linting
lpinheiroms Dec 17, 2024
0c05047
add from config constructor
lpinheiroms Dec 17, 2024
0e53f91
Merge branch 'main' into lpinheiro/feat/add-graphrag-tools
lspinheiro Dec 17, 2024
c1e7ea2
remove draft notebook
lpinheiroms Dec 17, 2024
a8b38ad
Merge branch 'main' into lpinheiro/feat/add-graphrag-tools
lspinheiro Dec 19, 2024
6d61c8e
update config factory and add docstrings
lpinheiroms Dec 20, 2024
1c4ed3d
add graphrag sample
lpinheiroms Dec 20, 2024
95f329c
add sample prompts
lpinheiroms Dec 20, 2024
3bc104b
update readme
lpinheiroms Dec 20, 2024
2ae6812
Merge branch 'main' into lpinheiro/feat/add-graphrag-tools
lspinheiro Dec 20, 2024
33523df
update deps
lpinheiroms Dec 20, 2024
8080ddb
Add API docs
ekzhu Dec 30, 2024
603c1c9
Update python/samples/agentchat_graphrag/requirements.txt
ekzhu Dec 30, 2024
934230b
Update python/samples/agentchat_graphrag/requirements.txt
ekzhu Dec 30, 2024
1c5fcd3
merge main, fix conflicts
lpinheiroms Dec 30, 2024
4f0c71f
update docstrings with snippet and doc ref
lpinheiroms Dec 30, 2024
e3dc1f9
lint
lpinheiroms Dec 30, 2024
f24fb6c
improve set up instructions in docstring
lpinheiroms Jan 3, 2025
4a5d611
lint
lpinheiroms Jan 3, 2025
74a2a23
Merge branch 'main' into lpinheiro/feat/add-graphrag-tools
lpinheiroms Jan 3, 2025
cac2aef
update lock
lpinheiroms Jan 3, 2025
e42f027
Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_gl…
lspinheiro Jan 4, 2025
e60a9aa
Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_lo…
lspinheiro Jan 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"]
docker = ["docker~=7.0"]
openai = ["openai>=1.3", "aiofiles"]
graphrag = ["graphrag==0.9.0"]
file-surfer = ["markitdown>=0.0.1a2"]
file-surfer = [
"autogen-agentchat==0.4.0.dev11",
"markitdown>=0.0.1a2",
]
graphrag = ["graphrag>=1.0.0"]
web-surfer = [
"playwright>=1.48.0",
"pillow>=11.0.0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
from ._global_search import ContextConfig, DataConfig, GlobalSearchTool, MapReduceConfig
from ._config import (
EmbeddingConfig,
GlobalContextConfig,
GlobalDataConfig,
LocalContextConfig,
LocalDataConfig,
MapReduceConfig,
SearchConfig,
)
from ._global_search import GlobalSearchTool
from ._local_search import LocalSearchTool

__all__ = ["GlobalSearchTool", "DataConfig", "ContextConfig", "MapReduceConfig"]
__all__ = [
"GlobalSearchTool",
"LocalSearchTool",
"GlobalDataConfig",
"LocalDataConfig",
"GlobalContextConfig",
"LocalContextConfig",
"MapReduceConfig",
"SearchConfig",
"EmbeddingConfig",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Callable, Literal, Optional

from pydantic import BaseModel


class DataConfig(BaseModel):
input_dir: str
entity_table: str = "create_final_nodes"
entity_embedding_table: str = "create_final_entities"
community_level: int = 2


class GlobalDataConfig(DataConfig):
community_table: str = "create_final_communities"
community_report_table: str = "create_final_community_reports"


class LocalDataConfig(DataConfig):
relationship_table: str = "create_final_relationships"
text_unit_table: str = "create_final_text_units"


class ContextConfig(BaseModel):
max_data_tokens: int = 8000


class GlobalContextConfig(ContextConfig):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000


class LocalContextConfig(ContextConfig):
text_unit_prop: float = 0.5
community_prop: float = 0.25
include_entity_rank: bool = True
rank_description: str = "number of relationships"
include_relationship_weight: bool = True
relationship_ranking_attribute: str = "rank"


class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"


class SearchConfig(BaseModel):
max_tokens: int = 1500
temperature: float = 0.0
response_type: str = "multiple paragraphs"


class EmbeddingConfig(BaseModel):
api_key: Optional[str] = None
model: str
api_base: Optional[str] = None
deployment_name: Optional[str] = None
api_version: Optional[str] = None
api_type: Literal["azure", "openai"] = "openai"
organization: Optional[str] = None
azure_ad_token_provider: Optional[Callable[[], str]] = None
max_retries: int = 10
request_timeout: float = 180.0
Original file line number Diff line number Diff line change
@@ -1,57 +1,25 @@
# tool_global_search.py

import json
from typing import Any

# mypy: disable-error-code="no-any-unimported,misc"
import pandas as pd
import tiktoken
from autogen_core import CancellationToken
from autogen_core.components.tools import BaseTool
from autogen_core.tools import BaseTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
from pydantic import BaseModel, Field

from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from pydantic import BaseModel, Field

from autogen_ext.models.openai import OpenAIChatCompletionClient

from ._config import GlobalContextConfig as ContextConfig
from ._config import GlobalDataConfig as DataConfig
from ._config import MapReduceConfig
from ._model_adapter import GraphragOpenAiModelAdapter


class DataConfig(BaseModel):
input_dir: str
community_table: str = "create_final_communities"
community_report_table: str = "create_final_community_reports"
entity_table: str = "create_final_nodes"
entity_embedding_table: str = "create_final_entities"
community_level: int = 2


class ContextConfig(BaseModel):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000


class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"


_default_context_config = ContextConfig()
_default_mapreduce_config = MapReduceConfig()

Expand All @@ -60,32 +28,37 @@ class GlobalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform global search on.")


class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, str]):
class GlobalSearchToolReturn(BaseModel):
answer: str


class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
def __init__(
self,
openai_client: OpenAIChatCompletionClient,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
):
super().__init__(
args_type=GlobalSearchToolArgs,
return_type=str,
return_type=GlobalSearchToolReturn,
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# We use the adapter here
self._llm_adapter = GraphragOpenAiModelAdapter(openai_client)

# Set up credentials and LLM
model_name = self._llm_adapter._client._raw_config["model"]
token_encoder = tiktoken.encoding_for_model(model_name)
# Use the provided LLM
self._llm = llm

# Load parquet files
community_df = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet")
entity_df = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet")
report_df = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_report_table}.parquet")
entity_embedding_df = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet")
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)

communities = read_indexer_communities(community_df, entity_df, report_df)
reports = read_indexer_reports(report_df, entity_df, data_config.community_level)
Expand Down Expand Up @@ -123,7 +96,7 @@ def __init__(
}

self._search_engine = GlobalSearch(
llm=self._llm_adapter,
llm=self._llm,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
Expand All @@ -136,6 +109,37 @@ def __init__(
response_type=mapreduce_config.response_type,
)

async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> str:
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
result = await self._search_engine.asearch(args.query)
return result.response
assert isinstance(result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=result.response)

@classmethod
def from_config(
cls,
openai_client: AzureOpenAIChatCompletionClient,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from configuration.

Args:
openai_client: The Azure OpenAI client to use
data_config: Configuration for data sources
context_config: Configuration for context building
mapreduce_config: Configuration for map-reduce operations

Returns:
An initialized GlobalSearchTool instance
"""
llm_adapter = GraphragOpenAiModelAdapter(openai_client)
token_encoder = tiktoken.encoding_for_model(llm_adapter.model_name)

return cls(
token_encoder=token_encoder,
llm=llm_adapter,
data_config=data_config,
context_config=context_config,
mapreduce_config=mapreduce_config,
)
Loading