Skip to content
Merged

Dev #13

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python/examples
e2e.sh
TODO.md
.vscode/
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
86 changes: 78 additions & 8 deletions datalab_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
import asyncio
import mimetypes
import aiohttp
from tenacity import (
retry,
retry_if_exception,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from pathlib import Path
from typing import Union, Optional, Dict, Any

Expand Down Expand Up @@ -119,7 +126,7 @@ async def _poll_result(
)

for i in range(max_polls):
data = await self._make_request("GET", full_url)
data = await self._poll_get_with_retry(full_url)

if data.get("status") == "complete":
return data
Expand All @@ -135,6 +142,32 @@ async def _poll_result(
f"Polling timed out after {max_polls * poll_interval} seconds"
)

@retry(
retry=(
retry_if_exception_type(DatalabTimeoutError)
| retry_if_exception(
lambda e: isinstance(e, DatalabAPIError)
and (
# retry request timeout or too many requests
getattr(e, "status_code", None) in (408, 429)
or (
# or if there's a server error
getattr(e, "status_code", None) is not None
and getattr(e, "status_code") >= 500
)
# or datalab api error without status code (e.g., connection errors)
or getattr(e, "status_code", None) is None
)
)
),
stop=stop_after_attempt(2),
wait=wait_exponential_jitter(max=0.5),
reraise=True,
)
async def _poll_get_with_retry(self, url: str) -> Dict[str, Any]:
"""GET wrapper for polling with scoped retries for transient failures"""
return await self._make_request("GET", url)

def _prepare_file_data(self, file_path: Union[str, Path]) -> tuple:
"""Prepare file data for upload"""
file_path = Path(file_path)
Expand All @@ -156,7 +189,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None):

if file_url and file_path:
raise ValueError("Either file_path or file_url must be provided, not both.")

# Use either file_url or file upload, not both
if file_url:
form_data.add_field("file_url", file_url)
Expand Down Expand Up @@ -184,21 +217,31 @@ async def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint"""
if options is None:
options = ConvertOptions()

initial_data = await self._make_request(
"POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options)
"POST",
"/api/v1/marker",
data=self.get_form_params(
file_path=file_path, file_url=file_url, options=options
),
)

if not initial_data.get("success"):
raise DatalabAPIError(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)

result_data = await self._poll_result(initial_data["request_check_url"])
result_data = await self._poll_result(
initial_data["request_check_url"],
max_polls=max_polls,
poll_interval=poll_interval,
)

result = ConversionResult(
success=result_data.get("success", False),
Expand Down Expand Up @@ -227,21 +270,29 @@ async def ocr(
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document"""
if options is None:
options = OCROptions()

initial_data = await self._make_request(
"POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options)
"POST",
"/api/v1/ocr",
data=self.get_form_params(file_path=file_path, options=options),
)

if not initial_data.get("success"):
raise DatalabAPIError(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)

result_data = await self._poll_result(initial_data["request_check_url"])
result_data = await self._poll_result(
initial_data["request_check_url"],
max_polls=max_polls,
poll_interval=poll_interval,
)

result = OCRResult(
success=result_data.get("success", False),
Expand Down Expand Up @@ -299,17 +350,36 @@ def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint (sync version)"""
return self._run_async(
self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output)
self._async_client.convert(
file_path=file_path,
file_url=file_url,
options=options,
save_output=save_output,
max_polls=max_polls,
poll_interval=poll_interval,
)
)

def ocr(
self,
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document (sync version)"""
return self._run_async(self._async_client.ocr(file_path, options, save_output))
return self._run_async(
self._async_client.ocr(
file_path=file_path,
options=options,
save_output=save_output,
max_polls=max_polls,
poll_interval=poll_interval,
)
)
9 changes: 7 additions & 2 deletions datalab_sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class ConvertOptions(ProcessingOptions):
block_correction_prompt: Optional[str] = None
additional_config: Optional[Dict[str, Any]] = None
page_schema: Optional[Dict[str, Any]] = None
output_format: str = "markdown" # markdown, json, html
output_format: str = "markdown" # markdown, json, html, chunks
mode: str = "fast" # fast, balanced, accurate


@dataclass
Expand Down Expand Up @@ -91,7 +92,11 @@ def save_output(
json.dump(self.json, f, indent=2)

if self.extraction_schema_json:
with open(output_path.with_suffix("_extraction_results.json"), "w", encoding="utf-8") as f:
with open(
output_path.with_suffix("_extraction_results.json"),
"w",
encoding="utf-8",
) as f:
f.write(self.extraction_schema_json)

# Save images if present
Expand Down
2 changes: 1 addition & 1 deletion datalab_sdk/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Settings(BaseSettings):
# Paths
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
LOGLEVEL: str = "DEBUG"
VERSION: str = "0.1.3"
VERSION: str = "0.1.6"

# Base settings
DATALAB_API_KEY: str | None = None
Expand Down
8 changes: 8 additions & 0 deletions integration/test_live_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def test_convert_office_document(self):
assert len(result.html) > 0
assert result.output_format == "html"

def test_convert_pdf_high_accuracy(self):
client = DatalabClient()
pdf_file = DATA_DIR / "adversarial.pdf"
options = ConvertOptions(mode="accurate", max_pages=1)
result = client.convert(pdf_file, options=options)

assert "subspace" in result.markdown.lower()

@pytest.mark.asyncio
async def test_convert_async_with_json(self):
"""Test async conversion with JSON output"""
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ readme = "README.md"
license = "MIT"
repository = "https://github.com/datalab-to/sdk"
keywords = ["datalab", "sdk", "document-intelligence", "api"]
version = "0.1.5"
version = "0.1.6"
description = "SDK for the Datalab document intelligence API"
requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.12.14",
"click>=8.2.1",
"pydantic>=2.11.7,<3.0.0",
"pydantic-settings>=2.10.1,<3.0.0",
"tenacity>=8.2.3,<9.0.0",
]

[project.scripts]
Expand All @@ -27,6 +28,7 @@ test = [
"pytest-mock>=3.11.0",
"pytest-cov>=4.1.0",
"aiofiles>=23.2.0",
"pytest-xdist>=3.8.0",
]

[build-system]
Expand Down
Loading