diff --git a/qiskit_ibm_runtime/runtime_job.py b/qiskit_ibm_runtime/runtime_job.py index 1b0061be81..26ab35fc3a 100644 --- a/qiskit_ibm_runtime/runtime_job.py +++ b/qiskit_ibm_runtime/runtime_job.py @@ -13,14 +13,12 @@ """Qiskit runtime job.""" from typing import Any, Optional, Callable, Dict, Type -import time import logging from concurrent import futures import traceback import queue from datetime import datetime -from qiskit.providers.exceptions import JobTimeoutError from qiskit.providers.backend import Backend from qiskit.providers.jobstatus import JobStatus, JOB_FINAL_STATES @@ -159,14 +157,12 @@ def interim_results(self, decoder: Optional[Type[ResultDecoder]] = None) -> Any: def result( self, timeout: Optional[float] = None, - wait: float = 5, decoder: Optional[Type[ResultDecoder]] = None, ) -> Any: """Return the results of the job. Args: timeout: Number of seconds to wait for job. - wait: Seconds between queries. decoder: A :class:`ResultDecoder` subclass used to decode job results. Returns: @@ -177,7 +173,7 @@ def result( """ _decoder = decoder or self._result_decoder if self._results is None or (_decoder != self._result_decoder): - self.wait_for_final_state(timeout=timeout, wait=wait) + self.wait_for_final_state(timeout=timeout) if self._status == JobStatus.ERROR: raise RuntimeJobFailureError( f"Unable to retrieve job result. " f"{self.error_message()}" @@ -222,29 +218,18 @@ def error_message(self) -> Optional[str]: self._set_status_and_error_message() return self._error_message - def wait_for_final_state( - self, timeout: Optional[float] = None, wait: float = 5 - ) -> None: - """Poll the job status until it progresses to a final state such as ``DONE`` or ``ERROR``. + def wait_for_final_state(self, timeout: Optional[float] = None) -> None: + """Use the websocket server to wait for the final the state of a job. The server + will remain open if the job is still running and the connection will be terminated + once the job completes. Then update and return the status of the job. Args: timeout: Seconds to wait for the job. If ``None``, wait indefinitely. - wait: Seconds between queries. - - Raises: - JobTimeoutError: If the job does not reach a final state before the - specified timeout. """ - start_time = time.time() - status = self.status() - while status not in JOB_FINAL_STATES: - elapsed_time = time.time() - start_time - if timeout is not None and elapsed_time >= timeout: - raise JobTimeoutError( - "Timeout while waiting for job {}.".format(self.job_id) - ) - time.sleep(wait) - status = self.status() + if self._status not in JOB_FINAL_STATES: + self._ws_client_future = self._executor.submit(self._start_websocket_client) + self._ws_client_future.result(timeout) + self.status() def stream_results( self, callback: Callable, decoder: Optional[Type[ResultDecoder]] = None @@ -264,14 +249,12 @@ def stream_results( RuntimeInvalidStateError: If a callback function is already streaming results or if the job already finished. """ + if self._status in JOB_FINAL_STATES: + raise RuntimeInvalidStateError("Job already finished.") if self._is_streaming(): raise RuntimeInvalidStateError( "A callback function is already streaming results." ) - - if self._status in JOB_FINAL_STATES: - raise RuntimeInvalidStateError("Job already finished.") - self._ws_client_future = self._executor.submit(self._start_websocket_client) self._executor.submit( self._stream_results, diff --git a/test/integration/test_interim_results.py b/test/integration/test_interim_results.py index 0c1be36604..695f25c933 100644 --- a/test/integration/test_interim_results.py +++ b/test/integration/test_interim_results.py @@ -83,16 +83,17 @@ def test_stream_results_done(self, service): def result_callback(job_id, interim_result): # pylint: disable=unused-argument - nonlocal called_back - called_back = True + nonlocal called_back_count + called_back_count += 1 - called_back = False + called_back_count = 0 job = self._run_program(service, interim_results="foobar") job.wait_for_final_state() job._status = JobStatus.RUNNING # Allow stream_results() job.stream_results(result_callback) time.sleep(2) - self.assertFalse(called_back) + # Callback is expected twice because both interim and final results are returned + self.assertEqual(2, called_back_count) self.assertIsNotNone(job._ws_client._server_close_code) @run_integration_test diff --git a/test/unit/mock/fake_runtime_client.py b/test/unit/mock/fake_runtime_client.py index f52c350cb8..d74c6d736f 100644 --- a/test/unit/mock/fake_runtime_client.py +++ b/test/unit/mock/fake_runtime_client.py @@ -164,6 +164,10 @@ def interim_results(self): """Return job interim results.""" return self._interim_results + def status(self): + """Return job status.""" + return self._status + class FailedRuntimeJob(BaseFakeRuntimeJob): """Class for faking a failed runtime job.""" @@ -451,6 +455,13 @@ def job_delete(self, job_id): self._get_job(job_id) del self._jobs[job_id] + def wait_for_final_state(self, job_id): + """Wait for the final state of a program job.""" + final_states = ["COMPLETED", "FAILED", "CANCELLED", "CANCELLED - RAN TOO LONG"] + status = self._get_job(job_id).status() + while status not in final_states: + status = self._get_job(job_id).status() + def _get_program(self, program_id): """Get program.""" if program_id not in self._programs: diff --git a/test/unit/test_job_retrieval.py b/test/unit/test_job_retrieval.py index 852fb509c1..937a63bb88 100644 --- a/test/unit/test_job_retrieval.py +++ b/test/unit/test_job_retrieval.py @@ -17,6 +17,7 @@ from ..ibm_test_case import IBMTestCase from ..decorators import run_legacy_and_cloud_fake from ..program import run_program, upload_program +from ..utils import mock_wait_for_final_state class TestRetrieveJobs(IBMTestCase): @@ -182,8 +183,9 @@ def test_jobs_filter_by_program_id(self, service): job = run_program(service=service, program_id=program_id) job_1 = run_program(service=service, program_id=program_id_1) - job.wait_for_final_state() - job_1.wait_for_final_state() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() + job_1.wait_for_final_state() rjobs = service.jobs(program_id=program_id) self.assertEqual(program_id, rjobs[0].program_id) self.assertEqual(1, len(rjobs)) @@ -195,7 +197,8 @@ def test_jobs_filter_by_instance(self): instance = FakeRuntimeService.DEFAULT_HGPS[1] job = run_program(service=service, program_id=program_id, instance=instance) - job.wait_for_final_state() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() rjobs = service.jobs(program_id=program_id, instance=instance) self.assertTrue(rjobs) self.assertEqual(program_id, rjobs[0].program_id) diff --git a/test/unit/test_jobs.py b/test/unit/test_jobs.py index 2c5ff3574a..9eeaf8912a 100644 --- a/test/unit/test_jobs.py +++ b/test/unit/test_jobs.py @@ -37,6 +37,7 @@ from ..decorators import run_legacy_and_cloud_fake from ..program import run_program, upload_program from ..serialization import get_complex_types +from ..utils import mock_wait_for_final_state class TestRuntimeJob(IBMTestCase): @@ -51,9 +52,10 @@ def test_run_program(self, service): self.assertIsInstance(job, RuntimeJob) self.assertIsInstance(job.status(), JobStatus) self.assertEqual(job.inputs, params) - job.wait_for_final_state() - self.assertEqual(job.status(), JobStatus.DONE) - self.assertTrue(job.result()) + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() + self.assertEqual(job.status(), JobStatus.DONE) + self.assertTrue(job.result()) @run_legacy_and_cloud_fake def test_run_phantom_program(self, service): @@ -148,9 +150,10 @@ def test_run_program_with_custom_runtime_image(self, service): self.assertIsInstance(job, RuntimeJob) self.assertIsInstance(job.status(), JobStatus) self.assertEqual(job.inputs, params) - job.wait_for_final_state() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() + self.assertTrue(job.result()) self.assertEqual(job.status(), JobStatus.DONE) - self.assertTrue(job.result()) self.assertEqual(job.image, image) @run_legacy_and_cloud_fake @@ -164,31 +167,33 @@ def test_run_program_with_custom_log_level(self, service): def test_run_program_failed(self, service): """Test a failed program execution.""" job = run_program(service=service, job_classes=FailedRuntimeJob) - job.wait_for_final_state() - job_result_raw = service._api_client.job_results(job.job_id) - self.assertEqual(JobStatus.ERROR, job.status()) - self.assertEqual( - API_TO_JOB_ERROR_MESSAGE["FAILED"].format(job.job_id, job_result_raw), - job.error_message(), - ) - with self.assertRaises(RuntimeJobFailureError): - job.result() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() + job_result_raw = service._api_client.job_results(job.job_id) + self.assertEqual(JobStatus.ERROR, job.status()) + self.assertEqual( + API_TO_JOB_ERROR_MESSAGE["FAILED"].format(job.job_id, job_result_raw), + job.error_message(), + ) + with self.assertRaises(RuntimeJobFailureError): + job.result() @run_legacy_and_cloud_fake def test_run_program_failed_ran_too_long(self, service): """Test a program that failed since it ran longer than maximum execution time.""" job = run_program(service=service, job_classes=FailedRanTooLongRuntimeJob) - job.wait_for_final_state() - job_result_raw = service._api_client.job_results(job.job_id) - self.assertEqual(JobStatus.ERROR, job.status()) - self.assertEqual( - API_TO_JOB_ERROR_MESSAGE["CANCELLED - RAN TOO LONG"].format( - job.job_id, job_result_raw - ), - job.error_message(), - ) - with self.assertRaises(RuntimeJobFailureError): - job.result() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() + job_result_raw = service._api_client.job_results(job.job_id) + self.assertEqual(JobStatus.ERROR, job.status()) + self.assertEqual( + API_TO_JOB_ERROR_MESSAGE["CANCELLED - RAN TOO LONG"].format( + job.job_id, job_result_raw + ), + job.error_message(), + ) + with self.assertRaises(RuntimeJobFailureError): + job.result() @run_legacy_and_cloud_fake def test_program_params_namespace(self, service): @@ -212,8 +217,9 @@ def test_cancel_job(self, service): def test_final_result(self, service): """Test getting final result.""" job = run_program(service) - result = job.result() - self.assertTrue(result) + with mock_wait_for_final_state(service, job): + result = job.result() + self.assertTrue(result) @run_legacy_and_cloud_fake def test_interim_results(self, service): @@ -248,7 +254,8 @@ def test_job_program_id(self, service): def test_wait_for_final_state(self, service): """Test wait for final state.""" job = run_program(service) - job.wait_for_final_state() + with mock_wait_for_final_state(service, job): + job.wait_for_final_state() self.assertEqual(JobStatus.DONE, job.status()) @run_legacy_and_cloud_fake @@ -259,8 +266,9 @@ def test_get_result_twice(self, service): job_cls.custom_result = custom_result job = run_program(service=service, job_classes=job_cls) - _ = job.result() - _ = job.result() + with mock_wait_for_final_state(service, job): + _ = job.result() + _ = job.result() @run_legacy_and_cloud_fake def test_delete_job(self, service): diff --git a/test/utils.py b/test/utils.py index da5a2e065a..7fae9311b2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -16,6 +16,7 @@ import logging import time import unittest +from unittest import mock from qiskit import QuantumCircuit from qiskit.providers.jobstatus import JOB_FINAL_STATES, JobStatus @@ -137,3 +138,12 @@ def get_real_device(service): ).name() except QiskitBackendNotFoundError: raise unittest.SkipTest("No real device") # cloud has no real device + + +def mock_wait_for_final_state(service, job): + """replace `wait_for_final_state` with a mock function""" + return mock.patch.object( + RuntimeJob, + "wait_for_final_state", + side_effect=service._api_client.wait_for_final_state(job.job_id), + )