Skip to content

Commit

Permalink
Feature/make parsing chunking providers (#820)
Browse files Browse the repository at this point in the history
* moving towards pipeline options

* move to parse / chunk providers

* make run serve work.

* refactored ingestion

* fix import issues

* cleanup
  • Loading branch information
emrgnt-cmplxty authored Aug 1, 2024
1 parent 95c0472 commit 7d9b756
Show file tree
Hide file tree
Showing 27 changed files with 1,205 additions and 551 deletions.
613 changes: 553 additions & 60 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ bcrypt = "^4.1.3"
pyjwt = "^2.8.0"
toml = "^0.10.2"
pyyaml = "^6.0.1"
unstructured = "^0.15.0"

[tool.poetry.extras]
all = ["moviepy", "opencv-python"]
ingest-movies = ["moviepy", "opencv-python"]


[tool.poetry.group.dev.dependencies]
black = "^24.3.0"
codecov = "^2.1.13"
Expand Down
27 changes: 15 additions & 12 deletions r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ require_email_verification = false
default_admin_email = "[email protected]"
default_admin_password = "change_me_immediately"

[chunking]
provider = "unstructured"
method = "by_title"
chunk_size = 512
chunk_overlap = 50
max_chunk_size = 1024

[completion]
provider = "litellm"
concurrent_request_limit = 16
Expand Down Expand Up @@ -37,18 +44,6 @@ concurrent_request_limit = 256
[eval]
provider = "None"

[ingestion]
excluded_parsers = [ "mp4" ]

[[ingestion.override_parsers]]
document_type = "pdf"
parser = "PDFParser"

[ingestion.text_splitter]
type = "recursive_character"
chunk_size = 512
chunk_overlap = 20

[kg]
provider = "None"

Expand All @@ -57,5 +52,13 @@ provider = "local"
log_table = "logs"
log_info_table = "log_info"

[parsing]
provider = "r2r"
excluded_parsers = ["mp4"]

[[parsing.override_parsers]]
document_type = "pdf"
parser = "PDFParser"

[prompt]
provider = "r2r"
6 changes: 6 additions & 0 deletions r2r/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .pipeline.base_pipeline import AsyncPipeline
from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType
from .providers.auth import AuthConfig, AuthProvider
from .providers.chunking import ChunkingConfig, ChunkingProvider
from .providers.crypto import CryptoConfig, CryptoProvider
from .providers.database import (
DatabaseConfig,
Expand All @@ -75,6 +76,7 @@
from .providers.eval import EvalConfig, EvalProvider
from .providers.kg import KGConfig, KGProvider, update_kg_prompt
from .providers.llm import CompletionConfig, CompletionProvider
from .providers.parsing import ParsingConfig, ParsingProvider
from .providers.prompt import PromptConfig, PromptProvider
from .utils import (
EntityType,
Expand Down Expand Up @@ -154,6 +156,10 @@
# Pipelines
"AsyncPipeline",
# Providers
"ParsingConfig",
"ParsingProvider",
"ChunkingConfig",
"ChunkingProvider",
"EmbeddingConfig",
"EmbeddingProvider",
"EvalConfig",
Expand Down
41 changes: 41 additions & 0 deletions r2r/base/providers/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, AsyncGenerator, List, Optional

from pydantic import BaseModel, Field

from ..abstractions.document import Document, DocumentType
from .base import Provider, ProviderConfig


class Method(str, Enum):
BY_TITLE = "by_title"
BASIC = "basic"
RECURSIVE = "recursive"


class ChunkingConfig(ProviderConfig):
provider: str = "r2r"
method: Method = Method.RECURSIVE
chunk_size: int = 512
chunk_overlap: int = 0
max_chunk_size: Optional[int] = None

def validate(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider {self.provider} is not supported.")

@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured", None]


class ChunkingProvider(Provider, ABC):
def __init__(self, config: ChunkingConfig):
super().__init__(config)
self.config = config

@abstractmethod
async def chunk(self, parsed_document: str) -> AsyncGenerator[str, None]:
"""Chunk the parsed document using the configured chunking strategy."""
pass
42 changes: 42 additions & 0 deletions r2r/base/providers/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List

from pydantic import BaseModel, Field

from ..abstractions.document import Document, DocumentType
from .base import Provider, ProviderConfig


class OverrideParser(BaseModel):
document_type: DocumentType
parser: str


class ParsingConfig(ProviderConfig):
provider: str = "r2r"
excluded_parsers: List[DocumentType] = Field(default_factory=list)
override_parsers: List[OverrideParser] = Field(default_factory=list)

@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured", None]

def validate(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider {self.provider} is not supported.")


class ParsingProvider(Provider, ABC):
def __init__(self, config: ParsingConfig):
super().__init__(config)
self.config = config

@abstractmethod
async def parse(self, document: Document) -> AsyncGenerator[Any, None]:
"""Parse the document using the configured parsing strategy."""
pass

@abstractmethod
def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
"""Get the appropriate parser for a given document type."""
pass
2 changes: 1 addition & 1 deletion r2r/examples/configs/local_llm_neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ kg_extraction_prompt = "zero_shot_ner_kg_extraction"
stream = false
add_generation_kwargs = { }

[vector_database]
[database]
provider = "pgvector"
7 changes: 6 additions & 1 deletion r2r/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from r2r.base import (
AsyncPipe,
AuthProvider,
ChunkingProvider,
CompletionProvider,
DatabaseProvider,
EmbeddingProvider,
EvalProvider,
KGProvider,
ParsingProvider,
PromptProvider,
)
from r2r.pipelines import (
Expand All @@ -23,19 +25,22 @@

class R2RProviders(BaseModel):
auth: Optional[AuthProvider]
chunking: Optional[ChunkingProvider]
llm: Optional[CompletionProvider]
database: Optional[DatabaseProvider]
embedding: Optional[EmbeddingProvider]
llm: Optional[CompletionProvider]
prompt: Optional[PromptProvider]
eval: Optional[EvalProvider]
kg: Optional[KGProvider]
parsing: Optional[ParsingProvider]

class Config:
arbitrary_types_allowed = True


class R2RPipes(BaseModel):
parsing_pipe: Optional[AsyncPipe]
chunking_pipe: Optional[AsyncPipe]
embedding_pipe: Optional[AsyncPipe]
vector_storage_pipe: Optional[AsyncPipe]
vector_search_pipe: Optional[AsyncPipe]
Expand Down
20 changes: 12 additions & 8 deletions r2r/main/assembly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from ...base.abstractions.llm import GenerationConfig
from ...base.logging.kv_logger import LoggingConfig
from ...base.providers.auth import AuthConfig
from ...base.providers.chunking import ChunkingConfig
from ...base.providers.crypto import CryptoConfig
from ...base.providers.database import DatabaseConfig, ProviderConfig
from ...base.providers.embedding import EmbeddingConfig
from ...base.providers.eval import EvalConfig
from ...base.providers.kg import KGConfig
from ...base.providers.llm import CompletionConfig
from ...base.providers.parsing import ParsingConfig
from ...base.providers.prompt import PromptConfig

logger = logging.getLogger(__name__)
Expand All @@ -37,20 +39,24 @@ class R2RConfig:
"batch_size",
"kg_extraction_config",
],
"ingestion": ["excluded_parsers", "text_splitter"],
"parsing": ["provider", "excluded_parsers"],
"chunking": ["provider", "method"],
"completion": ["provider"],
"logging": ["provider", "log_table"],
"prompt": ["provider"],
"database": ["provider"],
}
auth: AuthConfig
chunking: ChunkingConfig
completion: CompletionConfig
crypto: CryptoConfig
database: DatabaseConfig
embedding: EmbeddingConfig
eval: EvalConfig
kg: KGConfig
completion: CompletionConfig
logging: LoggingConfig
parsing: ParsingConfig
prompt: PromptConfig
database: DatabaseConfig

def __init__(self, config_data: dict[str, Any]):
# Load the default configuration
Expand All @@ -75,19 +81,17 @@ def __init__(self, config_data: dict[str, Any]):
self._validate_config_section(default_config, section, keys)
setattr(self, section, default_config[section])
self.auth = AuthConfig.create(**self.auth)
self.chunking = ChunkingConfig.create(**self.chunking)
self.completion = CompletionConfig.create(**self.completion)
self.crypto = CryptoConfig.create(**self.crypto)
self.database = DatabaseConfig.create(**self.database)
self.embedding = EmbeddingConfig.create(**self.embedding)
self.eval = EvalConfig.create(**self.eval, llm=None)
self.logging = LoggingConfig.create(**self.logging)
self.kg = KGConfig.create(**self.kg)
self.logging = LoggingConfig.create(**self.logging)
self.parsing = ParsingConfig.create(**self.parsing)
self.prompt = PromptConfig.create(**self.prompt)

self.ingestion = self.ingestion # for type hinting
self.ingestion["excluded_parsers"] = [
DocumentType(k) for k in self.ingestion["excluded_parsers"]
]
# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completion.generation_config.dict()
Expand Down
Loading

0 comments on commit 7d9b756

Please sign in to comment.