diff --git a/qiskit_experiments/database_service/db_experiment_data.py b/qiskit_experiments/database_service/db_experiment_data.py index 12a6d94dd6..35441d71bd 100644 --- a/qiskit_experiments/database_service/db_experiment_data.py +++ b/qiskit_experiments/database_service/db_experiment_data.py @@ -14,7 +14,6 @@ import logging import dataclasses -import threading import uuid from typing import Optional, List, Any, Union, Callable, Dict import copy @@ -70,13 +69,13 @@ def service_exception_to_warning(): @dataclasses.dataclass -class CallbackStatus: - """Dataclass for analysis callback status""" +class Callback: + """Dataclass for analysis callback functions""" - callback: Callable + func: Callable kwargs: Dict = dataclasses.field(default_factory=dict) - status: JobStatus = JobStatus.QUEUED - event: threading.Event = dataclasses.field(default_factory=threading.Event) + callback_id: str = "" + status: JobStatus = JobStatus.INITIALIZING error_msg: Optional[str] = None @@ -104,7 +103,7 @@ class DbExperimentDataV1(DbExperimentData): version = 1 verbose = True # Whether to print messages to the standard output. _metadata_version = 1 - _executor = futures.ThreadPoolExecutor() + _job_executor = futures.ThreadPoolExecutor() """Threads used for asynchronous processing.""" _json_encoder = ExperimentEncoder @@ -166,8 +165,9 @@ def __init__( self._jobs = ThreadSafeOrderedDict(job_ids or []) self._job_futures = ThreadSafeList() - self._callback_statuses = ThreadSafeOrderedDict() - self._callback_future = None + self._callback_executor = futures.ThreadPoolExecutor(max_workers=1) + self._callbacks = ThreadSafeOrderedDict() + self._callback_futures = ThreadSafeOrderedDict() self._data = ThreadSafeList() self._figures = ThreadSafeOrderedDict(figure_names or []) @@ -220,7 +220,7 @@ def add_data( Raises: TypeError: If the input data type is invalid. """ - if any(not status.event.is_set() for status in self._callback_statuses.values()): + if any(not future.done() for future in self._callback_futures.values()): LOG.warning( "Not all post-processing has finished. Adding new data " "may create unexpected analysis results." @@ -265,7 +265,7 @@ def add_data( "timeout": timeout, } self._job_futures.append( - (job_kwargs, self._executor.submit(self._add_jobs_data, **job_kwargs)) + (job_kwargs, self._job_executor.submit(self._add_jobs_data, **job_kwargs)) ) if self.auto_save: @@ -288,52 +288,65 @@ def add_analysis_callback(self, callback: Callable, **kwargs: Any): keywork arguments passed to this method. **kwargs: Keyword arguments to be passed to the callback function. """ - callback_id = uuid.uuid4() - self._callback_statuses[callback_id] = CallbackStatus(callback, kwargs=kwargs) + with self._job_futures.lock and self._callback_futures.lock: + # Create callback dataclass + cid = uuid.uuid4().hex + self._callbacks[cid] = Callback(callback, kwargs=kwargs, callback_id=cid) + + # Get futures to wait for before running callback + if self._callback_futures: + futs = self._callback_futures.values() + else: + futs = [fut for _, fut in self._job_futures.copy()] - # Wrap callback function to handle reporting status and catching - # any exceptions and their error messages - def _wrapped_callback(): - try: - self._callback_statuses[callback_id].status = JobStatus.RUNNING - callback(self, **kwargs) - self._callback_statuses[callback_id].status = JobStatus.DONE - except Exception as ex: # pylint: disable=broad-except - self._callback_statuses[callback_id].status = JobStatus.ERROR - error_msg = f"Analysis callback {callback} failed: \n" "".join( - traceback.format_exception(type(ex), ex, ex.__traceback__) - ) - self._callback_statuses[callback_id].error_msg = error_msg - LOG.warning("Analysis callback %s failed:\n%s", callback, traceback.format_exc()) - self._callback_statuses[callback_id].event.set() + # Add run callback future + self._callback_futures[cid] = self._callback_executor.submit( + self._run_callback, cid, futs + ) - with self._job_futures.lock: - # Determine if a future is running that we need to add callback to - fut_done = True - if self._job_futures: - _, fut = self._job_futures[-1] - fut_done = fut.done() - if fut_done and self._callback_future is not None: - fut = self._callback_future - fut_done = fut.done() - if fut_done: - fut = None - - if fut_done: - # Submit future so analysis can run async even if there are no - # running jobs or running analysis. - self._callback_future = self._executor.submit(_wrapped_callback) - else: - # Wrap the wrapped function for the format expected by Python - # Future.add_done_callback - def _done_callback(fut): - if fut.cancelled(): - self._callback_statuses[callback_id].status = JobStatus.CANCELLED - self._callback_statuses[callback_id].event.set() - else: - _wrapped_callback() + def cancel_callbacks(self) -> None: + """Cancel any queued callbacks. - fut.add_done_callback(_done_callback) + .. note:: + A currently running callback cannot be cancelled. + """ + with self._callback_futures.lock: + for cid, fut in self._callback_futures.items(): + if fut.done(): + continue + if fut.cancel(): + LOG.info("Cancelled queued callback [cid: %s].", cid) + self._callbacks[cid].status = JobStatus.CANCELLED + else: + LOG.warning("Unable to cancel running callback [cid: %s].", cid) + + def _run_callback(self, callback_id: str, futs: Optional[List[futures.Future]] = None): + """Run a callback after specified futures have finished.""" + if callback_id not in self._callbacks: + raise ValueError(f"No callback with id {callback_id}") + + callback = self._callbacks[callback_id] + + # Wait for previous futures to finish_ + LOG.debug("Waiting to run callback [cid %s]", callback_id) + self._callbacks[callback_id].status = JobStatus.QUEUED + if futs: + futures.wait(futs) + + # Run callback function + LOG.debug("Running callback [cid: %s]", callback_id) + self._callbacks[callback_id].status = JobStatus.RUNNING + try: + callback.func(self, **callback.kwargs) + self._callbacks[callback_id].status = JobStatus.DONE + LOG.debug("Callback finished [cid: %s]", callback_id) + except Exception as ex: # pylint: disable=broad-except + self._callbacks[callback_id].status = JobStatus.ERROR + error_msg = f"Analysis callback failed [cid: {callback_id}]:\n" "".join( + traceback.format_exception(type(ex), ex, ex.__traceback__) + ) + self._callbacks[callback_id].error_msg = error_msg + LOG.warning(error_msg) def _add_jobs_data( self, @@ -862,8 +875,8 @@ def save(self) -> None: if self.verbose: print( - "You can view the experiment online at https://quantum-computing.ibm.com/experiments/" - + self.experiment_id + "You can view the experiment online at " + "https://quantum-computing.ibm.com/experiments/" + self.experiment_id ) @classmethod @@ -930,8 +943,11 @@ def block_for_results(self, timeout: Optional[float] = None) -> "DbExperimentDat Returns: The experiment data with finished jobs and post-processing. """ - _, timeout = combined_timeout(self._wait_for_jobs, timeout) - _, timeout = combined_timeout(self._wait_for_callbacks, timeout) + if self._callback_futures: + self._wait_for_callbacks(timeout) + else: + self._wait_for_jobs(timeout) + self._removed_done_futures() return self def _wait_for_jobs(self, timeout: Optional[float] = None): @@ -960,29 +976,27 @@ def _wait_for_jobs(self, timeout: Optional[float] = None): def _wait_for_callbacks(self, timeout: Optional[float] = None): """Wait for analysis callbacks to finish""" - # Wait for analysis callbacks to finish - if self._callback_statuses: - for status in self._callback_statuses.values(): - if status.status in [JobStatus.DONE, JobStatus.CANCELLED]: - continue - LOG.info("Waiting for analysis callback %s to finish.", status.callback) - finished, timeout = combined_timeout(status.event.wait, timeout) - if not finished: - LOG.warning( - "Possibly incomplete analysis results:" - " analysis" - " callback %s timed out.", - status.callback, - ) + try: + LOG.debug("Waiting for all callbacks to finish [eid: %s]", self.experiment_id) + waited = futures.wait(self._callback_futures.values(), timeout=timeout) + if waited.not_done: + raise futures.TimeoutError + LOG.debug("All callbacks finished [eid: %s]", self.experiment_id) + except futures.TimeoutError: + LOG.warning("Waiting for callbacks timed out before completion.") + except futures.CancelledError: + LOG.warning("Callbacks were cancelled before completion.") + + def _removed_done_futures(self): + """Remove futures that have finished""" + with self._callback_futures.lock and self._job_futures.lock: + running_callbacks = [ + (cid, fut) for cid, fut in self._callback_futures.items() if not fut.done() + ] + self._callback_futures = ThreadSafeOrderedDict(running_callbacks) - # Check analysis status and show warning if cancelled or error - callback_status = self._callback_status() - if callback_status == "CANCELLED": - LOG.warning("Possibly incomplete analysis results: an analysis callback was cancelled.") - elif callback_status == "ERROR": - LOG.warning( - "Possibly incomplete analysis results: an analysis callback raised an error." - ) + running_jobs = [(jid, fut) for jid, fut in self._job_futures if not fut.done()] + self._job_futures = ThreadSafeList(running_jobs) def status(self) -> str: """Return the data processing status. @@ -999,6 +1013,8 @@ def status(self) -> str: * POST_PROCESSING - if any analysis callbacks are still running * DONE - if all jobs and analysis callbacks are finished. + If no data has been added the returned status will be EMPTY. + .. note:: If an experiment has status ERROR or CANCELLED there may still @@ -1013,22 +1029,23 @@ def status(self) -> str: for container in [ self._data, self._jobs, - self._callback_statuses, + self._job_futures, + self._callbacks, + self._callback_futures, self._figures, self._analysis_results, ] ): - return "INITIALIZING" + return "EMPTY" job_status = self._job_status() if job_status != "DONE": return job_status callback_status = self._callback_status() - if callback_status in ["QUEUED", "RUNNING"]: - return "POST_PROCESSING" - - return callback_status + if callback_status in ["DONE", "CANCELLED", "ERROR"]: + return callback_status + return "POST_PROCESSING" def _job_status(self) -> str: """Return the experiment job execution status. @@ -1084,34 +1101,33 @@ def _callback_status(self) -> str: If the experiment consists of multiple analysis callbacks, the returned status is mapped in the following order: - * ERROR - if any analysis callback incurred an error. - * CANCELLED - if any analysis callback is cancelled. - * RUNNING - if any analysis callbacks are still running. - * QUEUED - if any analysis callback is queued. - * DONE - if all analysis callbacks are finished. + * ERROR - if any callback incurred an error. + * CANCELLED - if any callback was cancelled. + * RUNNING - if any callback is still running. + * QUEUED - if any callback is queued. + * INITIALIZING - if any callback is being initialized. + * DONE - if all callbacks are finished. Returns: Analysis callback status. """ statuses = set() - for status in self._callback_statuses.values(): + for status in self._callbacks.values(): statuses.add(status.status) - # Remove analysis future if it is done, since we store all statuses - # In the _callback_status field. - if self._callback_future is not None and self._callback_future.done(): - self._callback_future = None - for stat in [ JobStatus.ERROR, JobStatus.CANCELLED, JobStatus.RUNNING, JobStatus.QUEUED, + JobStatus.VALIDATING, + JobStatus.INITIALIZING, + JobStatus.DONE, ]: if stat in statuses: return stat.name - return "DONE" + return JobStatus.DONE.name def errors(self) -> str: """Return errors encountered. @@ -1120,31 +1136,22 @@ def errors(self) -> str: Experiment errors. """ errors = [] - # Get any future errors - for fut_kwargs, fut in self._job_futures: - if fut.done(): - ex = fut.exception() - if ex: - jobs = [job.job_id() for job in fut_kwargs["jobs"]] - errors.append( - f"Job {jobs} failed: \n" - + "".join(traceback.format_exception(type(ex), ex, ex.__traceback__)) - ) # Get any job errors for job in self._jobs.values(): if job and job.status() == JobStatus.ERROR: - job_err = "." if hasattr(job, "error_message"): - job_err = ": " + job.error_message() - errors.append(f"Job {job.job_id()} failed{job_err}") + error_msg = job.error_message() + else: + error_msg = "" + errors.append(f"\n[jid: {job.job_id()}]: {error_msg}") - # Get any analysis callback errors - for status in self._callback_statuses.values(): - if status.error_msg is not None: - errors.append(status.error_msg) + # Get any callback errors + for callback in self._callbacks.values(): + if callback.status == JobStatus.ERROR: + errors.append(f"\n[cid: {callback.callback_id}]: {callback.error_msg}") - return "\n".join(errors) + return "".join(errors) def copy(self, copy_results: bool = True) -> "DbExperimentDataV1": """Make a copy of the experiment data with a new experiment ID. diff --git a/qiskit_experiments/framework/composite/composite_analysis.py b/qiskit_experiments/framework/composite/composite_analysis.py index ddf6671cb8..8f4680bc21 100644 --- a/qiskit_experiments/framework/composite/composite_analysis.py +++ b/qiskit_experiments/framework/composite/composite_analysis.py @@ -96,13 +96,10 @@ def _run_analysis(self, experiment_data: ExperimentData): ) analysis_results.append(result) - # Add callback to wait for all component analysis to finish before returning + # Wait for all component analysis to finish before returning # the parent experiment analysis results - def _wait_for_components(experiment_data, component_ids): - for comp_id in component_ids: - experiment_data.child_data(comp_id).block_for_results() - - experiment_data.add_analysis_callback(_wait_for_components, component_ids=component_ids) + for comp_id in component_ids: + experiment_data.child_data(comp_id).block_for_results() return analysis_results, [] diff --git a/releasenotes/notes/callback-futures-646d9c36f4af3d72.yaml b/releasenotes/notes/callback-futures-646d9c36f4af3d72.yaml new file mode 100644 index 0000000000..cc8e7c1920 --- /dev/null +++ b/releasenotes/notes/callback-futures-646d9c36f4af3d72.yaml @@ -0,0 +1,15 @@ +--- +features: + - | + Improves handling of analysis callbacks in :meth:`.ExperimentData`. + Logging information on execution of analysis callbacks in an experiment + can enabled by setting the log level to DEBUG. + - | + Adds :meth:`.ExperimentData.cancel_callbacks` method to allow cancelling + pending analysis callbacks. Note that analysis callbacks that have already + started running cannot be cancelled. +fixes: + - | + Fixes an issue with :meth:`.ExperimentData.block_for_results` sometimes + having a race issue with all analysis callbacks finishing. This would + often occur when running :meth:`.ExperimentData.block_for_results`. diff --git a/releasenotes/notes/fix-nested-comp-66a2b8b6e3b404be.yaml b/releasenotes/notes/fix-nested-comp-66a2b8b6e3b404be.yaml new file mode 100644 index 0000000000..8840e94d49 --- /dev/null +++ b/releasenotes/notes/fix-nested-comp-66a2b8b6e3b404be.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fixes bug in :class:`.CompositeAnalysis` where analysis of nested + composite experiments could raise a RuntimeError. diff --git a/test/database_service/test_db_experiment_data.py b/test/database_service/test_db_experiment_data.py index 353dde6935..b2b339917e 100644 --- a/test/database_service/test_db_experiment_data.py +++ b/test/database_service/test_db_experiment_data.py @@ -579,6 +579,23 @@ def test_status_post_processing(self): status = exp_data.status() self.assertEqual("POST_PROCESSING", status) + def test_status_cancelled_callback(self): + """Test experiment status during post processing.""" + job = mock.create_autospec(Job, instance=True) + job.result.return_value = self._get_job_result(3) + + event = threading.Event() + self.addCleanup(event.set) + + exp_data = DbExperimentData(experiment_type="qiskit_test") + exp_data.add_data(job) + exp_data.add_analysis_callback((lambda *args, **kwargs: event.wait(timeout=2))) + # Add second callback because the first can't be cancelled once it has started + exp_data.add_analysis_callback((lambda *args, **kwargs: event.wait(timeout=20))) + exp_data.cancel_callbacks() + status = exp_data.status() + self.assertEqual("CANCELLED", status) + def test_status_post_processing_error(self): """Test experiment status when post processing failed.""" diff --git a/test/test_composite.py b/test/test_composite.py index c894a9fbd6..7d6fae2929 100644 --- a/test/test_composite.py +++ b/test/test_composite.py @@ -161,6 +161,20 @@ def test_composite_copy(self): self.check_attributes(new_instance) self.assertEqual(new_instance.parent_id, None) + def test_nested_composite(self): + """ + Test nested parallel experiments. + """ + exp1 = FakeExperiment([0, 2]) + exp2 = FakeExperiment([1, 3]) + exp3 = ParallelExperiment([exp1, exp2]) + exp4 = BatchExperiment([exp3, exp1]) + exp5 = ParallelExperiment([exp4, FakeExperiment([4])]) + nested_exp = BatchExperiment([exp5, exp3]) + expdata = nested_exp.run(FakeBackend()).block_for_results() + status = expdata.status() + self.assertEqual(status, "DONE") + def test_analysis_replace_results_true(self): """ Test replace results when analyzing composite experiment data