diff --git a/.gitignore b/.gitignore index 9590e2b..42c4f52 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ python/examples e2e.sh TODO.md .vscode/ +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py index de262f7..fe22228 100644 --- a/datalab_sdk/client.py +++ b/datalab_sdk/client.py @@ -5,6 +5,13 @@ import asyncio import mimetypes import aiohttp +from tenacity import ( + retry, + retry_if_exception, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) from pathlib import Path from typing import Union, Optional, Dict, Any @@ -119,7 +126,7 @@ async def _poll_result( ) for i in range(max_polls): - data = await self._make_request("GET", full_url) + data = await self._poll_get_with_retry(full_url) if data.get("status") == "complete": return data @@ -135,6 +142,32 @@ async def _poll_result( f"Polling timed out after {max_polls * poll_interval} seconds" ) + @retry( + retry=( + retry_if_exception_type(DatalabTimeoutError) + | retry_if_exception( + lambda e: isinstance(e, DatalabAPIError) + and ( + # retry request timeout or too many requests + getattr(e, "status_code", None) in (408, 429) + or ( + # or if there's a server error + getattr(e, "status_code", None) is not None + and getattr(e, "status_code") >= 500 + ) + # or datalab api error without status code (e.g., connection errors) + or getattr(e, "status_code", None) is None + ) + ) + ), + stop=stop_after_attempt(2), + wait=wait_exponential_jitter(max=0.5), + reraise=True, + ) + async def _poll_get_with_retry(self, url: str) -> Dict[str, Any]: + """GET wrapper for polling with scoped retries for transient failures""" + return await self._make_request("GET", url) + def _prepare_file_data(self, file_path: Union[str, Path]) -> tuple: """Prepare file data for upload""" file_path = Path(file_path) @@ -156,7 +189,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None): if file_url and file_path: raise ValueError("Either file_path or file_url must be provided, not both.") - + # Use either file_url or file upload, not both if file_url: form_data.add_field("file_url", file_url) @@ -184,13 +217,19 @@ async def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint""" if options is None: options = ConvertOptions() initial_data = await self._make_request( - "POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options) + "POST", + "/api/v1/marker", + data=self.get_form_params( + file_path=file_path, file_url=file_url, options=options + ), ) if not initial_data.get("success"): @@ -198,7 +237,11 @@ async def convert( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = ConversionResult( success=result_data.get("success", False), @@ -227,13 +270,17 @@ async def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document""" if options is None: options = OCROptions() initial_data = await self._make_request( - "POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options) + "POST", + "/api/v1/ocr", + data=self.get_form_params(file_path=file_path, options=options), ) if not initial_data.get("success"): @@ -241,7 +288,11 @@ async def ocr( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = OCRResult( success=result_data.get("success", False), @@ -299,10 +350,19 @@ def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint (sync version)""" return self._run_async( - self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output) + self._async_client.convert( + file_path=file_path, + file_url=file_url, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) ) def ocr( @@ -310,6 +370,16 @@ def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document (sync version)""" - return self._run_async(self._async_client.ocr(file_path, options, save_output)) + return self._run_async( + self._async_client.ocr( + file_path=file_path, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) + ) diff --git a/datalab_sdk/models.py b/datalab_sdk/models.py index 0e0d486..07c79bc 100644 --- a/datalab_sdk/models.py +++ b/datalab_sdk/models.py @@ -47,7 +47,8 @@ class ConvertOptions(ProcessingOptions): block_correction_prompt: Optional[str] = None additional_config: Optional[Dict[str, Any]] = None page_schema: Optional[Dict[str, Any]] = None - output_format: str = "markdown" # markdown, json, html + output_format: str = "markdown" # markdown, json, html, chunks + mode: str = "fast" # fast, balanced, accurate @dataclass @@ -91,7 +92,11 @@ def save_output( json.dump(self.json, f, indent=2) if self.extraction_schema_json: - with open(output_path.with_suffix("_extraction_results.json"), "w", encoding="utf-8") as f: + with open( + output_path.with_suffix("_extraction_results.json"), + "w", + encoding="utf-8", + ) as f: f.write(self.extraction_schema_json) # Save images if present diff --git a/datalab_sdk/settings.py b/datalab_sdk/settings.py index 3c2fb86..605a008 100644 --- a/datalab_sdk/settings.py +++ b/datalab_sdk/settings.py @@ -6,7 +6,7 @@ class Settings(BaseSettings): # Paths BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) LOGLEVEL: str = "DEBUG" - VERSION: str = "0.1.3" + VERSION: str = "0.1.6" # Base settings DATALAB_API_KEY: str | None = None diff --git a/integration/test_live_api.py b/integration/test_live_api.py index 19c9a40..45970b2 100644 --- a/integration/test_live_api.py +++ b/integration/test_live_api.py @@ -62,6 +62,14 @@ def test_convert_office_document(self): assert len(result.html) > 0 assert result.output_format == "html" + def test_convert_pdf_high_accuracy(self): + client = DatalabClient() + pdf_file = DATA_DIR / "adversarial.pdf" + options = ConvertOptions(mode="accurate", max_pages=1) + result = client.convert(pdf_file, options=options) + + assert "subspace" in result.markdown.lower() + @pytest.mark.asyncio async def test_convert_async_with_json(self): """Test async conversion with JSON output""" diff --git a/pyproject.toml b/pyproject.toml index fb63155..e1ce78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" license = "MIT" repository = "https://github.com/datalab-to/sdk" keywords = ["datalab", "sdk", "document-intelligence", "api"] -version = "0.1.5" +version = "0.1.6" description = "SDK for the Datalab document intelligence API" requires-python = ">=3.10" dependencies = [ @@ -15,6 +15,7 @@ dependencies = [ "click>=8.2.1", "pydantic>=2.11.7,<3.0.0", "pydantic-settings>=2.10.1,<3.0.0", + "tenacity>=8.2.3,<9.0.0", ] [project.scripts] @@ -27,6 +28,7 @@ test = [ "pytest-mock>=3.11.0", "pytest-cov>=4.1.0", "aiofiles>=23.2.0", + "pytest-xdist>=3.8.0", ] [build-system] diff --git a/tests/test_client_methods.py b/tests/test_client_methods.py index 9b62d79..65c65f9 100644 --- a/tests/test_client_methods.py +++ b/tests/test_client_methods.py @@ -8,7 +8,11 @@ from datalab_sdk import DatalabClient, AsyncDatalabClient from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions -from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError +from datalab_sdk.exceptions import ( + DatalabAPIError, + DatalabFileError, + DatalabTimeoutError, +) class TestConvertMethod: @@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir): assert result.html == "

Test Document

" assert result.output_format == "html" + @pytest.mark.asyncio + async def test_convert_async_respects_polling_params(self, temp_dir): + """Verify convert passes max_polls and poll_interval to poller""" + # Create test file + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + # Mock API responses + mock_initial_response = { + "success": True, + "request_id": "rid-1", + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-1", + } + + mock_result_response = { + "success": True, + "status": "complete", + "output_format": "markdown", + "markdown": "ok", + } + + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + with patch.object( + client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_req.return_value = mock_initial_response + mock_poll.return_value = mock_result_response + + max_polls = 7 + poll_interval = 3 + await client.convert( + pdf_file, max_polls=max_polls, poll_interval=poll_interval + ) + + mock_poll.assert_awaited_once() + # Verify kwargs were forwarded + args, kwargs = mock_poll.await_args + assert args[0] == mock_initial_response["request_check_url"] + assert kwargs["max_polls"] == max_polls + assert kwargs["poll_interval"] == poll_interval + class TestOCRMethod: """Test the ocr method""" @@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir): assert "Page 1 content" in all_text assert "Page 2 content" in all_text + @pytest.mark.asyncio + async def test_ocr_async_respects_polling_params(self, temp_dir): + """Verify ocr passes max_polls and poll_interval to poller""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + mock_initial_response = { + "success": True, + "request_id": "rid-2", + "request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2", + } + + mock_result_response = { + "success": True, + "status": "complete", + "pages": [], + } + + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + with patch.object( + client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_req.return_value = mock_initial_response + mock_poll.return_value = mock_result_response + + max_polls = 11 + poll_interval = 2 + await client.ocr( + pdf_file, max_polls=max_polls, poll_interval=poll_interval + ) + + mock_poll.assert_awaited_once() + args, kwargs = mock_poll.await_args + assert args[0] == mock_initial_response["request_check_url"] + assert kwargs["max_polls"] == max_polls + assert kwargs["poll_interval"] == poll_interval + + def test_sync_wrappers_forward_polling_params(self, temp_dir): + """Ensure sync client forwards polling params to async client""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + client = DatalabClient(api_key="test-key") + + # Patch async convert/ocr to capture kwargs + with patch.object( + client._async_client, "convert", new_callable=AsyncMock + ) as mock_conv: + with patch.object( + client._async_client, "ocr", new_callable=AsyncMock + ) as mock_ocr: + mock_conv.return_value = ConversionResult( + success=True, output_format="markdown", markdown="ok" + ) + mock_ocr.return_value = OCRResult(success=True, pages=[]) + + client.convert(pdf_file, max_polls=5, poll_interval=9) + client.ocr(pdf_file, max_polls=6, poll_interval=4) + + # Assert called with forwarded kwargs + _, conv_kwargs = mock_conv.await_args + assert conv_kwargs["max_polls"] == 5 + assert conv_kwargs["poll_interval"] == 9 + + _, ocr_kwargs = mock_ocr.await_args + assert ocr_kwargs["max_polls"] == 6 + assert ocr_kwargs["poll_interval"] == 4 + class TestClientErrorHandling: """Test error handling in client methods""" @@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir): DatalabAPIError, match="Request failed: Processing failed" ): client.convert(pdf_file) + + def test_convert_timeout_bubbles_up(self, temp_dir): + """Polling timeout surfaces as DatalabTimeoutError for sync convert""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + mock_initial_response = { + "success": True, + "request_id": "rid-timeout", + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout", + } + + client = DatalabClient(api_key="test-key") + with patch.object( + client._async_client, "_make_request", new_callable=AsyncMock + ) as mock_request: + with patch.object( + client._async_client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_request.return_value = mock_initial_response + mock_poll.side_effect = DatalabTimeoutError("Polling timed out") + + with pytest.raises(DatalabTimeoutError, match="Polling timed out"): + client.convert(pdf_file) + + +class TestPollingLoop: + """Direct tests for the internal polling helper""" + + @pytest.mark.asyncio + async def test_poll_result_times_out(self): + async with AsyncDatalabClient(api_key="test-key") as client: + with ( + patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req, + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + # Always return processing so we hit timeout + mock_req.return_value = {"status": "processing", "success": True} + + with pytest.raises(DatalabTimeoutError): + await client._poll_result( + "https://api.example.com/check", max_polls=3, poll_interval=0 + ) + + assert mock_req.await_count == 3 + assert mock_sleep.await_count >= 1 + + @pytest.mark.asyncio + async def test_poll_result_raises_on_failed_status(self): + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + mock_req.return_value = { + "status": "failed", + "success": False, + "error": "boom", + } + + with pytest.raises(DatalabAPIError, match="Processing failed: boom"): + await client._poll_result( + "https://api.example.com/check", max_polls=1, poll_interval=0 + ) diff --git a/uv.lock b/uv.lock index 810a5dd..13b8c78 100644 --- a/uv.lock +++ b/uv.lock @@ -169,13 +169,14 @@ wheels = [ [[package]] name = "datalab-python-sdk" -version = "0.1.3" +version = "0.1.6" source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "click" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "tenacity" }, ] [package.dev-dependencies] @@ -195,6 +196,7 @@ requires-dist = [ { name = "click", specifier = ">=8.2.1" }, { name = "pydantic", specifier = ">=2.11.7,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.10.1,<3.0.0" }, + { name = "tenacity", specifier = ">=8.2.3,<9.0.0" }, ] [package.metadata.requires-dev] @@ -857,6 +859,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" }, ] +[[package]] +name = "tenacity" +version = "8.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309, upload-time = "2024-07-05T07:25:31.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" }, +] + [[package]] name = "tomli" version = "2.2.1"