Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/make parsing chunking providers #820

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
613 changes: 553 additions & 60 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ bcrypt = "^4.1.3"
pyjwt = "^2.8.0"
toml = "^0.10.2"
pyyaml = "^6.0.1"
unstructured = "^0.15.0"

[tool.poetry.extras]
all = ["moviepy", "opencv-python"]
ingest-movies = ["moviepy", "opencv-python"]


[tool.poetry.group.dev.dependencies]
black = "^24.3.0"
codecov = "^2.1.13"
Expand Down
27 changes: 15 additions & 12 deletions r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ require_email_verification = false
default_admin_email = "[email protected]"
default_admin_password = "change_me_immediately"

[chunking]
provider = "unstructured"
method = "by_title"
chunk_size = 512
chunk_overlap = 50
max_chunk_size = 1024

[completion]
provider = "litellm"
concurrent_request_limit = 16
Expand Down Expand Up @@ -37,18 +44,6 @@ concurrent_request_limit = 256
[eval]
provider = "None"

[ingestion]
excluded_parsers = [ "mp4" ]

[[ingestion.override_parsers]]
document_type = "pdf"
parser = "PDFParser"

[ingestion.text_splitter]
type = "recursive_character"
chunk_size = 512
chunk_overlap = 20

[kg]
provider = "None"

Expand All @@ -57,5 +52,13 @@ provider = "local"
log_table = "logs"
log_info_table = "log_info"

[parsing]
provider = "r2r"
excluded_parsers = ["mp4"]

[[parsing.override_parsers]]
document_type = "pdf"
parser = "PDFParser"

