-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a common type for ingested tokens (#70)
- Loading branch information
1 parent
2c34760
commit 37a9c8b
Showing
3 changed files
with
75 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Union | ||
|
||
|
||
class IngestedTokens: | ||
def __init__(self, file: str, data: [str], error: str = None) -> None: | ||
self.data = data | ||
self.error = error | ||
self.file = file | ||
if self.file: | ||
file = str(file) | ||
self.extension = file.split(".")[-1] | ||
self.file_id = file.split("/")[-1].split(".")[0] | ||
|
||
def __str__(self): | ||
if self.error: | ||
return f"Error: {self.error}" | ||
return f"Data: {self.data}" | ||
|
||
def is_error(self) -> bool: | ||
return self.error is not None | ||
|
||
def get_file_path(self) -> str: | ||
return self.file | ||
|
||
def get_extension(self) -> str: | ||
return self.extension | ||
|
||
def get_file_id(self) -> str: | ||
return self.file_id | ||
|
||
@classmethod | ||
def success(cls, data: bytes) -> "IngestedTokens": | ||
return cls(data) | ||
|
||
@classmethod | ||
def error(cls, error: str) -> "IngestedTokens": | ||
return cls(None, error) | ||
|
||
def unwrap(self) -> bytes: | ||
if self.error: | ||
raise ValueError(self.error) | ||
return self.data | ||
|
||
def unwrap_or(self, default: bytes) -> bytes: | ||
return self.data if not self.error else default | ||
|
||
def __eq__(self, other: Union[bytes, "IngestedTokens"]) -> bool: | ||
if isinstance(other, IngestedTokens): | ||
return self.data == other.data and self.error == other.error | ||
return self.data == other | ||
|
||
def __hash__(self) -> int: | ||
return hash((self.data, self.error)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,22 @@ | ||
from typing import List | ||
from abc import abstractmethod | ||
from typing import AsyncGenerator, List | ||
from querent.common.types.collected_bytes import CollectedBytes | ||
from querent.common.types.ingested_tokens import IngestedTokens | ||
from querent.processors.async_processor import AsyncProcessor | ||
|
||
|
||
class BaseIngestor: | ||
def __init__(self, processors: List[AsyncProcessor]): | ||
self.processors = processors | ||
|
||
async def process_data(self, text): | ||
# Your common data processing logic here | ||
@abstractmethod | ||
async def ingest( | ||
self, poll_function: AsyncGenerator[CollectedBytes, None] | ||
) -> AsyncGenerator[IngestedTokens, None]: | ||
# Your common ingestion logic here | ||
pass | ||
|
||
async def extract_text_from_file(self, file_path: str) -> str: | ||
# Your common file extraction logic here | ||
@abstractmethod | ||
async def process_data(self, text): | ||
# Your common data processing logic here | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters