From d89987c2f85b94ba2dfc45fada5842d98c09a9d0 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 9 Sep 2025 21:10:25 -0400 Subject: [PATCH 1/3] feat: allow max polls and poll interval --- datalab_sdk/client.py | 51 ++++++++-- tests/test_client_methods.py | 186 ++++++++++++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 8 deletions(-) diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py index de262f7..4647b9e 100644 --- a/datalab_sdk/client.py +++ b/datalab_sdk/client.py @@ -156,7 +156,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None): if file_url and file_path: raise ValueError("Either file_path or file_url must be provided, not both.") - + # Use either file_url or file upload, not both if file_url: form_data.add_field("file_url", file_url) @@ -184,13 +184,19 @@ async def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint""" if options is None: options = ConvertOptions() initial_data = await self._make_request( - "POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options) + "POST", + "/api/v1/marker", + data=self.get_form_params( + file_path=file_path, file_url=file_url, options=options + ), ) if not initial_data.get("success"): @@ -198,7 +204,11 @@ async def convert( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = ConversionResult( success=result_data.get("success", False), @@ -227,13 +237,17 @@ async def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document""" if options is None: options = OCROptions() initial_data = await self._make_request( - "POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options) + "POST", + "/api/v1/ocr", + data=self.get_form_params(file_path=file_path, options=options), ) if not initial_data.get("success"): @@ -241,7 +255,11 @@ async def ocr( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = OCRResult( success=result_data.get("success", False), @@ -299,10 +317,19 @@ def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint (sync version)""" return self._run_async( - self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output) + self._async_client.convert( + file_path=file_path, + file_url=file_url, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) ) def ocr( @@ -310,6 +337,16 @@ def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document (sync version)""" - return self._run_async(self._async_client.ocr(file_path, options, save_output)) + return self._run_async( + self._async_client.ocr( + file_path=file_path, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) + ) diff --git a/tests/test_client_methods.py b/tests/test_client_methods.py index 9b62d79..65c65f9 100644 --- a/tests/test_client_methods.py +++ b/tests/test_client_methods.py @@ -8,7 +8,11 @@ from datalab_sdk import DatalabClient, AsyncDatalabClient from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions -from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError +from datalab_sdk.exceptions import ( + DatalabAPIError, + DatalabFileError, + DatalabTimeoutError, +) class TestConvertMethod: @@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir): assert result.html == "

Test Document

