Skip to content

Commit d89987c

Browse files
committed
feat: allow max polls and poll interval
1 parent 5df99e4 commit d89987c

File tree

2 files changed

+229
-8
lines changed

2 files changed

+229
-8
lines changed

datalab_sdk/client.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def get_form_params(self, file_path=None, file_url=None, options=None):
156156

157157
if file_url and file_path:
158158
raise ValueError("Either file_path or file_url must be provided, not both.")
159-
159+
160160
# Use either file_url or file upload, not both
161161
if file_url:
162162
form_data.add_field("file_url", file_url)
@@ -184,21 +184,31 @@ async def convert(
184184
file_url: Optional[str] = None,
185185
options: Optional[ProcessingOptions] = None,
186186
save_output: Optional[Union[str, Path]] = None,
187+
max_polls: int = 300,
188+
poll_interval: int = 1,
187189
) -> ConversionResult:
188190
"""Convert a document using the marker endpoint"""
189191
if options is None:
190192
options = ConvertOptions()
191193

192194
initial_data = await self._make_request(
193-
"POST", "/api/v1/marker", data=self.get_form_params(file_path=file_path, file_url=file_url, options=options)
195+
"POST",
196+
"/api/v1/marker",
197+
data=self.get_form_params(
198+
file_path=file_path, file_url=file_url, options=options
199+
),
194200
)
195201

196202
if not initial_data.get("success"):
197203
raise DatalabAPIError(
198204
f"Request failed: {initial_data.get('error', 'Unknown error')}"
199205
)
200206

201-
result_data = await self._poll_result(initial_data["request_check_url"])
207+
result_data = await self._poll_result(
208+
initial_data["request_check_url"],
209+
max_polls=max_polls,
210+
poll_interval=poll_interval,
211+
)
202212

