Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Consistent behaviour of StateFn.eval with OperatorBase.eval #1210

Merged
merged 16 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions qiskit/aqua/operators/operator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,19 @@ def eval(self,
defined to be evaluated from Zero implicitly (i.e. it is as if ``.eval('0000')`` is already
called implicitly to always "indexing" from column 0).

If ``front`` is None, the matrix-representation of the operator is returned.

Args:
front: The bitstring, dict of bitstrings (with values being coefficients), or
StateFn to evaluated by the Operator's underlying function.
StateFn to evaluated by the Operator's underlying function, or None.

Returns:
The output of the Operator's evaluation function. If self is a ``StateFn``, the result
is a float or complex. If self is an Operator (``PrimitiveOp, ComposedOp, SummedOp,
EvolvedOp,`` etc.), the result is a StateFn. If either self or front contain proper
EvolvedOp,`` etc.), the result is a StateFn.
If ``front`` is None, the matrix-representation of the operator is returned, which
is a ``MatrixOp`` for the operators and a ``VectorStateFn`` for state-functions.
If either self or front contain proper
``ListOps`` (not ListOp subclasses), the result is an n-dimensional list of complex
or StateFn results, resulting from the recursive evaluation by each OperatorBase
in the ListOps.
Expand Down
6 changes: 3 additions & 3 deletions qiskit/aqua/operators/primitive_ops/circuit_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

""" CircuitOp Class """

from typing import Union, Optional, Set, List, cast
from typing import Union, Optional, Set, List, Dict, cast
import logging
import numpy as np

Expand Down Expand Up @@ -186,8 +186,8 @@ def assign_parameters(self, param_dict: dict) -> OperatorBase:
return self.__class__(qc, coeff=param_value)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
# pylint: disable=import-outside-toplevel
from ..state_fns import CircuitStateFn
from ..list_ops import ListOp
Expand Down
6 changes: 3 additions & 3 deletions qiskit/aqua/operators/primitive_ops/matrix_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

""" MatrixOp Class """

from typing import Union, Optional, Set
from typing import Union, Optional, Set, Dict
import logging
import numpy as np
from scipy.sparse import spmatrix
Expand Down Expand Up @@ -134,8 +134,8 @@ def __str__(self) -> str:
return "{} * {}".format(self.coeff, prim_str)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
# For other ops' eval we return self.to_matrix_op() here, but that's unnecessary here.
if front is None:
return self
Expand Down
6 changes: 3 additions & 3 deletions qiskit/aqua/operators/primitive_ops/pauli_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

""" PauliOp Class """

from typing import Union, Set, Dict, cast
from typing import Union, Set, Dict, Optional, cast
import logging
import numpy as np
from scipy.sparse import spmatrix
Expand Down Expand Up @@ -145,8 +145,8 @@ def __str__(self) -> str:
return "{} * {}".format(self.coeff, prim_str)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
if front is None:
return self.to_matrix_op()

Expand Down
6 changes: 3 additions & 3 deletions qiskit/aqua/operators/primitive_ops/primitive_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

""" PrimitiveOp Class """

from typing import Optional, Union, Set, List
from typing import Optional, Union, Set, List, Dict
import logging
import numpy as np
from scipy.sparse import spmatrix
Expand Down Expand Up @@ -213,8 +213,8 @@ def __repr__(self) -> str:
return "{}({}, coeff={})".format(type(self).__name__, repr(self.primitive), self.coeff)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
raise NotImplementedError

@property
Expand Down
11 changes: 8 additions & 3 deletions qiskit/aqua/operators/state_fns/circuit_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" CircuitStateFn Class """


from typing import Union, Set, List, cast
from typing import Union, Set, List, Optional, Dict, cast
import numpy as np

from qiskit import QuantumCircuit, BasicAer, execute, ClassicalRegister
Expand Down Expand Up @@ -275,8 +275,13 @@ def assign_parameters(self, param_dict: dict) -> OperatorBase:
return self.__class__(qc, coeff=param_value, is_measurement=self.is_measurement)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
if front is None:
vector_state_fn = self.to_matrix_op().eval()
vector_state_fn = cast(OperatorBase, vector_state_fn)
return vector_state_fn

