Skip to content

Commit

Permalink
Feature/merge with dev (#916)
Browse files Browse the repository at this point in the history
* Fix CLI Tests (#912)

Fix CLI tests

* Shreyas/kg runtime cfg (#913)

add kg runtime config

* rename kgenrichmentresponse (#914)

* revert change to chunking by_title

---------

Co-authored-by: Nolan Tremelling <[email protected]>
Co-authored-by: Shreyas Pimpalgaonkar <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 622b52c commit c0db4ec
Show file tree
Hide file tree
Showing 29 changed files with 899 additions and 52 deletions.
2 changes: 1 addition & 1 deletion py/cli/commands/restructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def enrich_graph(client):
Perform graph enrichment over the entire graph.
"""
with timer():
response = client.restructure()
response = client.enrich_graph()

click.echo(response)
2 changes: 2 additions & 0 deletions py/cli/utils/param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class JsonParamType(click.ParamType):
name = "json"

def convert(self, value, param, ctx) -> Dict[str, Any]:
if value is None:
return None
if isinstance(value, dict):
return value
try:
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
# Restructure abstractions
"KGEnrichmentSettings",
# User abstractions
"Token",
"TokenData",
Expand Down
3 changes: 3 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
VectorSearchResult,
VectorSearchSettings,
)
from .restructure import KGEnrichmentSettings
from .user import Token, TokenData, UserStats
from .vector import Vector, VectorEntry, VectorType

Expand Down Expand Up @@ -81,6 +82,8 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
# Restructure abstractions
"KGEnrichmentSettings",
# User abstractions
"Token",
"TokenData",
Expand Down
23 changes: 23 additions & 0 deletions py/core/base/abstractions/restructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from pydantic import BaseModel, Field

from .llm import GenerationConfig


class KGEnrichmentSettings(BaseModel):
"""Settings for knowledge graph enrichment."""

max_knowledge_triples: int = Field(
default=100,
description="The maximum number of knowledge triples to extract from each chunk.",
)

generation_config: GenerationConfig = Field(
default_factory=GenerationConfig,
description="Configuration for text generation during graph enrichment.",
)
leiden_params: dict = Field(
default_factory=dict,
description="Parameters for the Leiden algorithm.",
)
8 changes: 4 additions & 4 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
WrappedUserOverviewResponse,
)
from .restructure.responses import (
KGEnrichementResponse,
WrappedKGEnrichementResponse,
KGEnrichmentResponse,
WrappedKGEnrichmentResponse,
)
from .retrieval.responses import (
RAGAgentResponse,
Expand All @@ -57,8 +57,8 @@
"IngestionResponse",
"WrappedIngestionResponse",
# Restructure Responses
"KGEnrichementResponse",
"WrappedKGEnrichementResponse",
"KGEnrichmentResponse",
"WrappedKGEnrichmentResponse",
# Management Responses
"PromptResponse",
"ServerStats",
Expand Down
4 changes: 2 additions & 2 deletions py/core/base/api/models/restructure/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pydantic import BaseModel


class KGEnrichementResponse(BaseModel):
class KGEnrichmentResponse(BaseModel):
enriched_content: Dict[str, Any]


WrappedKGEnrichementResponse = ResultsWrapper[KGEnrichementResponse]
WrappedKGEnrichmentResponse = ResultsWrapper[KGEnrichmentResponse]
4 changes: 3 additions & 1 deletion py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...base.utils.base_utils import RelationshipType
from ..abstractions.graph import Entity, KGExtraction, Triple
from ..abstractions.llm import GenerationConfig
from ..abstractions.restructure import KGEnrichmentSettings
from .base import ProviderConfig

logger = logging.getLogger(__name__)
Expand All @@ -20,8 +21,9 @@ class KGConfig(ProviderConfig):
kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
kg_search_prompt: Optional[str] = "kg_search"
kg_extraction_config: Optional[GenerationConfig] = None
kg_search_config: Optional[GenerationConfig] = None
kg_store_path: Optional[str] = None
max_knowledge_triples: Optional[int] = 100
kg_enrichment_settings: Optional[KGEnrichmentSettings] = KGEnrichmentSettings()

def validate(self) -> None:
if self.provider not in self.supported_providers:
Expand Down
20 changes: 14 additions & 6 deletions py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"


[kg.kg_extraction_config]
model = "gpt-4o-mini"
temperature = 1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[kg.kg_search_config]
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[database]
provider = "postgres"
Expand Down
9 changes: 7 additions & 2 deletions py/core/main/api/routes/restructure/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from core.base import KGEnrichmentSettings
from core.main.api.routes.base_router import BaseRouter
from core.main.engine import R2REngine
from fastapi import Depends

from typing import Union
from fastapi import Body, Depends

class RestructureRouter(BaseRouter):
def __init__(self, engine: R2REngine):
Expand All @@ -12,6 +13,10 @@ def setup_routes(self):
@self.router.post("/enrich_graph")
@self.base_endpoint
async def enrich_graph(
KGEnrichmentSettings: Union[dict, KGEnrichmentSettings] = Body(
...,
description="Settings for knowledge graph enrichment",
),
auth_user=(
Depends(self.engine.providers.auth.auth_wrapper)
if self.engine.providers.auth
Expand Down
1 change: 1 addition & 0 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self.provider_factory_override: Optional[Type[R2RProviderFactory]] = (
None
)

self.pipe_factory_override: Optional[R2RPipeFactory] = None
self.pipeline_factory_override: Optional[R2RPipelineFactory] = None

Expand Down
5 changes: 4 additions & 1 deletion py/core/main/services/restructure_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

from core.base import R2RException, RunLoggingSingleton, RunManager
from core.base.abstractions import KGEnrichmentSettings

from ..abstractions import R2RAgents, R2RPipelines, R2RProviders
from ..assembly.config import R2RConfig
Expand Down Expand Up @@ -30,7 +31,9 @@ def __init__(
logging_connection,
)

async def enrich_graph(self) -> Dict[str, Any]:
async def enrich_graph(
self, enrich_graph_settings: KGEnrichmentSettings = KGEnrichmentSettings()
) -> Dict[str, Any]:
"""
Perform graph enrichment.
Expand Down
32 changes: 16 additions & 16 deletions py/core/pipes/kg/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PromptProvider,
RunLoggingSingleton,
Triple,
KGEnrichmentSettings,
)

logger = logging.getLogger(__name__)
Expand All @@ -38,9 +39,6 @@ def __init__(
llm_provider: CompletionProvider,
prompt_provider: PromptProvider,
embedding_provider: EmbeddingProvider,
cluster_batch_size: int = 100,
max_cluster_size: int = 10,
use_lcc: bool = True,
pipe_logger: Optional[RunLoggingSingleton] = None,
type: PipeType = PipeType.OTHER,
config: Optional[AsyncPipe.PipeConfig] = None,
Expand All @@ -57,23 +55,20 @@ def __init__(
)
self.kg_provider = kg_provider
self.llm_provider = llm_provider
self.cluster_batch_size = cluster_batch_size
self.max_cluster_size = max_cluster_size
self.use_lcc = use_lcc
self.prompt_provider = prompt_provider
self.embedding_provider = embedding_provider

def _compute_leiden_communities(
self,
graph: nx.Graph,
seed: int = 0xDEADBEEF,
settings: KGEnrichmentSettings,
) -> dict[int, dict[str, int]]:
"""Compute Leiden communities."""
try:
from graspologic.partition import hierarchical_leiden

community_mapping = hierarchical_leiden(
graph, max_cluster_size=self.max_cluster_size, random_seed=seed
graph, **settings.leiden_params
)
results: dict[int, dict[str, int]] = {}
for partition in community_mapping:
Expand All @@ -84,7 +79,9 @@ def _compute_leiden_communities(
except ImportError as e:
raise ImportError("Please install the graspologic package.") from e

async def cluster_kg(self, triples: list[Triple]) -> list[Community]:
async def cluster_kg(
self, triples: list[Triple], settings: KGEnrichmentSettings = KGEnrichmentSettings()
) -> list[Community]:
"""
Clusters the knowledge graph triples into communities using hierarchical Leiden algorithm.
"""
Expand All @@ -100,7 +97,9 @@ async def cluster_kg(self, triples: list[Triple]) -> list[Community]:
id=f"{triple.subject}->{triple.predicate}->{triple.object}",
)

hierarchical_communities = self._compute_leiden_communities(G)
hierarchical_communities = self._compute_leiden_communities(
G, settings=settings
)

community_details = {}

Expand Down Expand Up @@ -172,9 +171,7 @@ async def process_community(community_key, community):
"input_text": input_text,
},
),
generation_config=GenerationConfig(
model="gpt-4o-mini",
),
generation_config=settings.generation_config,
)

description = description.choices[0].message.content
Expand Down Expand Up @@ -202,8 +199,11 @@ async def process_community(community_key, community):
)
)