[prompt]
provider = "r2r"
6 changes: 6 additions & 0 deletions r2r/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .pipeline.base_pipeline import AsyncPipeline
from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType
from .providers.auth import AuthConfig, AuthProvider
from .providers.chunking import ChunkingConfig, ChunkingProvider
from .providers.crypto import CryptoConfig, CryptoProvider
from .providers.database import (
DatabaseConfig,
Expand All @@ -75,6 +76,7 @@
from .providers.eval import EvalConfig, EvalProvider
from .providers.kg import KGConfig, KGProvider, update_kg_prompt
from .providers.llm import CompletionConfig, CompletionProvider
from .providers.parsing import ParsingConfig, ParsingProvider
from .providers.prompt import PromptConfig, PromptProvider
from .utils import (
EntityType,
Expand Down Expand Up @@ -154,6 +156,10 @@
# Pipelines
"AsyncPipeline",
# Providers
"ParsingConfig",
"ParsingProvider",
"ChunkingConfig",
"ChunkingProvider",
"EmbeddingConfig",
"EmbeddingProvider",
"EvalConfig",
Expand Down
41 changes: 41 additions & 0 deletions r2r/base/providers/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, AsyncGenerator, List, Optional

from pydantic import BaseModel, Field

from ..abstractions.document import Document, DocumentType
from .base import Provider, ProviderConfig


class Method(str, Enum):
BY_TITLE = "by_title"
BASIC = "basic"
RECURSIVE = "recursive"


class ChunkingConfig(ProviderConfig):
provider: str = "r2r"
method: Method = Method.RECURSIVE
chunk_size: int = 512
chunk_overlap: int = 0
max_chunk_size: Optional[int] = None

def validate(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider {self.provider} is not supported.")

@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured", None]


class ChunkingProvider(Provider, ABC):
def __init__(self, config: ChunkingConfig):
super().__init__(config)
self.config = config

@abstractmethod
async def chunk(self, parsed_document: str) -> AsyncGenerator[str, None]:
"""Chunk the parsed document using the configured chunking strategy."""
pass
42 changes: 42 additions & 0 deletions r2r/base/providers/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List

from pydantic import BaseModel, Field

from ..abstractions.document import Document, DocumentType
from .base import Provider, ProviderConfig


class OverrideParser(BaseModel):
document_type: DocumentType
parser: str


class ParsingConfig(ProviderConfig):
provider: str = "r2r"
excluded_parsers: List[DocumentType] = Field(default_factory=list)
override_parsers: List[OverrideParser] = Field(default_factory=list)

@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured", None]

def validate(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider {self.provider} is not supported.")


class ParsingProvider(Provider, ABC):
def __init__(self, config: ParsingConfig):
super().__init__(config)
self.config = config

@abstractmethod
async def parse(self, document: Document) -> AsyncGenerator[Any, None]:
"""Parse the document using the configured parsing strategy."""
pass

@abstractmethod
def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
"""Get the appropriate parser for a given document type."""
pass
2 changes: 1 addition & 1 deletion r2r/examples/configs/local_llm_neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ kg_extraction_prompt = "zero_shot_ner_kg_extraction"
stream = false
add_generation_kwargs = { }

[vector_database]
[database]
provider = "pgvector"
7 changes: 6 additions & 1 deletion r2r/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from r2r.base import (
AsyncPipe,
AuthProvider,
ChunkingProvider,
CompletionProvider,
DatabaseProvider,
EmbeddingProvider,
EvalProvider,
KGProvider,
ParsingProvider,
PromptProvider,
)
from r2r.pipelines import (
Expand All @@ -23,19 +25,22 @@

class R2RProviders(BaseModel):
auth: Optional[AuthProvider]
chunking: Optional[ChunkingProvider]
llm: Optional[CompletionProvider]
database: Optional[DatabaseProvider]
embedding: Optional[EmbeddingProvider]
llm: Optional[CompletionProvider]
prompt: Optional[PromptProvider]
eval: Optional[EvalProvider]
kg: Optional[KGProvider]
parsing: Optional[ParsingProvider]

class Config:
arbitrary_types_allowed = True


class R2RPipes(BaseModel):
parsing_pipe: Optional[AsyncPipe]
chunking_pipe: Optional[AsyncPipe]
embedding_pipe: Optional[AsyncPipe]
vector_storage_pipe: Optional[AsyncPipe]
vector_search_pipe: Optional[AsyncPipe]
Expand Down
20 changes: 12 additions & 8 deletions r2r/main/assembly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from ...base.abstractions.llm import GenerationConfig
from ...base.logging.kv_logger import LoggingConfig
from ...base.providers.auth import AuthConfig
from ...base.providers.chunking import ChunkingConfig
from ...base.providers.crypto import CryptoConfig
from ...base.providers.database import DatabaseConfig, ProviderConfig
from ...base.providers.embedding import EmbeddingConfig
from ...base.providers.eval import EvalConfig
from ...base.providers.kg import KGConfig
from ...base.providers.llm import CompletionConfig
from ...base.providers.parsing import ParsingConfig
from ...base.providers.prompt import PromptConfig

logger = logging.getLogger(__name__)
Expand All @@ -37,20 +39,24 @@ class R2RConfig:
"batch_size",
"kg_extraction_config",
],
"ingestion": ["excluded_parsers", "text_splitter"],
"parsing": ["provider", "excluded_parsers"],
"chunking": ["provider", "method"],
"completion": ["provider"],
"logging": ["provider", "log_table"],
"prompt": ["provider"],
"database": ["provider"],
}
auth: AuthConfig
chunking: ChunkingConfig
completion: CompletionConfig
crypto: CryptoConfig
database: DatabaseConfig
embedding: EmbeddingConfig
eval: EvalConfig
kg: KGConfig
completion: CompletionConfig
logging: LoggingConfig
parsing: ParsingConfig
prompt: PromptConfig
database: DatabaseConfig

def __init__(self, config_data: dict[str, Any]):
# Load the default configuration
Expand All @@ -75,19 +81,17 @@ def __init__(self, config_data: dict[str, Any]):
self._validate_config_section(default_config, section, keys)
setattr(self, section, default_config[section])
self.auth = AuthConfig.create(**self.auth)
self.chunking = ChunkingConfig.create(**self.chunking)
self.completion = CompletionConfig.create(**self.completion)
self.crypto = CryptoConfig.create(**self.crypto)
self.database = DatabaseConfig.create(**self.database)
self.embedding = EmbeddingConfig.create(**self.embedding)
self.eval = EvalConfig.create(**self.eval, llm=None)
self.logging = LoggingConfig.create(**self.logging)
self.kg = KGConfig.create(**self.kg)
self.logging = LoggingConfig.create(**self.logging)
self.parsing = ParsingConfig.create(**self.parsing)
self.prompt = PromptConfig.create(**self.prompt)

self.ingestion = self.ingestion # for type hinting
self.ingestion["excluded_parsers"] = [
DocumentType(k) for k in self.ingestion["excluded_parsers"]
]
# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completion.generation_config.dict()
Expand Down
Loading
Loading