Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions datalab_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None):

if file_url and file_path:
raise ValueError("Either file_path or file_url must be provided, not both.")

# Use either file_url or file upload, not both
if file_url:
form_data.add_field("file_url", file_url)
Expand Down Expand Up @@ -184,21 +184,31 @@ async def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint"""
if options is None:
options = ConvertOptions()

initial_data = await self._make_request(
"POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options)
"POST",
"/api/v1/marker",
data=self.get_form_params(
file_path=file_path, file_url=file_url, options=options
),
)

if not initial_data.get("success"):
raise DatalabAPIError(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)

result_data = await self._poll_result(initial_data["request_check_url"])
result_data = await self._poll_result(
initial_data["request_check_url"],
max_polls=max_polls,
poll_interval=poll_interval,
)

result = ConversionResult(
success=result_data.get("success", False),
Expand Down Expand Up @@ -227,21 +237,29 @@ async def ocr(
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document"""
if options is None:
options = OCROptions()

initial_data = await self._make_request(
"POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options)
"POST",
"/api/v1/ocr",
data=self.get_form_params(file_path=file_path, options=options),
)

if not initial_data.get("success"):
raise DatalabAPIError(
f"Request failed: {initial_data.get('error', 'Unknown error')}"
)

result_data = await self._poll_result(initial_data["request_check_url"])
result_data = await self._poll_result(
initial_data["request_check_url"],
max_polls=max_polls,
poll_interval=poll_interval,
)

result = OCRResult(
success=result_data.get("success", False),
Expand Down Expand Up @@ -299,17 +317,36 @@ def convert(
file_url: Optional[str] = None,
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> ConversionResult:
"""Convert a document using the marker endpoint (sync version)"""
return self._run_async(
self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output)
self._async_client.convert(
file_path=file_path,
file_url=file_url,
options=options,
save_output=save_output,
max_polls=max_polls,
poll_interval=poll_interval,
)
)

def ocr(
self,
file_path: Union[str, Path],
options: Optional[ProcessingOptions] = None,
save_output: Optional[Union[str, Path]] = None,
max_polls: int = 300,
poll_interval: int = 1,
) -> OCRResult:
"""Perform OCR on a document (sync version)"""
return self._run_async(self._async_client.ocr(file_path, options, save_output))
return self._run_async(
self._async_client.ocr(
file_path=file_path,
options=options,
save_output=save_output,
max_polls=max_polls,
poll_interval=poll_interval,
)
)
186 changes: 185 additions & 1 deletion tests/test_client_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from datalab_sdk import DatalabClient, AsyncDatalabClient
from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions
from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError
from datalab_sdk.exceptions import (
DatalabAPIError,
DatalabFileError,
DatalabTimeoutError,
)


class TestConvertMethod:
Expand Down Expand Up @@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir):
assert result.html == "<h1>Test Document</h1>"
assert result.output_format == "html"

@pytest.mark.asyncio
async def test_convert_async_respects_polling_params(self, temp_dir):
"""Verify convert passes max_polls and poll_interval to poller"""
# Create test file
pdf_file = temp_dir / "test.pdf"
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")

# Mock API responses
mock_initial_response = {
"success": True,
"request_id": "rid-1",
"request_check_url": "https://api.datalab.to/api/v1/marker/rid-1",
}

mock_result_response = {
"success": True,
"status": "complete",
"output_format": "markdown",
"markdown": "ok",
}

async with AsyncDatalabClient(api_key="test-key") as client:
with patch.object(
client, "_make_request", new_callable=AsyncMock
) as mock_req:
with patch.object(
client, "_poll_result", new_callable=AsyncMock
) as mock_poll:
mock_req.return_value = mock_initial_response
mock_poll.return_value = mock_result_response

max_polls = 7
poll_interval = 3
await client.convert(
pdf_file, max_polls=max_polls, poll_interval=poll_interval
)

mock_poll.assert_awaited_once()
# Verify kwargs were forwarded
args, kwargs = mock_poll.await_args
assert args[0] == mock_initial_response["request_check_url"]
assert kwargs["max_polls"] == max_polls
assert kwargs["poll_interval"] == poll_interval


class TestOCRMethod:
"""Test the ocr method"""
Expand Down Expand Up @@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir):
assert "Page 1 content" in all_text
assert "Page 2 content" in all_text

