diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py
index de262f7..fe22228 100644
--- a/datalab_sdk/client.py
+++ b/datalab_sdk/client.py
@@ -5,6 +5,13 @@
import asyncio
import mimetypes
import aiohttp
+from tenacity import (
+ retry,
+ retry_if_exception,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential_jitter,
+)
from pathlib import Path
from typing import Union, Optional, Dict, Any
@@ -119,7 +126,7 @@ async def _poll_result(
)
for i in range(max_polls):
- data = await self._make_request("GET", full_url)
+ data = await self._poll_get_with_retry(full_url)
if data.get("status") == "complete":
return data
@@ -135,6 +142,32 @@ async def _poll_result(
f"Polling timed out after {max_polls * poll_interval} seconds"
)
+ @retry(
+ retry=(
+ retry_if_exception_type(DatalabTimeoutError)
+ | retry_if_exception(
+ lambda e: isinstance(e, DatalabAPIError)
+ and (
+ # retry request timeout or too many requests
+ getattr(e, "status_code", None) in (408, 429)
+ or (
+ # or if there's a server error
+ getattr(e, "status_code", None) is not None
+ and getattr(e, "status_code") >= 500
+ )
+ # or datalab api error without status code (e.g., connection errors)
+ or getattr(e, "status_code", None) is None
+ )
+ )
+ ),
+ stop=stop_after_attempt(2),
+ wait=wait_exponential_jitter(max=0.5),
+ reraise=True,
+ )
+ async def _poll_get_with_retry(self, url: str) -> Dict[str, Any]:
+ """GET wrapper for polling with scoped retries for transient failures"""
+ return await self._make_request("GET", url)
+
def _prepare_file_data(self, file_path: Union[str, Path]) -> tuple:
"""Prepare file data for upload"""
file_path = Path(file_path)
@@ -156,7 +189,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None):
if file_url and file_path:
raise ValueError("Either file_path or file_url must be provided, not both.")
-
+
# Use either file_url or file upload, not both
if file_url:
form_data.add_field("file_url", file_url)
@@ -184,13 +217,19 @@ async def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
+ max_polls: int = 300,
+ poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint"""
if options is None:
options = ConvertOptions()
initial_data = await self._make_request(
- "POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options)
+ "POST",
+ "/api/v1/marker",
+ data=self.get_form_params(
+ file_path=file_path, file_url=file_url, options=options
+ ),
)
if not initial_data.get("success"):
@@ -198,7 +237,11 @@ async def convert(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)
- result_data = await self._poll_result(initial_data["request_check_url"])
+ result_data = await self._poll_result(
+ initial_data["request_check_url"],
+ max_polls=max_polls,
+ poll_interval=poll_interval,
+ )
result = ConversionResult(
success=result_data.get("success", False),
@@ -227,13 +270,17 @@ async def ocr(
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
+ max_polls: int = 300,
+ poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document"""
if options is None:
options = OCROptions()
initial_data = await self._make_request(
- "POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options)
+ "POST",
+ "/api/v1/ocr",
+ data=self.get_form_params(file_path=file_path, options=options),
)
if not initial_data.get("success"):
@@ -241,7 +288,11 @@ async def ocr(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)
- result_data = await self._poll_result(initial_data["request_check_url"])
+ result_data = await self._poll_result(
+ initial_data["request_check_url"],
+ max_polls=max_polls,
+ poll_interval=poll_interval,
+ )
result = OCRResult(
success=result_data.get("success", False),
@@ -299,10 +350,19 @@ def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
+ max_polls: int = 300,
+ poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint (sync version)"""
return self._run_async(
- self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output)
+ self._async_client.convert(
+ file_path=file_path,
+ file_url=file_url,
+ options=options,
+ save_output=save_output,
+ max_polls=max_polls,
+ poll_interval=poll_interval,
+ )
)
def ocr(
@@ -310,6 +370,16 @@ def ocr(
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
+ max_polls: int = 300,
+ poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document (sync version)"""
- return self._run_async(self._async_client.ocr(file_path, options, save_output))
+ return self._run_async(
+ self._async_client.ocr(
+ file_path=file_path,
+ options=options,
+ save_output=save_output,
+ max_polls=max_polls,
+ poll_interval=poll_interval,
+ )
+ )
diff --git a/pyproject.toml b/pyproject.toml
index 3b8167a..29c65d8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,6 +15,7 @@ dependencies = [
"click>=8.2.1",
"pydantic>=2.11.7,<3.0.0",
"pydantic-settings>=2.10.1,<3.0.0",
+ "tenacity>=8.2.3,<9.0.0",
]
[project.scripts]
diff --git a/tests/test_client_methods.py b/tests/test_client_methods.py
index 9b62d79..65c65f9 100644
--- a/tests/test_client_methods.py
+++ b/tests/test_client_methods.py
@@ -8,7 +8,11 @@
from datalab_sdk import DatalabClient, AsyncDatalabClient
from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions
-from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError
+from datalab_sdk.exceptions import (
+ DatalabAPIError,
+ DatalabFileError,
+ DatalabTimeoutError,
+)
class TestConvertMethod:
@@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir):
assert result.html == "
Test Document
"
assert result.output_format == "html"
+ @pytest.mark.asyncio
+ async def test_convert_async_respects_polling_params(self, temp_dir):
+ """Verify convert passes max_polls and poll_interval to poller"""
+ # Create test file
+ pdf_file = temp_dir / "test.pdf"
+ pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
+
+ # Mock API responses
+ mock_initial_response = {
+ "success": True,
+ "request_id": "rid-1",
+ "request_check_url": "https://api.datalab.to/api/v1/marker/rid-1",
+ }
+
+ mock_result_response = {
+ "success": True,
+ "status": "complete",
+ "output_format": "markdown",
+ "markdown": "ok",
+ }
+
+ async with AsyncDatalabClient(api_key="test-key") as client:
+ with patch.object(
+ client, "_make_request", new_callable=AsyncMock
+ ) as mock_req:
+ with patch.object(
+ client, "_poll_result", new_callable=AsyncMock
+ ) as mock_poll:
+ mock_req.return_value = mock_initial_response
+ mock_poll.return_value = mock_result_response
+
+ max_polls = 7
+ poll_interval = 3
+ await client.convert(
+ pdf_file, max_polls=max_polls, poll_interval=poll_interval
+ )
+
+ mock_poll.assert_awaited_once()
+ # Verify kwargs were forwarded
+ args, kwargs = mock_poll.await_args
+ assert args[0] == mock_initial_response["request_check_url"]
+ assert kwargs["max_polls"] == max_polls
+ assert kwargs["poll_interval"] == poll_interval
+
class TestOCRMethod:
"""Test the ocr method"""
@@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir):
assert "Page 1 content" in all_text
assert "Page 2 content" in all_text
+ @pytest.mark.asyncio
+ async def test_ocr_async_respects_polling_params(self, temp_dir):
+ """Verify ocr passes max_polls and poll_interval to poller"""
+ pdf_file = temp_dir / "test.pdf"
+ pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
+
+ mock_initial_response = {
+ "success": True,
+ "request_id": "rid-2",
+ "request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2",
+ }
+
+ mock_result_response = {
+ "success": True,
+ "status": "complete",
+ "pages": [],
+ }
+
+ async with AsyncDatalabClient(api_key="test-key") as client:
+ with patch.object(
+ client, "_make_request", new_callable=AsyncMock
+ ) as mock_req:
+ with patch.object(
+ client, "_poll_result", new_callable=AsyncMock
+ ) as mock_poll:
+ mock_req.return_value = mock_initial_response
+ mock_poll.return_value = mock_result_response
+
+ max_polls = 11
+ poll_interval = 2
+ await client.ocr(
+ pdf_file, max_polls=max_polls, poll_interval=poll_interval
+ )
+
+ mock_poll.assert_awaited_once()
+ args, kwargs = mock_poll.await_args
+ assert args[0] == mock_initial_response["request_check_url"]
+ assert kwargs["max_polls"] == max_polls
+ assert kwargs["poll_interval"] == poll_interval
+
+ def test_sync_wrappers_forward_polling_params(self, temp_dir):
+ """Ensure sync client forwards polling params to async client"""
+ pdf_file = temp_dir / "test.pdf"
+ pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
+
+ client = DatalabClient(api_key="test-key")
+
+ # Patch async convert/ocr to capture kwargs
+ with patch.object(
+ client._async_client, "convert", new_callable=AsyncMock
+ ) as mock_conv:
+ with patch.object(
+ client._async_client, "ocr", new_callable=AsyncMock
+ ) as mock_ocr:
+ mock_conv.return_value = ConversionResult(
+ success=True, output_format="markdown", markdown="ok"
+ )
+ mock_ocr.return_value = OCRResult(success=True, pages=[])
+
+ client.convert(pdf_file, max_polls=5, poll_interval=9)
+ client.ocr(pdf_file, max_polls=6, poll_interval=4)
+
+ # Assert called with forwarded kwargs
+ _, conv_kwargs = mock_conv.await_args
+ assert conv_kwargs["max_polls"] == 5
+ assert conv_kwargs["poll_interval"] == 9
+
+ _, ocr_kwargs = mock_ocr.await_args
+ assert ocr_kwargs["max_polls"] == 6
+ assert ocr_kwargs["poll_interval"] == 4
+
class TestClientErrorHandling:
"""Test error handling in client methods"""
@@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir):
DatalabAPIError, match="Request failed: Processing failed"
):
client.convert(pdf_file)
+
+ def test_convert_timeout_bubbles_up(self, temp_dir):
+ """Polling timeout surfaces as DatalabTimeoutError for sync convert"""
+ pdf_file = temp_dir / "test.pdf"
+ pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
+
+ mock_initial_response = {
+ "success": True,
+ "request_id": "rid-timeout",
+ "request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout",
+ }
+
+ client = DatalabClient(api_key="test-key")
+ with patch.object(
+ client._async_client, "_make_request", new_callable=AsyncMock
+ ) as mock_request:
+ with patch.object(
+ client._async_client, "_poll_result", new_callable=AsyncMock
+ ) as mock_poll:
+ mock_request.return_value = mock_initial_response
+ mock_poll.side_effect = DatalabTimeoutError("Polling timed out")
+
+ with pytest.raises(DatalabTimeoutError, match="Polling timed out"):
+ client.convert(pdf_file)
+
+
+class TestPollingLoop:
+ """Direct tests for the internal polling helper"""
+
+ @pytest.mark.asyncio
+ async def test_poll_result_times_out(self):
+ async with AsyncDatalabClient(api_key="test-key") as client:
+ with (
+ patch.object(
+ client, "_make_request", new_callable=AsyncMock
+ ) as mock_req,
+ patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
+ ):
+ # Always return processing so we hit timeout
+ mock_req.return_value = {"status": "processing", "success": True}
+
+ with pytest.raises(DatalabTimeoutError):
+ await client._poll_result(
+ "https://api.example.com/check", max_polls=3, poll_interval=0
+ )
+
+ assert mock_req.await_count == 3
+ assert mock_sleep.await_count >= 1
+
+ @pytest.mark.asyncio
+ async def test_poll_result_raises_on_failed_status(self):
+ async with AsyncDatalabClient(api_key="test-key") as client:
+ with patch.object(
+ client, "_make_request", new_callable=AsyncMock
+ ) as mock_req:
+ mock_req.return_value = {
+ "status": "failed",
+ "success": False,
+ "error": "boom",
+ }
+
+ with pytest.raises(DatalabAPIError, match="Processing failed: boom"):
+ await client._poll_result(
+ "https://api.example.com/check", max_polls=1, poll_interval=0
+ )
diff --git a/uv.lock b/uv.lock
index 810a5dd..b1cb84a 100644
--- a/uv.lock
+++ b/uv.lock
@@ -169,13 +169,14 @@ wheels = [
[[package]]
name = "datalab-python-sdk"
-version = "0.1.3"
+version = "0.1.4"
source = { editable = "." }
dependencies = [
{ name = "aiohttp" },
{ name = "click" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
+ { name = "tenacity" },
]
[package.dev-dependencies]
@@ -195,6 +196,7 @@ requires-dist = [
{ name = "click", specifier = ">=8.2.1" },
{ name = "pydantic", specifier = ">=2.11.7,<3.0.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1,<3.0.0" },
+ { name = "tenacity", specifier = ">=8.2.3,<9.0.0" },
]
[package.metadata.requires-dev]
@@ -857,6 +859,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" },
]
+[[package]]
+name = "tenacity"
+version = "8.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309, upload-time = "2024-07-05T07:25:31.836Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" },
+]
+
[[package]]
name = "tomli"
version = "2.2.1"