203213
result = ConversionResult(
204214
success=result_data.get("success", False),
@@ -227,21 +237,29 @@ async def ocr(
227237
file_path: Union[str, Path],
228238
options: Optional[ProcessingOptions] = None,
229239
save_output: Optional[Union[str, Path]] = None,
240+
max_polls: int = 300,
241+
poll_interval: int = 1,
230242
) -> OCRResult:
231243
"""Perform OCR on a document"""
232244
if options is None:
233245
options = OCROptions()
234246

235247
initial_data = await self._make_request(
236-
"POST", "/api/v1/ocr", data=self.get_form_params(file_path=file_path, options=options)
248+
"POST",
249+
"/api/v1/ocr",
250+
data=self.get_form_params(file_path=file_path, options=options),
237251
)
238252

239253
if not initial_data.get("success"):
240254
raise DatalabAPIError(
241255
f"Request failed: {initial_data.get('error', 'Unknown error')}"
242256
)
243257

244-
result_data = await self._poll_result(initial_data["request_check_url"])
258+
result_data = await self._poll_result(
259+
initial_data["request_check_url"],
260+
max_polls=max_polls,
261+
poll_interval=poll_interval,
262+
)
245263

246264
result = OCRResult(
247265
success=result_data.get("success", False),
@@ -299,17 +317,36 @@ def convert(
299317
file_url: Optional[str] = None,
300318
options: Optional[ProcessingOptions] = None,
301319
save_output: Optional[Union[str, Path]] = None,
320+
max_polls: int = 300,
321+
poll_interval: int = 1,
302322
) -> ConversionResult:
303323
"""Convert a document using the marker endpoint (sync version)"""
304324
return self._run_async(
305-
self._async_client.convert(file_path=file_path, file_url=file_url, options=options, save_output=save_output)
325+
self._async_client.convert(
326+
file_path=file_path,
327+
file_url=file_url,
328+
options=options,
329+
save_output=save_output,
330+
max_polls=max_polls,
331+
poll_interval=poll_interval,
332+
)
306333
)
307334

308335
def ocr(
309336
self,
310337
file_path: Union[str, Path],
311338
options: Optional[ProcessingOptions] = None,
312339
save_output: Optional[Union[str, Path]] = None,
340+
max_polls: int = 300,
341+
poll_interval: int = 1,
313342
) -> OCRResult:
314343
"""Perform OCR on a document (sync version)"""
315-
return self._run_async(self._async_client.ocr(file_path, options, save_output))
344+
return self._run_async(
345+
self._async_client.ocr(
346+
file_path=file_path,
347+
options=options,
348+
save_output=save_output,
349+
max_polls=max_polls,
350+
poll_interval=poll_interval,
351+
)
352+
)

tests/test_client_methods.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
from datalab_sdk import DatalabClient, AsyncDatalabClient
1010
from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions
11-
from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError
11+
from datalab_sdk.exceptions import (
12+
DatalabAPIError,
13+
DatalabFileError,
14+
DatalabTimeoutError,
15+
)
1216

1317

1418
class TestConvertMethod:
@@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir):
169173
assert result.html == "<h1>Test Document</h1>"
170174
assert result.output_format == "html"
171175

176+
@pytest.mark.asyncio
177+
async def test_convert_async_respects_polling_params(self, temp_dir):
178+
"""Verify convert passes max_polls and poll_interval to poller"""
179+
# Create test file
180+
pdf_file = temp_dir / "test.pdf"
181+
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
182+
183+
# Mock API responses
184+
mock_initial_response = {
185+
"success": True,
186+
"request_id": "rid-1",
187+
"request_check_url": "https://api.datalab.to/api/v1/marker/rid-1",
188+
}
189+
190+
mock_result_response = {
191+
"success": True,
192+
"status": "complete",
193+
"output_format": "markdown",
194+
"markdown": "ok",
195+
}
196+
197+
async with AsyncDatalabClient(api_key="test-key") as client:
198+
with patch.object(
199+
client, "_make_request", new_callable=AsyncMock
200+
) as mock_req:
201+
with patch.object(
202+
client, "_poll_result", new_callable=AsyncMock
203+
) as mock_poll:
204+
mock_req.return_value = mock_initial_response
205+
mock_poll.return_value = mock_result_response
206+
207+
max_polls = 7
208+
poll_interval = 3
209+
await client.convert(
210+
pdf_file, max_polls=max_polls, poll_interval=poll_interval
211+
)
212+
213+
mock_poll.assert_awaited_once()
214+
# Verify kwargs were forwarded
215+
args, kwargs = mock_poll.await_args
216+
assert args[0] == mock_initial_response["request_check_url"]
217+
assert kwargs["max_polls"] == max_polls
218+
assert kwargs["poll_interval"] == poll_interval
219+
172220

173221
class TestOCRMethod:
174222
"""Test the ocr method"""
@@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir):
356404
assert "Page 1 content" in all_text
357405
assert "Page 2 content" in all_text
358406

