Skip to content

Commit

Permalink
feat: allow passing meta in the run method of FileTypeRouter (#…
Browse files Browse the repository at this point in the history
…8486)

* initial refactoring

* progress

* refinements

* serde methods + tests

* release note

* comment

* make additional_mimetypes internal attribute
  • Loading branch information
anakin87 authored Oct 24, 2024
1 parent c24814c commit 7829242
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 38 deletions.
108 changes: 70 additions & 38 deletions haystack/components/routers/file_type_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from haystack import component, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
from haystack.dataclasses import ByteStream

logger = logging.getLogger(__name__)


# we add markdown because it is not added by the mimetypes module
# see https://github.com/python/cpython/pull/17995
CUSTOM_MIMETYPES = {".md": "text/markdown", ".markdown": "text/markdown"}


@component
class FileTypeRouter:
"""
Expand Down Expand Up @@ -50,19 +56,19 @@ class FileTypeRouter:
# PosixPath('song.mp3')], 'text/plain': [PosixPath('file.txt')], 'unclassified': [PosixPath('document.pdf')
# ]}
```
:param mime_types: A list of MIME types or regex patterns to classify the input files or byte streams.
"""

def __init__(self, mime_types: List[str], additional_mimetypes: Optional[Dict[str, str]] = None):
"""
Initialize the FileTypeRouter component.
:param mime_types: A list of MIME types or regex patterns to classify the input files or byte streams.
:param mime_types:
A list of MIME types or regex patterns to classify the input files or byte streams.
(for example: `["text/plain", "audio/x-wav", "image/jpeg"]`).
:param additional_mimetypes: A dictionary containing the MIME type to add to the mimetypes package to prevent
unsupported or non native packages from being unclassified.
:param additional_mimetypes:
A dictionary containing the MIME type to add to the mimetypes package to prevent unsupported or non native
packages from being unclassified.
(for example: `{"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}`).
"""
if not mime_types:
Expand All @@ -74,35 +80,84 @@ def __init__(self, mime_types: List[str], additional_mimetypes: Optional[Dict[st

self.mime_type_patterns = []
for mime_type in mime_types:
if not self._is_valid_mime_type_format(mime_type):
raise ValueError(f"Invalid mime type or regex pattern: '{mime_type}'.")
pattern = re.compile(mime_type)
try:
pattern = re.compile(mime_type)
except re.error:
raise ValueError(f"Invalid regex pattern '{mime_type}'.")
self.mime_type_patterns.append(pattern)

component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
# the actual output type is List[Union[Path, ByteStream]],
# but this would cause PipelineConnectError with Converters
component.set_output_types(
self,
unclassified=List[Union[str, Path, ByteStream]],
**{mime_type: List[Union[str, Path, ByteStream]] for mime_type in mime_types},
)
self.mime_types = mime_types
self._additional_mimetypes = additional_mimetypes

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, mime_types=self.mime_types, additional_mimetypes=self._additional_mimetypes)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)

def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
def run(
self,
sources: List[Union[str, Path, ByteStream]],
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
) -> Dict[str, List[Union[ByteStream, Path]]]:
"""
Categorize files or byte streams according to their MIME types.
:param sources: A list of file paths or byte streams to categorize.
:param sources:
A list of file paths or byte streams to categorize.
:param meta:
Optional metadata to attach to the sources.
When provided, the sources are internally converted to ByteStream objects and the metadata is added.
This value can be a list of dictionaries or a single dictionary.
If it's a single dictionary, its content is added to the metadata of all ByteStream objects.
If it's a list, its length must match the number of sources, as they are zipped together.
:returns: A dictionary where the keys are MIME types (or `"unclassified"`) and the values are lists of data
sources.
"""

mime_types = defaultdict(list)
for source in sources:
meta_list = normalize_metadata(meta=meta, sources_count=len(sources))

for source, meta_dict in zip(sources, meta_list):
if isinstance(source, str):
source = Path(source)

if isinstance(source, Path):
mime_type = self._get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.mime_type
else:
raise ValueError(f"Unsupported data source type: {type(source).__name__}")

# If we have metadata, we convert the source to ByteStream and add the metadata
if meta_dict:
source = get_bytestream_from_source(source)
source.meta.update(meta_dict)

matched = False
if mime_type:
for pattern in self.mime_type_patterns:
Expand All @@ -126,27 +181,4 @@ def _get_mime_type(self, path: Path) -> Optional[str]:
extension = path.suffix.lower()
mime_type = mimetypes.guess_type(path.as_posix())[0]
# lookup custom mappings if the mime type is not found
return self._get_custom_mime_mappings().get(extension, mime_type)

def _is_valid_mime_type_format(self, mime_type: str) -> bool:
"""
Checks if the provided MIME type string is a valid regex pattern.
:param mime_type: The MIME type or regex pattern to validate.
:raises ValueError: If the mime_type is not a valid regex pattern.
:returns: Always True because a ValueError is raised for invalid patterns.
"""
try:
re.compile(mime_type)
return True
except re.error:
raise ValueError(f"Invalid regex pattern '{mime_type}'.")

@staticmethod
def _get_custom_mime_mappings() -> Dict[str, str]:
"""
Returns a dictionary of custom file extension to MIME type mappings.
"""
# we add markdown because it is not added by the mimetypes module
# see https://github.com/python/cpython/pull/17995
return {".md": "text/markdown", ".markdown": "text/markdown"}
return CUSTOM_MIMETYPES.get(extension, mime_type)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
The `FiletypeRouter` now supports passing metadata (`meta`) in the `run` method.
When metadata is provided, the sources are internally converted to `ByteStream` objects and the metadata is added.
This new parameter simplifies working with preprocessing/indexing pipelines.
Loading

0 comments on commit 7829242

Please sign in to comment.