Skip to content

Commit 2f97d51

Browse files
committed
Added validator arg to parse_image, and used in testing to confirm concurrency works
1 parent 4c05ebe commit 2f97d51

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/paperqa/readers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import os
5+
from collections.abc import Awaitable, Callable
56
from math import ceil
67
from pathlib import Path
78
from typing import Literal, Protocol, cast, overload, runtime_checkable
@@ -32,9 +33,18 @@ def __call__(
3233
) -> ParsedText: ...
3334

3435

35-
async def parse_image(path: str | os.PathLike, **_) -> ParsedText:
36+
async def parse_image(
37+
path: str | os.PathLike, validator: Callable[[bytes], Awaitable] | None = None, **_
38+
) -> ParsedText:
3639
apath = anyio.Path(path)
3740
image_data = await anyio.Path(path).read_bytes()
41+
if validator:
42+
try:
43+
await validator(image_data)
44+
except Exception as exc:
45+
raise ImpossibleParsingError(
46+
f"Image validation failed for the image at path {path}."
47+
) from exc
3848
parsed_media = ParsedMedia(index=0, data=image_data, info={"suffix": apath.suffix})
3949
metadata = ParsedMetadata(
4050
parsing_libraries=[],

tests/test_paperqa.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from io import BytesIO
1515
from pathlib import Path
1616
from typing import cast
17+
from unittest.mock import MagicMock, call
1718
from uuid import UUID
1819

1920
import httpx
@@ -1427,13 +1428,31 @@ async def test_read_doc_images_metadata(stub_data_dir: Path) -> None:
14271428
async def test_read_doc_images_concurrency(stub_data_dir: Path) -> None:
14281429
png_path = stub_data_dir / "sf_districts.png"
14291430
doc = Doc(docname="stub", citation="stub", dockey="stub")
1431+
validation_mock = MagicMock()
1432+
1433+
async def validate(data: bytes) -> None: # noqa: RUF029
1434+
validate_image(io.BytesIO(data))
1435+
validation_mock(data)
14301436

14311437
# Check we can concurrently read in the same image many times
1432-
bulk_texts = await asyncio.gather(*(read_doc(png_path, doc) for _ in range(10)))
1438+
concurrent_call_count = 10
1439+
seen_media = set()
1440+
bulk_texts = await asyncio.gather(
1441+
*(
1442+
read_doc(png_path, doc, validator=validate)
1443+
for _ in range(concurrent_call_count)
1444+
)
1445+
)
14331446
for (text,) in bulk_texts:
14341447
assert text.doc == doc
14351448
assert len(text.media) == 1
1436-
validate_image(io.BytesIO(text.media[0].data))
1449+
seen_media.add(text.media[0])
1450+
assert (
1451+
len(seen_media) == 1
1452+
), "Expected the concurrent reads to all have the same parsed result"
1453+
validation_mock.assert_has_calls(
1454+
[call(next(iter(seen_media)).data)] * concurrent_call_count
1455+
)
14371456

14381457

14391458
@pytest.mark.asyncio

0 commit comments

Comments
 (0)