if not self.is_measurement and isinstance(front, OperatorBase):
raise ValueError(
'Cannot compute overlap with StateFn or Operator if not Measurement. Try taking '
Expand Down
10 changes: 7 additions & 3 deletions qiskit/aqua/operators/state_fns/dict_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

""" DictStateFn Class """

from typing import Union, Set, cast
from typing import Optional, Union, Set, cast
woodsp-ibm marked this conversation as resolved.
Show resolved Hide resolved
import itertools
import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -182,8 +182,12 @@ def __str__(self) -> str:

# pylint: disable=too-many-return-statements
def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
if front is None:
vector_state_fn = self.to_matrix_op().eval()
vector_state_fn = cast(OperatorBase, vector_state_fn)
return vector_state_fn

if not self.is_measurement and isinstance(front, OperatorBase):
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions qiskit/aqua/operators/state_fns/operator_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..operator_base import OperatorBase
from .state_fn import StateFn
from .vector_state_fn import VectorStateFn
from ..list_ops.list_op import ListOp
from ..list_ops.summed_op import SummedOp

Expand Down Expand Up @@ -181,6 +182,10 @@ def __str__(self) -> str:
def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
if front is None:
matrix = self.primitive.to_matrix_op().primitive.data
return VectorStateFn(matrix[0, :])

if not self.is_measurement and isinstance(front, OperatorBase):
raise ValueError(
'Cannot compute overlap with StateFn or Operator if not Measurement. Try taking '
Expand Down
4 changes: 2 additions & 2 deletions qiskit/aqua/operators/state_fns/state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def __repr__(self) -> str:
self.coeff, self.is_measurement)

def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
raise NotImplementedError

@property
Expand Down
9 changes: 6 additions & 3 deletions qiskit/aqua/operators/state_fns/vector_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" VectorStateFn Class """


from typing import Union, Set
from typing import Union, Set, Optional, Dict
import numpy as np

from qiskit.quantum_info import Statevector
Expand Down Expand Up @@ -128,8 +128,11 @@ def __str__(self) -> str:

# pylint: disable=too-many-return-statements
def eval(self,
front: Union[str, dict, np.ndarray,
OperatorBase] = None) -> Union[OperatorBase, float, complex]:
front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None
) -> Union[OperatorBase, float, complex]:
if front is None: # this object is already a VectorStateFn
return self

if not self.is_measurement and isinstance(front, OperatorBase):
raise ValueError(
'Cannot compute overlap with StateFn or Operator if not Measurement. Try taking '
Expand Down
8 changes: 8 additions & 0 deletions releasenotes/notes/statefn-eval-51ecf38a7a3cc087.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
features:
- |
Allow calling ``eval`` on state function objects with no argument, which returns the
``VectorStateFn`` representation of the state function.
This is consistent behaviour with ``OperatorBase.eval``, which returns the
``MatrixOp`` representation, if no argument is passed.
woodsp-ibm marked this conversation as resolved.
Show resolved Hide resolved

19 changes: 17 additions & 2 deletions test/aqua/operators/test_op_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from qiskit.circuit import QuantumCircuit, QuantumRegister, Instruction, Parameter, ParameterVector

from qiskit.extensions.exceptions import ExtensionError
from qiskit.quantum_info.operators import Operator, Pauli
from qiskit.quantum_info import Operator, Pauli, Statevector
from qiskit.circuit.library import CZGate, ZGate

from qiskit.aqua.operators import (
X, Y, Z, I, CX, T, H, PrimitiveOp, SummedOp, PauliOp, Minus, CircuitOp, MatrixOp, ListOp,
ComposedOp, StateFn
ComposedOp, StateFn, VectorStateFn, OperatorStateFn, CircuitStateFn, DictStateFn,
)


Expand Down Expand Up @@ -619,6 +619,21 @@ def test_list_op_parameters(self):
params.append(lam)
self.assertEqual(list_op.parameters, set(params))

def test_statefn_eval(self):
"""Test calling eval on StateFn returns the statevector."""
qc = QuantumCircuit(1)
ops = [VectorStateFn([1, 0]),
woodsp-ibm marked this conversation as resolved.
Show resolved Hide resolved
DictStateFn({'0': 1}),
CircuitStateFn(qc),
OperatorStateFn(I),
OperatorStateFn(MatrixOp([[1, 0], [0, 1]])),
OperatorStateFn(CircuitOp(qc))]

expected = Statevector([1, 0])
for op in ops:
with self.subTest(op):
self.assertEqual(op.eval().primitive, expected)


class TestOpMethods(QiskitAquaTestCase):
"""Basic method tests."""
Expand Down