diff --git a/.gitignore b/.gitignore index 9590e2b..42c4f52 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ python/examples e2e.sh TODO.md .vscode/ +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/datalab_sdk/client.py b/datalab_sdk/client.py index de262f7..fe22228 100644 --- a/datalab_sdk/client.py +++ b/datalab_sdk/client.py @@ -5,6 +5,13 @@ import asyncio import mimetypes import aiohttp +from tenacity import ( + retry, + retry_if_exception, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) from pathlib import Path from typing import Union, Optional, Dict, Any @@ -119,7 +126,7 @@ async def _poll_result( ) for i in range(max_polls): - data = await self._make_request("GET", full_url) + data = await self._poll_get_with_retry(full_url) if data.get("status") == "complete": return data @@ -135,6 +142,32 @@ async def _poll_result( f"Polling timed out after {max_polls * poll_interval} seconds" ) + @retry( + retry=( + retry_if_exception_type(DatalabTimeoutError) + | retry_if_exception( + lambda e: isinstance(e, DatalabAPIError) + and ( + # retry request timeout or too many requests + getattr(e, "status_code", None) in (408, 429) + or ( + # or if there's a server error + getattr(e, "status_code", None) is not None + and getattr(e, "status_code") >= 500 + ) + # or datalab api error without status code (e.g., connection errors) + or getattr(e, "status_code", None) is None + ) + ) + ), + stop=stop_after_attempt(2), + wait=wait_exponential_jitter(max=0.5), + reraise=True, + ) + async def _poll_get_with_retry(self, url: str) -> Dict[str, Any]: + """GET wrapper for polling with scoped retries for transient failures""" + return await self._make_request("GET", url) + def _prepare_file_data(self, file_path: Union[str, Path]) -> tuple: """Prepare file data for upload""" file_path = Path(file_path) @@ -156,7 +189,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None): if file_url and file_path: raise ValueError("Either file_path or file_url must be provided, not both.") - + # Use either file_url or file upload, not both if file_url: form_data.add_field("file_url", file_url) @@ -184,13 +217,19 @@ async def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint""" if options is None: options = ConvertOptions() initial_data = await self._make_request( - "POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options) + "POST", + "/api/v1/marker", + data=self.get_form_params( + file_path=file_path, file_url=file_url, options=options + ), ) if not initial_data.get("success"): @@ -198,7 +237,11 @@ async def convert( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = ConversionResult( success=result_data.get("success", False), @@ -227,13 +270,17 @@ async def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document""" if options is None: options = OCROptions() initial_data = await self._make_request( - "POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options) + "POST", + "/api/v1/ocr", + data=self.get_form_params(file_path=file_path, options=options), ) if not initial_data.get("success"): @@ -241,7 +288,11 @@ async def ocr( f"Request failed: {initial_data.get('error', 'Unknown error')}" ) - result_data = await self._poll_result(initial_data["request_check_url"]) + result_data = await self._poll_result( + initial_data["request_check_url"], + max_polls=max_polls, + poll_interval=poll_interval, + ) result = OCRResult( success=result_data.get("success", False), @@ -299,10 +350,19 @@ def convert( file_url: Optional[str] = None, options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> ConversionResult: """Convert a document using the marker endpoint (sync version)""" return self._run_async( - self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output) + self._async_client.convert( + file_path=file_path, + file_url=file_url, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) ) def ocr( @@ -310,6 +370,16 @@ def ocr( file_path: Union[str, Path], options: Optional[ProcessingOptions] = None, save_output: Optional[Union[str, Path]] = None, + max_polls: int = 300, + poll_interval: int = 1, ) -> OCRResult: """Perform OCR on a document (sync version)""" - return self._run_async(self._async_client.ocr(file_path, options, save_output)) + return self._run_async( + self._async_client.ocr( + file_path=file_path, + options=options, + save_output=save_output, + max_polls=max_polls, + poll_interval=poll_interval, + ) + ) diff --git a/datalab_sdk/models.py b/datalab_sdk/models.py index 0e0d486..07c79bc 100644 --- a/datalab_sdk/models.py +++ b/datalab_sdk/models.py @@ -47,7 +47,8 @@ class ConvertOptions(ProcessingOptions): block_correction_prompt: Optional[str] = None additional_config: Optional[Dict[str, Any]] = None page_schema: Optional[Dict[str, Any]] = None - output_format: str = "markdown" # markdown, json, html + output_format: str = "markdown" # markdown, json, html, chunks + mode: str = "fast" # fast, balanced, accurate @dataclass @@ -91,7 +92,11 @@ def save_output( json.dump(self.json, f, indent=2) if self.extraction_schema_json: - with open(output_path.with_suffix("_extraction_results.json"), "w", encoding="utf-8") as f: + with open( + output_path.with_suffix("_extraction_results.json"), + "w", + encoding="utf-8", + ) as f: f.write(self.extraction_schema_json) # Save images if present diff --git a/datalab_sdk/settings.py b/datalab_sdk/settings.py index 3c2fb86..605a008 100644 --- a/datalab_sdk/settings.py +++ b/datalab_sdk/settings.py @@ -6,7 +6,7 @@ class Settings(BaseSettings): # Paths BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) LOGLEVEL: str = "DEBUG" - VERSION: str = "0.1.3" + VERSION: str = "0.1.6" # Base settings DATALAB_API_KEY: str | None = None diff --git a/integration/test_live_api.py b/integration/test_live_api.py index 19c9a40..45970b2 100644 --- a/integration/test_live_api.py +++ b/integration/test_live_api.py @@ -62,6 +62,14 @@ def test_convert_office_document(self): assert len(result.html) > 0 assert result.output_format == "html" + def test_convert_pdf_high_accuracy(self): + client = DatalabClient() + pdf_file = DATA_DIR / "adversarial.pdf" + options = ConvertOptions(mode="accurate", max_pages=1) + result = client.convert(pdf_file, options=options) + + assert "subspace" in result.markdown.lower() + @pytest.mark.asyncio async def test_convert_async_with_json(self): """Test async conversion with JSON output""" diff --git a/pyproject.toml b/pyproject.toml index fb63155..e1ce78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" license = "MIT" repository = "https://github.com/datalab-to/sdk" keywords = ["datalab", "sdk", "document-intelligence", "api"] -version = "0.1.5" +version = "0.1.6" description = "SDK for the Datalab document intelligence API" requires-python = ">=3.10" dependencies = [ @@ -15,6 +15,7 @@ dependencies = [ "click>=8.2.1", "pydantic>=2.11.7,<3.0.0", "pydantic-settings>=2.10.1,<3.0.0", + "tenacity>=8.2.3,<9.0.0", ] [project.scripts] @@ -27,6 +28,7 @@ test = [ "pytest-mock>=3.11.0", "pytest-cov>=4.1.0", "aiofiles>=23.2.0", + "pytest-xdist>=3.8.0", ] [build-system] diff --git a/tests/test_client_methods.py b/tests/test_client_methods.py index 9b62d79..65c65f9 100644 --- a/tests/test_client_methods.py +++ b/tests/test_client_methods.py @@ -8,7 +8,11 @@ from datalab_sdk import DatalabClient, AsyncDatalabClient from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions -from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError +from datalab_sdk.exceptions import ( + DatalabAPIError, + DatalabFileError, + DatalabTimeoutError, +) class TestConvertMethod: @@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir): assert result.html == "