" assert result.output_format == "html" + @pytest.mark.asyncio + async def test_convert_async_respects_polling_params(self, temp_dir): + """Verify convert passes max_polls and poll_interval to poller""" + # Create test file + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + # Mock API responses + mock_initial_response = { + "success": True, + "request_id": "rid-1", + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-1", + } + + mock_result_response = { + "success": True, + "status": "complete", + "output_format": "markdown", + "markdown": "ok", + } + + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + with patch.object( + client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_req.return_value = mock_initial_response + mock_poll.return_value = mock_result_response + + max_polls = 7 + poll_interval = 3 + await client.convert( + pdf_file, max_polls=max_polls, poll_interval=poll_interval + ) + + mock_poll.assert_awaited_once() + # Verify kwargs were forwarded + args, kwargs = mock_poll.await_args + assert args[0] == mock_initial_response["request_check_url"] + assert kwargs["max_polls"] == max_polls + assert kwargs["poll_interval"] == poll_interval + class TestOCRMethod: """Test the ocr method""" @@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir): assert "Page 1 content" in all_text assert "Page 2 content" in all_text + @pytest.mark.asyncio + async def test_ocr_async_respects_polling_params(self, temp_dir): + """Verify ocr passes max_polls and poll_interval to poller""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + mock_initial_response = { + "success": True, + "request_id": "rid-2", + "request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2", + } + + mock_result_response = { + "success": True, + "status": "complete", + "pages": [], + } + + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + with patch.object( + client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_req.return_value = mock_initial_response + mock_poll.return_value = mock_result_response + + max_polls = 11 + poll_interval = 2 + await client.ocr( + pdf_file, max_polls=max_polls, poll_interval=poll_interval + ) + + mock_poll.assert_awaited_once() + args, kwargs = mock_poll.await_args + assert args[0] == mock_initial_response["request_check_url"] + assert kwargs["max_polls"] == max_polls + assert kwargs["poll_interval"] == poll_interval + + def test_sync_wrappers_forward_polling_params(self, temp_dir): + """Ensure sync client forwards polling params to async client""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + client = DatalabClient(api_key="test-key") + + # Patch async convert/ocr to capture kwargs + with patch.object( + client._async_client, "convert", new_callable=AsyncMock + ) as mock_conv: + with patch.object( + client._async_client, "ocr", new_callable=AsyncMock + ) as mock_ocr: + mock_conv.return_value = ConversionResult( + success=True, output_format="markdown", markdown="ok" + ) + mock_ocr.return_value = OCRResult(success=True, pages=[]) + + client.convert(pdf_file, max_polls=5, poll_interval=9) + client.ocr(pdf_file, max_polls=6, poll_interval=4) + + # Assert called with forwarded kwargs + _, conv_kwargs = mock_conv.await_args + assert conv_kwargs["max_polls"] == 5 + assert conv_kwargs["poll_interval"] == 9 + + _, ocr_kwargs = mock_ocr.await_args + assert ocr_kwargs["max_polls"] == 6 + assert ocr_kwargs["poll_interval"] == 4 + class TestClientErrorHandling: """Test error handling in client methods""" @@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir): DatalabAPIError, match="Request failed: Processing failed" ): client.convert(pdf_file) + + def test_convert_timeout_bubbles_up(self, temp_dir): + """Polling timeout surfaces as DatalabTimeoutError for sync convert""" + pdf_file = temp_dir / "test.pdf" + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") + + mock_initial_response = { + "success": True, + "request_id": "rid-timeout", + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout", + } + + client = DatalabClient(api_key="test-key") + with patch.object( + client._async_client, "_make_request", new_callable=AsyncMock + ) as mock_request: + with patch.object( + client._async_client, "_poll_result", new_callable=AsyncMock + ) as mock_poll: + mock_request.return_value = mock_initial_response + mock_poll.side_effect = DatalabTimeoutError("Polling timed out") + + with pytest.raises(DatalabTimeoutError, match="Polling timed out"): + client.convert(pdf_file) + + +class TestPollingLoop: + """Direct tests for the internal polling helper""" + + @pytest.mark.asyncio + async def test_poll_result_times_out(self): + async with AsyncDatalabClient(api_key="test-key") as client: + with ( + patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req, + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + # Always return processing so we hit timeout + mock_req.return_value = {"status": "processing", "success": True} + + with pytest.raises(DatalabTimeoutError): + await client._poll_result( + "https://api.example.com/check", max_polls=3, poll_interval=0 + ) + + assert mock_req.await_count == 3 + assert mock_sleep.await_count >= 1 + + @pytest.mark.asyncio + async def test_poll_result_raises_on_failed_status(self): + async with AsyncDatalabClient(api_key="test-key") as client: + with patch.object( + client, "_make_request", new_callable=AsyncMock + ) as mock_req: + mock_req.return_value = { + "status": "failed", + "success": False, + "error": "boom", + } + + with pytest.raises(DatalabAPIError, match="Processing failed: boom"): + await client._poll_result( + "https://api.example.com/check", max_polls=1, poll_interval=0 + ) From 5543bff1fd8451cd3d7b857ee892ab8a8e3884cc Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 10 Sep 2025 12:09:33 -0400 Subject: [PATCH 2/3] fix: add retries for make_request on polling in case of timeout errors --- datalab_sdk/client.py | 33 ++++++++++++++++++++++++++++++++- pyproject.toml | 1 + uv.lock | 13 ++++++++++++- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py index 4647b9e..1b7c783 100644 --- a/datalab_sdk/client.py +++ b/datalab_sdk/client.py @@ -5,6 +5,13 @@ import asyncio import mimetypes import aiohttp +from tenacity import ( + AsyncRetrying, + retry_if_exception, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) from pathlib import Path from typing import Union, Optional, Dict, Any @@ -119,7 +126,31 @@ async def _poll_result( ) for i in range(max_polls): - data = await self._make_request("GET", full_url) + # Retry transient failures for the polling GET using tenacity + async for attempt in AsyncRetrying( + retry=( + retry_if_exception_type(DatalabTimeoutError) + | retry_if_exception( + lambda e: isinstance(e, DatalabAPIError) + and ( + # retry request timeout or too many requests + getattr(e, "status_code", None) in (408, 429) + or ( + # or if there's a server error + getattr(e, "status_code", None) is not None + and getattr(e, "status_code") >= 500 + ) + # or datalab api error without status code + or getattr(e, "status_code", None) is None + ) + ) + ), + stop=stop_after_attempt(2), + wait=wait_exponential_jitter(max=0.5), + reraise=True, + ): + with attempt: + data = await self._make_request("GET", full_url) if data.get("status") == "complete": return data diff --git a/pyproject.toml b/pyproject.toml index 3b8167a..29c65d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "click>=8.2.1", "pydantic>=2.11.7,<3.0.0", "pydantic-settings>=2.10.1,<3.0.0", + "tenacity>=8.2.3,<9.0.0", ] [project.scripts] diff --git a/uv.lock b/uv.lock index 810a5dd..b1cb84a 100644 --- a/uv.lock +++ b/uv.lock @@ -169,13 +169,14 @@ wheels = [ [[package]] name = "datalab-python-sdk" -version = "0.1.3" +version = "0.1.4" source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "click" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "tenacity" }, ] [package.dev-dependencies] @@ -195,6 +196,7 @@ requires-dist = [ { name = "click", specifier = ">=8.2.1" }, { name = "pydantic", specifier = ">=2.11.7,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.10.1,<3.0.0" }, + { name = "tenacity", specifier = ">=8.2.3,<9.0.0" }, ] [package.metadata.requires-dev] @@ -857,6 +859,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" }, ] +[[package]] +name = "tenacity" +version = "8.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309, upload-time = "2024-07-05T07:25:31.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" }, +] + [[package]] name = "tomli" version = "2.2.1" From 91b90f75567f0095549a3c1a38c7d61dbfb57476 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 10 Sep 2025 13:38:04 -0400 Subject: [PATCH 3/3] refactor: retry func w decorator --- datalab_sdk/client.py | 54 ++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py index 1b7c783..fe22228 100644 --- a/datalab_sdk/client.py +++ b/datalab_sdk/client.py @@ -6,7 +6,7 @@ import mimetypes import aiohttp from tenacity import ( - AsyncRetrying, + retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, @@ -126,31 +126,7 @@ async def _poll_result( ) for i in range(max_polls): - # Retry transient failures for the polling GET using tenacity - async for attempt in AsyncRetrying( - retry=( - retry_if_exception_type(DatalabTimeoutError) - | retry_if_exception( - lambda e: isinstance(e, DatalabAPIError) - and ( - # retry request timeout or too many requests - getattr(e, "status_code", None) in (408, 429) - or ( - # or if there's a server error - getattr(e, "status_code", None) is not None - and getattr(e, "status_code") >= 500 - ) - # or datalab api error without status code - or getattr(e, "status_code", None) is None - ) - ) - ), - stop=stop_after_attempt(2), - wait=wait_exponential_jitter(max=0.5), - reraise=True, - ): - with attempt: - data = await self._make_request("GET", full_url) + data = await self._poll_get_with_retry(full_url) if data.get("status") == "complete": return data @@ -166,6 +142,32 @@ async def _poll_result( f"Polling timed out after {max_polls * poll_interval} seconds" ) + @retry( + retry=( + retry_if_exception_type(DatalabTimeoutError) + | retry_if_exception( + lambda e: isinstance(e, DatalabAPIError) + and ( + # retry request timeout or too many requests + getattr(e, "status_code", None) in (408, 429) + or ( + # or if there's a server error + getattr(e, "status_code", None) is not None + and getattr(e, "status_code") >= 500 + ) + # or datalab api error without status code (e.g., connection errors) + or getattr(e, "status_code", None) is None + ) + ) + ), + stop=stop_after_attempt(2), + wait=wait_exponential_jitter(max=0.5), + reraise=True, + ) + async def _poll_get_with_retry(self, url: str) -> Dict[str, Any]: + """GET wrapper for polling with scoped retries for transient failures""" + return await self._make_request("GET", url) + def _prepare_file_data(self, file_path: Union[str, Path]) -> tuple: """Prepare file data for upload""" file_path = Path(file_path)