407+
@pytest.mark.asyncio
408+
async def test_ocr_async_respects_polling_params(self, temp_dir):
409+
"""Verify ocr passes max_polls and poll_interval to poller"""
410+
pdf_file = temp_dir / "test.pdf"
411+
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
412+
413+
mock_initial_response = {
414+
"success": True,
415+
"request_id": "rid-2",
416+
"request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2",
417+
}
418+
419+
mock_result_response = {
420+
"success": True,
421+
"status": "complete",
422+
"pages": [],
423+
}
424+
425+
async with AsyncDatalabClient(api_key="test-key") as client:
426+
with patch.object(
427+
client, "_make_request", new_callable=AsyncMock
428+
) as mock_req:
429+
with patch.object(
430+
client, "_poll_result", new_callable=AsyncMock
431+
) as mock_poll:
432+
mock_req.return_value = mock_initial_response
433+
mock_poll.return_value = mock_result_response
434+
435+
max_polls = 11
436+
poll_interval = 2
437+
await client.ocr(
438+
pdf_file, max_polls=max_polls, poll_interval=poll_interval
439+
)
440+
441+
mock_poll.assert_awaited_once()
442+
args, kwargs = mock_poll.await_args
443+
assert args[0] == mock_initial_response["request_check_url"]
444+
assert kwargs["max_polls"] == max_polls
445+
assert kwargs["poll_interval"] == poll_interval
446+
447+
def test_sync_wrappers_forward_polling_params(self, temp_dir):
448+
"""Ensure sync client forwards polling params to async client"""
449+
pdf_file = temp_dir / "test.pdf"
450+
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
451+
452+
client = DatalabClient(api_key="test-key")
453+
454+
# Patch async convert/ocr to capture kwargs
455+
with patch.object(
456+
client._async_client, "convert", new_callable=AsyncMock
457+
) as mock_conv:
458+
with patch.object(
459+
client._async_client, "ocr", new_callable=AsyncMock
460+
) as mock_ocr:
461+
mock_conv.return_value = ConversionResult(
462+
success=True, output_format="markdown", markdown="ok"
463+
)
464+
mock_ocr.return_value = OCRResult(success=True, pages=[])
465+
466+
client.convert(pdf_file, max_polls=5, poll_interval=9)
467+
client.ocr(pdf_file, max_polls=6, poll_interval=4)
468+
469+
# Assert called with forwarded kwargs
470+
_, conv_kwargs = mock_conv.await_args
471+
assert conv_kwargs["max_polls"] == 5
472+
assert conv_kwargs["poll_interval"] == 9
473+
474+
_, ocr_kwargs = mock_ocr.await_args
475+
assert ocr_kwargs["max_polls"] == 6
476+
assert ocr_kwargs["poll_interval"] == 4
477+
359478

360479
class TestClientErrorHandling:
361480
"""Test error handling in client methods"""
@@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir):
416535
DatalabAPIError, match="Request failed: Processing failed"
417536
):
418537
client.convert(pdf_file)
538+
539+
def test_convert_timeout_bubbles_up(self, temp_dir):
540+
"""Polling timeout surfaces as DatalabTimeoutError for sync convert"""
541+
pdf_file = temp_dir / "test.pdf"
542+
pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n")
543+
544+
mock_initial_response = {
545+
"success": True,
546+
"request_id": "rid-timeout",
547+
"request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout",
548+
}
549+
550+
client = DatalabClient(api_key="test-key")
551+
with patch.object(
552+
client._async_client, "_make_request", new_callable=AsyncMock
553+
) as mock_request:
554+
with patch.object(
555+
client._async_client, "_poll_result", new_callable=AsyncMock
556+
) as mock_poll:
557+
mock_request.return_value = mock_initial_response
558+
mock_poll.side_effect = DatalabTimeoutError("Polling timed out")
559+
560+
with pytest.raises(DatalabTimeoutError, match="Polling timed out"):
561+
client.convert(pdf_file)
562+
563+
564+
class TestPollingLoop:
565+
"""Direct tests for the internal polling helper"""
566+
567+
@pytest.mark.asyncio
568+
async def test_poll_result_times_out(self):
569+
async with AsyncDatalabClient(api_key="test-key") as client:
570+
with (
571+
patch.object(
572+
client, "_make_request", new_callable=AsyncMock
573+
) as mock_req,
574+
patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
575+
):
576+
# Always return processing so we hit timeout
577+
mock_req.return_value = {"status": "processing", "success": True}
578+
579+
with pytest.raises(DatalabTimeoutError):
580+
await client._poll_result(
581+
"https://api.example.com/check", max_polls=3, poll_interval=0
582+
)
583+
584+
assert mock_req.await_count == 3
585+
assert mock_sleep.await_count >= 1
586+
587+
@pytest.mark.asyncio
588+
async def test_poll_result_raises_on_failed_status(self):
589+
async with AsyncDatalabClient(api_key="test-key") as client:
590+
with patch.object(
591+
client, "_make_request", new_callable=AsyncMock
592+
) as mock_req:
593+
mock_req.return_value = {
594+
"status": "failed",
595+
"success": False,
596+
"error": "boom",
597+
}
598+
599+
with pytest.raises(DatalabAPIError, match="Processing failed: boom"):
600+
await client._poll_result(
601+
"https://api.example.com/check", max_polls=1, poll_interval=0
602+
)

0 commit comments

Comments
 (0)