|
8 | 8 |
|
9 | 9 | from datalab_sdk import DatalabClient, AsyncDatalabClient |
10 | 10 | from datalab_sdk.models import ConversionResult, OCRResult, ConvertOptions, OCROptions |
11 | | -from datalab_sdk.exceptions import DatalabAPIError, DatalabFileError |
| 11 | +from datalab_sdk.exceptions import ( |
| 12 | + DatalabAPIError, |
| 13 | + DatalabFileError, |
| 14 | + DatalabTimeoutError, |
| 15 | +) |
12 | 16 |
|
13 | 17 |
|
14 | 18 | class TestConvertMethod: |
@@ -169,6 +173,50 @@ def test_convert_sync_with_processing_options(self, temp_dir): |
169 | 173 | assert result.html == "<h1>Test Document</h1>" |
170 | 174 | assert result.output_format == "html" |
171 | 175 |
|
| 176 | + @pytest.mark.asyncio |
| 177 | + async def test_convert_async_respects_polling_params(self, temp_dir): |
| 178 | + """Verify convert passes max_polls and poll_interval to poller""" |
| 179 | + # Create test file |
| 180 | + pdf_file = temp_dir / "test.pdf" |
| 181 | + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") |
| 182 | + |
| 183 | + # Mock API responses |
| 184 | + mock_initial_response = { |
| 185 | + "success": True, |
| 186 | + "request_id": "rid-1", |
| 187 | + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-1", |
| 188 | + } |
| 189 | + |
| 190 | + mock_result_response = { |
| 191 | + "success": True, |
| 192 | + "status": "complete", |
| 193 | + "output_format": "markdown", |
| 194 | + "markdown": "ok", |
| 195 | + } |
| 196 | + |
| 197 | + async with AsyncDatalabClient(api_key="test-key") as client: |
| 198 | + with patch.object( |
| 199 | + client, "_make_request", new_callable=AsyncMock |
| 200 | + ) as mock_req: |
| 201 | + with patch.object( |
| 202 | + client, "_poll_result", new_callable=AsyncMock |
| 203 | + ) as mock_poll: |
| 204 | + mock_req.return_value = mock_initial_response |
| 205 | + mock_poll.return_value = mock_result_response |
| 206 | + |
| 207 | + max_polls = 7 |
| 208 | + poll_interval = 3 |
| 209 | + await client.convert( |
| 210 | + pdf_file, max_polls=max_polls, poll_interval=poll_interval |
| 211 | + ) |
| 212 | + |
| 213 | + mock_poll.assert_awaited_once() |
| 214 | + # Verify kwargs were forwarded |
| 215 | + args, kwargs = mock_poll.await_args |
| 216 | + assert args[0] == mock_initial_response["request_check_url"] |
| 217 | + assert kwargs["max_polls"] == max_polls |
| 218 | + assert kwargs["poll_interval"] == poll_interval |
| 219 | + |
172 | 220 |
|
173 | 221 | class TestOCRMethod: |
174 | 222 | """Test the ocr method""" |
@@ -356,6 +404,77 @@ def test_ocr_sync_with_max_pages(self, temp_dir): |
356 | 404 | assert "Page 1 content" in all_text |
357 | 405 | assert "Page 2 content" in all_text |
358 | 406 |
|
| 407 | + @pytest.mark.asyncio |
| 408 | + async def test_ocr_async_respects_polling_params(self, temp_dir): |
| 409 | + """Verify ocr passes max_polls and poll_interval to poller""" |
| 410 | + pdf_file = temp_dir / "test.pdf" |
| 411 | + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") |
| 412 | + |
| 413 | + mock_initial_response = { |
| 414 | + "success": True, |
| 415 | + "request_id": "rid-2", |
| 416 | + "request_check_url": "https://api.datalab.to/api/v1/ocr/rid-2", |
| 417 | + } |
| 418 | + |
| 419 | + mock_result_response = { |
| 420 | + "success": True, |
| 421 | + "status": "complete", |
| 422 | + "pages": [], |
| 423 | + } |
| 424 | + |
| 425 | + async with AsyncDatalabClient(api_key="test-key") as client: |
| 426 | + with patch.object( |
| 427 | + client, "_make_request", new_callable=AsyncMock |
| 428 | + ) as mock_req: |
| 429 | + with patch.object( |
| 430 | + client, "_poll_result", new_callable=AsyncMock |
| 431 | + ) as mock_poll: |
| 432 | + mock_req.return_value = mock_initial_response |
| 433 | + mock_poll.return_value = mock_result_response |
| 434 | + |
| 435 | + max_polls = 11 |
| 436 | + poll_interval = 2 |
| 437 | + await client.ocr( |
| 438 | + pdf_file, max_polls=max_polls, poll_interval=poll_interval |
| 439 | + ) |
| 440 | + |
| 441 | + mock_poll.assert_awaited_once() |
| 442 | + args, kwargs = mock_poll.await_args |
| 443 | + assert args[0] == mock_initial_response["request_check_url"] |
| 444 | + assert kwargs["max_polls"] == max_polls |
| 445 | + assert kwargs["poll_interval"] == poll_interval |
| 446 | + |
| 447 | + def test_sync_wrappers_forward_polling_params(self, temp_dir): |
| 448 | + """Ensure sync client forwards polling params to async client""" |
| 449 | + pdf_file = temp_dir / "test.pdf" |
| 450 | + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") |
| 451 | + |
| 452 | + client = DatalabClient(api_key="test-key") |
| 453 | + |
| 454 | + # Patch async convert/ocr to capture kwargs |
| 455 | + with patch.object( |
| 456 | + client._async_client, "convert", new_callable=AsyncMock |
| 457 | + ) as mock_conv: |
| 458 | + with patch.object( |
| 459 | + client._async_client, "ocr", new_callable=AsyncMock |
| 460 | + ) as mock_ocr: |
| 461 | + mock_conv.return_value = ConversionResult( |
| 462 | + success=True, output_format="markdown", markdown="ok" |
| 463 | + ) |
| 464 | + mock_ocr.return_value = OCRResult(success=True, pages=[]) |
| 465 | + |
| 466 | + client.convert(pdf_file, max_polls=5, poll_interval=9) |
| 467 | + client.ocr(pdf_file, max_polls=6, poll_interval=4) |
| 468 | + |
| 469 | + # Assert called with forwarded kwargs |
| 470 | + _, conv_kwargs = mock_conv.await_args |
| 471 | + assert conv_kwargs["max_polls"] == 5 |
| 472 | + assert conv_kwargs["poll_interval"] == 9 |
| 473 | + |
| 474 | + _, ocr_kwargs = mock_ocr.await_args |
| 475 | + assert ocr_kwargs["max_polls"] == 6 |
| 476 | + assert ocr_kwargs["poll_interval"] == 4 |
| 477 | + |
359 | 478 |
|
360 | 479 | class TestClientErrorHandling: |
361 | 480 | """Test error handling in client methods""" |
@@ -416,3 +535,68 @@ def test_convert_unsuccessful_response(self, temp_dir): |
416 | 535 | DatalabAPIError, match="Request failed: Processing failed" |
417 | 536 | ): |
418 | 537 | client.convert(pdf_file) |
| 538 | + |
| 539 | + def test_convert_timeout_bubbles_up(self, temp_dir): |
| 540 | + """Polling timeout surfaces as DatalabTimeoutError for sync convert""" |
| 541 | + pdf_file = temp_dir / "test.pdf" |
| 542 | + pdf_file.write_bytes(b"%PDF-1.4\n%Test PDF content\n%%EOF\n") |
| 543 | + |
| 544 | + mock_initial_response = { |
| 545 | + "success": True, |
| 546 | + "request_id": "rid-timeout", |
| 547 | + "request_check_url": "https://api.datalab.to/api/v1/marker/rid-timeout", |
| 548 | + } |
| 549 | + |
| 550 | + client = DatalabClient(api_key="test-key") |
| 551 | + with patch.object( |
| 552 | + client._async_client, "_make_request", new_callable=AsyncMock |
| 553 | + ) as mock_request: |
| 554 | + with patch.object( |
| 555 | + client._async_client, "_poll_result", new_callable=AsyncMock |
| 556 | + ) as mock_poll: |
| 557 | + mock_request.return_value = mock_initial_response |
| 558 | + mock_poll.side_effect = DatalabTimeoutError("Polling timed out") |
| 559 | + |
| 560 | + with pytest.raises(DatalabTimeoutError, match="Polling timed out"): |
| 561 | + client.convert(pdf_file) |
| 562 | + |
| 563 | + |
| 564 | +class TestPollingLoop: |
| 565 | + """Direct tests for the internal polling helper""" |
| 566 | + |
| 567 | + @pytest.mark.asyncio |
| 568 | + async def test_poll_result_times_out(self): |
| 569 | + async with AsyncDatalabClient(api_key="test-key") as client: |
| 570 | + with ( |
| 571 | + patch.object( |
| 572 | + client, "_make_request", new_callable=AsyncMock |
| 573 | + ) as mock_req, |
| 574 | + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, |
| 575 | + ): |
| 576 | + # Always return processing so we hit timeout |
| 577 | + mock_req.return_value = {"status": "processing", "success": True} |
| 578 | + |
| 579 | + with pytest.raises(DatalabTimeoutError): |
| 580 | + await client._poll_result( |
| 581 | + "https://api.example.com/check", max_polls=3, poll_interval=0 |
| 582 | + ) |
| 583 | + |
| 584 | + assert mock_req.await_count == 3 |
| 585 | + assert mock_sleep.await_count >= 1 |
| 586 | + |
| 587 | + @pytest.mark.asyncio |
| 588 | + async def test_poll_result_raises_on_failed_status(self): |
| 589 | + async with AsyncDatalabClient(api_key="test-key") as client: |
| 590 | + with patch.object( |
| 591 | + client, "_make_request", new_callable=AsyncMock |
| 592 | + ) as mock_req: |
| 593 | + mock_req.return_value = { |
| 594 | + "status": "failed", |
| 595 | + "success": False, |
| 596 | + "error": "boom", |
| 597 | + } |
| 598 | + |
| 599 | + with pytest.raises(DatalabAPIError, match="Processing failed: boom"): |
| 600 | + await client._poll_result( |
| 601 | + "https://api.example.com/check", max_polls=1, poll_interval=0 |
| 602 | + ) |
0 commit comments