diff --git a/qiskit/passmanager/base_tasks.py b/qiskit/passmanager/base_tasks.py index 84a944d78042..64ebafa225da 100644 --- a/qiskit/passmanager/base_tasks.py +++ b/qiskit/passmanager/base_tasks.py @@ -19,7 +19,7 @@ from collections.abc import Iterable, Callable, Generator from typing import Any -from .compilation_status import RunState, PassManagerState +from .compilation_status import RunState, PassManagerState, PropertySet logger = logging.getLogger(__name__) @@ -62,6 +62,7 @@ class GenericPass(Task, ABC): """ def __init__(self): + self.property_set = PropertySet() self.requires: Iterable[Task] = [] def name(self) -> str: @@ -77,6 +78,7 @@ def execute( # Overriding this method is not safe. # Pass subclass must keep current implementation. # Especially, task execution may break when method signature is modified. + self.property_set = state.property_set if self.requires: # pylint: disable=cyclic-import diff --git a/qiskit/passmanager/flow_controllers.py b/qiskit/passmanager/flow_controllers.py index bf831b94b35c..eb06202e8673 100644 --- a/qiskit/passmanager/flow_controllers.py +++ b/qiskit/passmanager/flow_controllers.py @@ -133,6 +133,8 @@ def iter_tasks(self, state: PassManagerState) -> Generator[Task, PassManagerStat state = yield task if not self.do_while(state.property_set): return + # Remove stored tasks from the completed task collection for next loop + state.workflow_status.completed_passes.difference_update(self.tasks) raise PassManagerError("Maximum iteration reached. max_iteration=%i" % max_iteration) diff --git a/qiskit/passmanager/passmanager.py b/qiskit/passmanager/passmanager.py index fb9c8b523ba6..74d5feb91088 100644 --- a/qiskit/passmanager/passmanager.py +++ b/qiskit/passmanager/passmanager.py @@ -15,7 +15,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence, Iterable +from collections.abc import Callable, Iterable from itertools import chain from typing import Any @@ -169,14 +169,16 @@ def _passmanager_backend( def run( self, - in_programs: Any, + in_programs: Any | list[Any], callback: Callable = None, **kwargs, ) -> Any: - """Run all the passes on the specified ``circuits``. + """Run all the passes on the specified ``in_programs``. Args: in_programs: Input programs to transform via all the registered passes. + A single input object cannot be a Python builtin list object. + A list object is considered as multiple input objects to optimize. callback: A callback function that will be called after each pass execution. The function will be called with 4 keyword arguments:: @@ -212,7 +214,7 @@ def callback_func(**kwargs): return in_programs is_list = True - if not isinstance(in_programs, Sequence): + if not isinstance(in_programs, list): in_programs = [in_programs] is_list = False diff --git a/qiskit/transpiler/basepasses.py b/qiskit/transpiler/basepasses.py index 02593c0bb9fd..c09ee190e38b 100644 --- a/qiskit/transpiler/basepasses.py +++ b/qiskit/transpiler/basepasses.py @@ -74,7 +74,6 @@ class BasePass(GenericPass, metaclass=MetaPass): def __init__(self): super().__init__() self.preserves: Iterable[GenericPass] = [] - self.property_set = PropertySet() self._hash = hash(None) def __hash__(self): @@ -118,21 +117,6 @@ def is_analysis_pass(self): """ return isinstance(self, AnalysisPass) - def execute( - self, - passmanager_ir: PassManagerIR, - state: PassManagerState, - callback: Callable = None, - ) -> tuple[PassManagerIR, PassManagerState]: - # For backward compatibility. - # Circuit passes access self.property_set. - self.property_set = state.property_set - return super().execute( - passmanager_ir=passmanager_ir, - state=state, - callback=callback, - ) - def __call__( self, circuit: QuantumCircuit, diff --git a/test/python/passmanager/__init__.py b/test/python/passmanager/__init__.py new file mode 100644 index 000000000000..d5b924250dc4 --- /dev/null +++ b/test/python/passmanager/__init__.py @@ -0,0 +1,54 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023 +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Pass manager test cases.""" + +import contextlib +import logging +import re +from itertools import zip_longest +from logging import getLogger + +from qiskit.test import QiskitTestCase + + +class PassManagerTestCase(QiskitTestCase): + """Test case for the pass manager module.""" + + @contextlib.contextmanager + def assertLogContains(self, expected_lines): + """A context manager that capture pass manager log. + + Args: + expected_lines (List[str]): Expected logs. Each element can be regular expression. + """ + try: + logger = getLogger() + with self.assertLogs(logger=logger, level=logging.DEBUG) as cm: + yield cm + finally: + recorded_lines = cm.output + for i, (expected, recorded) in enumerate(zip_longest(expected_lines, recorded_lines)): + expected = expected or "" + recorded = recorded or "" + if not re.search(expected, recorded): + raise AssertionError( + f"Log didn't match. Mismatch found at line #{i}.\n\n" + f"Expected:\n{self._format_log(expected_lines)}\n" + f"Recorded:\n{self._format_log(recorded_lines)}" + ) + + def _format_log(self, lines): + out = "" + for i, line in enumerate(lines): + out += f"#{i:02d}: {line}\n" + return out diff --git a/test/python/passmanager/test_generic_pass.py b/test/python/passmanager/test_generic_pass.py new file mode 100644 index 000000000000..c0071b966d65 --- /dev/null +++ b/test/python/passmanager/test_generic_pass.py @@ -0,0 +1,143 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023 +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=missing-class-docstring + +"""Pass manager test cases.""" + +from test.python.passmanager import PassManagerTestCase + +from logging import getLogger + +from qiskit.passmanager import GenericPass +from qiskit.passmanager import PassManagerState, WorkflowStatus, PropertySet +from qiskit.passmanager.compilation_status import RunState + + +class TestGenericPass(PassManagerTestCase): + """Tests for the GenericPass subclass.""" + + def setUp(self): + super().setUp() + + self.state = PassManagerState( + workflow_status=WorkflowStatus(), + property_set=PropertySet(), + ) + + def test_run_task(self): + """Test case: Simple successful task execution.""" + + class Task(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + + task = Task() + data = "test_data" + expected = [r"Pass: Task - (\d*\.)?\d+ \(ms\)"] + + with self.assertLogContains(expected): + task.execute(passmanager_ir=data, state=self.state) + self.assertEqual(self.state.workflow_status.count, 1) + self.assertIn(task, self.state.workflow_status.completed_passes) + self.assertEqual(self.state.workflow_status.previous_run, RunState.SUCCESS) + + def test_failure_task(self): + """Test case: Log is created regardless of success.""" + + class TestError(Exception): + pass + + class RaiseError(GenericPass): + def run(self, passmanager_ir): + raise TestError() + + task = RaiseError() + data = "test_data" + expected = [r"Pass: RaiseError - (\d*\.)?\d+ \(ms\)"] + + with self.assertLogContains(expected): + with self.assertRaises(TestError): + task.execute(passmanager_ir=data, state=self.state) + self.assertEqual(self.state.workflow_status.count, 0) + self.assertNotIn(task, self.state.workflow_status.completed_passes) + self.assertEqual(self.state.workflow_status.previous_run, RunState.FAIL) + + def test_requires(self): + """Test case: Dependency tasks are run in advance to user provided task.""" + + class TaskA(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + + class TaskB(GenericPass): + def __init__(self): + super().__init__() + self.requires = [TaskA()] + + def run(self, passmanager_ir): + return passmanager_ir + + task = TaskB() + data = "test_data" + expected = [ + r"Pass: TaskA - (\d*\.)?\d+ \(ms\)", + r"Pass: TaskB - (\d*\.)?\d+ \(ms\)", + ] + with self.assertLogContains(expected): + task.execute(passmanager_ir=data, state=self.state) + self.assertEqual(self.state.workflow_status.count, 2) + + def test_requires_in_list(self): + """Test case: Dependency tasks are not executed multiple times.""" + + class TaskA(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + + class TaskB(GenericPass): + def __init__(self): + super().__init__() + self.requires = [TaskA()] + + def run(self, passmanager_ir): + return passmanager_ir + + task = TaskB() + data = "test_data" + expected = [ + r"Pass: TaskB - (\d*\.)?\d+ \(ms\)", + ] + self.state.workflow_status.completed_passes.add(task.requires[0]) # already done + with self.assertLogContains(expected): + task.execute(passmanager_ir=data, state=self.state) + self.assertEqual(self.state.workflow_status.count, 1) + + def test_run_with_callable(self): + """Test case: Callable is called after generic pass is run.""" + + # pylint: disable=unused-argument + def test_callable(task, passmanager_ir, property_set, running_time, count): + logger = getLogger() + logger.info("%s is running on %s", task.name(), passmanager_ir) + + class Task(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + + task = Task() + data = "test_data" + expected = [ + r"Pass: Task - (\d*\.)?\d+ \(ms\)", + r"Task is running on test_data", + ] + with self.assertLogContains(expected): + task.execute(passmanager_ir=data, state=self.state, callback=test_callable) diff --git a/test/python/passmanager/test_passmanager.py b/test/python/passmanager/test_passmanager.py new file mode 100644 index 000000000000..b7ffdbd6029b --- /dev/null +++ b/test/python/passmanager/test_passmanager.py @@ -0,0 +1,126 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023 +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=missing-class-docstring + +"""Pass manager test cases.""" + +from test.python.passmanager import PassManagerTestCase + +from qiskit.passmanager import GenericPass, BasePassManager +from qiskit.passmanager.flow_controllers import DoWhileController, ConditionalController + + +class RemoveFive(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir.replace("5", "") + + +class AddDigit(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + "0" + + +class CountDigits(GenericPass): + def run(self, passmanager_ir): + self.property_set["ndigits"] = len(passmanager_ir) + + +class ToyPassManager(BasePassManager): + def _passmanager_frontend(self, input_program, **kwargs): + return str(input_program) + + def _passmanager_backend(self, passmanager_ir, in_program, **kwargs): + return int(passmanager_ir) + + +class TestPassManager(PassManagerTestCase): + def test_single_task(self): + """Test case: Pass manager with a single task.""" + + task = RemoveFive() + data = 12345 + pm = ToyPassManager(task) + expected = [r"Pass: RemoveFive - (\d*\.)?\d+ \(ms\)"] + with self.assertLogContains(expected): + out = pm.run(data) + self.assertEqual(out, 1234) + + def test_property_set(self): + """Test case: Pass manager can access property set.""" + + task = CountDigits() + data = 12345 + pm = ToyPassManager(task) + pm.run(data) + self.assertDictEqual(pm.property_set, {"ndigits": 5}) + + def test_do_while_controller(self): + """Test case: Do while controller that repeats tasks until the condition is met.""" + + def _condition(property_set): + return property_set["ndigits"] < 7 + + controller = DoWhileController([AddDigit(), CountDigits()], do_while=_condition) + data = 12345 + pm = ToyPassManager(controller) + pm.property_set["ndigits"] = 5 + expected = [ + r"Pass: AddDigit - (\d*\.)?\d+ \(ms\)", + r"Pass: CountDigits - (\d*\.)?\d+ \(ms\)", + r"Pass: AddDigit - (\d*\.)?\d+ \(ms\)", + r"Pass: CountDigits - (\d*\.)?\d+ \(ms\)", + ] + with self.assertLogContains(expected): + out = pm.run(data) + self.assertEqual(out, 1234500) + + def test_conditional_controller(self): + """Test case: Conditional controller that run task when the condition is met.""" + + def _condition(property_set): + return property_set["ndigits"] > 6 + + controller = ConditionalController([RemoveFive()], condition=_condition) + data = [123456789, 45654, 36785554] + pm = ToyPassManager([CountDigits(), controller]) + out = pm.run(data) + self.assertListEqual(out, [12346789, 45654, 36784]) + + def test_string_input(self): + """Test case: Running tasks once for a single string input. + + Details: + When the pass manager receives a sequence of input values, + it duplicates itself and run the tasks on each input element in parallel. + If the input is string, this can be accidentally recognized as a sequence. + """ + + class StringPassManager(BasePassManager): + def _passmanager_frontend(self, input_program, **kwargs): + return input_program + + def _passmanager_backend(self, passmanager_ir, in_program, **kwargs): + return passmanager_ir + + class Task(GenericPass): + def run(self, passmanager_ir): + return passmanager_ir + + task = Task() + data = "12345" + pm = StringPassManager(task) + + # Should be run only one time + expected = [r"Pass: Task - (\d*\.)?\d+ \(ms\)"] + with self.assertLogContains(expected): + out = pm.run(data) + self.assertEqual(out, data)