@pytest.mark.asyncio
async def test_ocr_async_respects_polling_params(self, temp_dir):
"""Verify ocr passes max_polls and poll_interval to poller"""
pdf_file = temp_dir / "test.pdf"
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")

mock_initial_response = {
"success": True,
"request_id": "rid-2",
"request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2",
}

mock_result_response = {
"success": True,
"status": "complete",
"pages": [],
}

async with AsyncDatalabClient(api_key="test-key") as client:
with patch.object(
client, "_make_request", new_callable=AsyncMock
) as mock_req:
with patch.object(
client, "_poll_result", new_callable=AsyncMock
) as mock_poll:
mock_req.return_value = mock_initial_response
mock_poll.return_value = mock_result_response

max_polls = 11
poll_interval = 2
await client.ocr(
pdf_file, max_polls=max_polls, poll_interval=poll_interval
)

mock_poll.assert_awaited_once()
args, kwargs = mock_poll.await_args
assert args[0] == mock_initial_response["request_check_url"]
assert kwargs["max_polls"] == max_polls
assert kwargs["poll_interval"] == poll_interval

def test_sync_wrappers_forward_polling_params(self, temp_dir):
"""Ensure sync client forwards polling params to async client"""
pdf_file = temp_dir / "test.pdf"
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")

client = DatalabClient(api_key="test-key")

# Patch async convert/ocr to capture kwargs
with patch.object(
client._async_client, "convert", new_callable=AsyncMock
) as mock_conv:
with patch.object(
client._async_client, "ocr", new_callable=AsyncMock
) as mock_ocr:
mock_conv.return_value = ConversionResult(
success=True, output_format="markdown", markdown="ok"
)
mock_ocr.return_value = OCRResult(success=True, pages=[])

client.convert(pdf_file, max_polls=5, poll_interval=9)
client.ocr(pdf_file, max_polls=6, poll_interval=4)

# Assert called with forwarded kwargs
_, conv_kwargs = mock_conv.await_args
assert conv_kwargs["max_polls"] == 5
assert conv_kwargs["poll_interval"] == 9

_, ocr_kwargs = mock_ocr.await_args
assert ocr_kwargs["max_polls"] == 6
assert ocr_kwargs["poll_interval"] == 4


class TestClientErrorHandling:
"""Test error handling in client methods"""
Expand Down Expand Up @@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir):
DatalabAPIError, match="Request failed: Processing failed"
):
client.convert(pdf_file)

def test_convert_timeout_bubbles_up(self, temp_dir):
"""Polling timeout surfaces as DatalabTimeoutError for sync convert"""
pdf_file = temp_dir / "test.pdf"
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")

mock_initial_response = {
"success": True,
"request_id": "rid-timeout",
"request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout",
}

client = DatalabClient(api_key="test-key")
with patch.object(
client._async_client, "_make_request", new_callable=AsyncMock
) as mock_request:
with patch.object(
client._async_client, "_poll_result", new_callable=AsyncMock
) as mock_poll:
mock_request.return_value = mock_initial_response
mock_poll.side_effect = DatalabTimeoutError("Polling timed out")

with pytest.raises(DatalabTimeoutError, match="Polling timed out"):
client.convert(pdf_file)


class TestPollingLoop:
"""Direct tests for the internal polling helper"""

@pytest.mark.asyncio
async def test_poll_result_times_out(self):
async with AsyncDatalabClient(api_key="test-key") as client:
with (
patch.object(
client, "_make_request", new_callable=AsyncMock
) as mock_req,
patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
):
# Always return processing so we hit timeout
mock_req.return_value = {"status": "processing", "success": True}

with pytest.raises(DatalabTimeoutError):
await client._poll_result(
"https://api.example.com/check", max_polls=3, poll_interval=0
)

assert mock_req.await_count == 3
assert mock_sleep.await_count >= 1

@pytest.mark.asyncio
async def test_poll_result_raises_on_failed_status(self):
async with AsyncDatalabClient(api_key="test-key") as client:
with patch.object(
client, "_make_request", new_callable=AsyncMock
) as mock_req:
mock_req.return_value = {
"status": "failed",
"success": False,
"error": "boom",
}

with pytest.raises(DatalabAPIError, match="Processing failed: boom"):
await client._poll_result(
"https://api.example.com/check", max_polls=1, poll_interval=0
)