for completed_task in asyncio.as_completed(tasks):
yield await completed_task
total_tasks = len(tasks)
for i, completed_task in enumerate(asyncio.as_completed(tasks), 1):
result = await completed_task
logger.info(f"Progress: {i}/{total_tasks} communities completed ({i/total_tasks*100:.2f}%)")
yield result

async def _run_logic(
self,
Expand All @@ -230,5 +230,5 @@ async def _run_logic(

triples = self.kg_provider.get_triples()

async for community in self.cluster_kg(triples):
async for community in self.cluster_kg(triples, self.kg_provider.config.kg_enrichment_settings):
yield community
2 changes: 1 addition & 1 deletion py/core/pipes/kg/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def extract_kg(

task_inputs = {"input": fragment.data}
task_inputs["max_knowledge_triples"] = (
self.kg_provider.config.max_knowledge_triples
self.kg_provider.config.kg_enrichment_settings.max_knowledge_triples
)

messages = self.prompt_provider._get_message_payload(
Expand Down
8 changes: 4 additions & 4 deletions py/core/providers/chunking/r2r_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, config: ChunkingConfig):
)

def _initialize_text_splitter(self) -> TextSplitter:
if self.config.method == Method.RECURSIVE:
logger.info(
f"Initializing text splitter with method: {self.config.method}"
) # Debug log
if self.config.method == Method.RECURSIVE or self.config.method == Method.BASIC:
return RecursiveCharacterTextSplitter(
chunk_size=self.config.chunk_size,
chunk_overlap=self.config.chunk_overlap,
)
elif self.config.method == Method.BASIC:
# Implement basic method
raise NotImplementedError("Basic method not implemented yet")
elif self.config.method == Method.BY_TITLE:
# Implement by_title method
raise NotImplementedError("By_title method not implemented yet")
Expand Down
19 changes: 11 additions & 8 deletions py/core/providers/parsing/unstructured_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@

class UnstructuredParsingProvider(ParsingProvider):
def __init__(self, use_api, config):
try:
from unstructured.partition.auto import partition

self.partition = partition
except ImportError as e:
raise ImportError(
"Please install the unstructured package to use the unstructured parsing provider."
) from e
if config.excluded_parsers:
logger.warning(
"Excluded parsers are not supported by the unstructured parsing provider."
Expand Down Expand Up @@ -57,6 +49,17 @@ def __init__(self, use_api, config):
self.operations = operations
self.dict_to_elements = dict_to_elements

else:
try:
from unstructured.partition.auto import partition

self.partition = partition

except ImportError:
raise ImportError(
"Please install the unstructured package to use the unstructured parsing provider."
)

super().__init__(config)

async def parse(
Expand Down
16 changes: 16 additions & 0 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,23 @@ def model_dump(self, *args, **kwargs):
str(uuid) for uuid in dump["selected_group_ids"]
]
return dump

class KGEnrichmentSettings(BaseModel):
max_knowledge_triples: int = Field(
default=100,
description="The maximum number of knowledge triples to extract from each chunk.",
)
generation_config: GenerationConfig = Field(
default_factory=GenerationConfig,
description="The generation configuration for the KG enrichment.",
)
leiden_params: dict = Field(
default_factory=dict,
description="The parameters for the Leiden algorithm.",
)

class KGEnrichmentResponse(BaseModel):
enriched_content: Dict[str, Any]

class UserResponse(BaseModel):
id: UUID
Expand Down
Loading

0 comments on commit c0db4ec

Please sign in to comment.