Skip to content

Commit

Permalink
Cast initial state into complex dtype for default.qubit + Update `S…
Browse files Browse the repository at this point in the history
…um.label` + fix `qml.compile` (#6198)

* Upgrading `qml` to numpy 1.25 caused a bunch of `ComplexWarning`s to
show up because we start out with a float initial state, which most
likely gets cast into complex during execution. Then, during backprop,
the complex state would get cast into float, which would raise the
aforementioned warnings.

* `Sum.label` was also overly detailed, causing `qml.draw` to look
terrible. I changed the label to be consistent with `LinearCombination`
and `Hamiltonian`.

* `qml.compile` was not decomposing state prep ops, which is now
changed.
  • Loading branch information
mudit2812 authored Sep 3, 2024
1 parent f521ada commit 3d39e18
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 20 deletions.
9 changes: 7 additions & 2 deletions pennylane/devices/qubit/initialize_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ def create_initial_state(
"""
if not prep_operation:
num_wires = len(wires)
state = np.zeros((2,) * num_wires)
state = np.zeros((2,) * num_wires, dtype=complex)
state[(0,) * num_wires] = 1
return qml.math.asarray(state, like=like)

return qml.math.asarray(prep_operation.state_vector(wire_order=list(wires)), like=like)
state_vector = prep_operation.state_vector(wire_order=list(wires))
dtype = str(state_vector.dtype)
floating_single = "float32" in dtype or "complex64" in dtype
dtype = "complex64" if floating_single else "complex128"
dtype = "complex128" if like == "tensorflow" else dtype
return qml.math.cast(qml.math.asarray(state_vector, like=like), dtype)
4 changes: 0 additions & 4 deletions pennylane/ops/op_math/linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,6 @@ def _build_pauli_rep_static(coeffs, observables):
def _check_batching(self):
"""Override for LinearCombination, batching is not yet supported."""

def label(self, decimals=None, base_label=None, cache=None):
decimals = None if (len(self.parameters) > 3) else decimals
return Operator.label(self, decimals=decimals, base_label=base_label or "𝓗", cache=cache)

@property
def coeffs(self):
"""Return the coefficients defining the LinearCombination.
Expand Down
6 changes: 5 additions & 1 deletion pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def is_hermitian(self):

return all(s.is_hermitian for s in self)

def label(self, decimals=None, base_label=None, cache=None):
decimals = None if (len(self.parameters) > 3) else decimals
return Operator.label(self, decimals=decimals, base_label=base_label or "𝓗", cache=cache)

def matrix(self, wire_order=None):
r"""Representation of the operator as a matrix in the computational basis.
Expand Down Expand Up @@ -466,7 +470,7 @@ def terms(self):
ops.append(factor)
return coeffs, ops

def compute_grouping(self, grouping_type="qwc", method="rlf"):
def compute_grouping(self, grouping_type="qwc", method="lf"):
"""
Compute groups of operators and coefficients corresponding to commuting
observables of this Sum.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/ops/qubit/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
def _compute_grouping_indices(
observables: list[Observable],
grouping_type: Literal["qwc", "commuting", "anticommuting"] = "qwc",
method: Literal["lf", "rlf"] = "rlf",
method: Literal["lf", "rlf"] = "lf",
):
# todo: directly compute the
# indices, instead of extracting groups of observables first
Expand Down Expand Up @@ -467,7 +467,7 @@ def grouping_indices(self, value: Iterable[Iterable[int]]):
def compute_grouping(
self,
grouping_type: Literal["qwc", "commuting", "anticommuting"] = "qwc",
method: Literal["lf", "rlf"] = "rlf",
method: Literal["lf", "rlf"] = "lf",
):
"""
Compute groups of indices corresponding to commuting observables of this
Expand Down
1 change: 1 addition & 0 deletions pennylane/transforms/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def stop_at(obj):
max_expansion=expand_depth,
name="compile",
error=qml.operation.DecompositionUndefinedError,
skip_initial_state_prep=False,
)

# Apply the full set of compilation transforms num_passes times
Expand Down
28 changes: 27 additions & 1 deletion tests/devices/qubit/test_initialize_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@ class DefaultPrep(StatePrepBase):

num_wires = qml.operation.AllWires

def __init__(self, *args, **kwargs):
self.dtype = kwargs.pop("dtype", None)
super().__init__(*args, **kwargs)

def state_vector(self, wire_order=None):
return self.parameters[0]
sv = self.parameters[0]
if self.dtype is not None:
sv = qml.math.cast(sv, self.dtype)
return sv

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
Expand All @@ -40,6 +47,7 @@ def test_create_initial_state_no_state_prep(self, interface):
state = create_initial_state([0, 1], like=interface)
assert qml.math.allequal(state, [[1, 0], [0, 0]])
assert qml.math.get_interface(state) == interface
assert "complex" in str(state.dtype)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
Expand Down Expand Up @@ -84,7 +92,25 @@ def test_create_initial_state_casts_to_like_with_prep_op(self):
state = create_initial_state([0, 1], prep_operation=prep_op, like="torch")
assert qml.math.get_interface(state) == "torch"

@pytest.mark.torch
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_create_initial_state_with_stateprep_casts_to_complex(self, dtype):
"""Test that the state gets cast to complex with the correct precision"""
expected_dtype = "complex128" if dtype == "float64" else "complex64"
prep_op = self.DefaultPrep([0, 0, 0, 1], wires=[0, 1], dtype=dtype)
res_dtype = create_initial_state([0, 1], prep_operation=prep_op, like="torch").dtype
assert expected_dtype in str(res_dtype)

@pytest.mark.tf
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_create_initial_state_with_stateprep_casts_to_complex128_with_tf(self, dtype):
"""Test that the state gets cast to complex128 with tensorflow"""
prep_op = self.DefaultPrep([0, 0, 0, 1], wires=[0, 1], dtype=dtype)
res_dtype = create_initial_state([0, 1], prep_operation=prep_op, like="tensorflow").dtype
assert "complex128" in str(res_dtype)

def test_create_initial_state_defaults_to_numpy(self):
"""Tests that the default interface is vanilla numpy."""
state = qml.devices.qubit.create_initial_state((0, 1))
assert qml.math.get_interface(state) == "numpy"
assert state.dtype == np.complex128
2 changes: 1 addition & 1 deletion tests/drawer/test_tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_setting_max_length(self, ml):
qml.expval(
0.1 * qml.PauliX(0) + 0.2 * qml.PauliY(1) + 0.3 * qml.PauliZ(0) + 0.4 * qml.PauliZ(1)
),
"0: ───┤ ╭<(0.10*X)+(0.20*Y)+(0.30*Z)+(0.40*Z)>\n1: ───┤ ╰<(0.10*X)+(0.20*Y)+(0.30*Z)+(0.40*Z)>",
"0: ───┤ ╭<𝓗>\n1: ───┤ ╰<𝓗>",
),
# Operations (both regular and controlled) and nested multi-valued controls
(qml.ctrl(qml.PauliX(wires=2), control=[0, 1]), "0: ─╭●─┤ \n1: ─├●─┤ \n2: ─╰X─┤ "),
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/op_math/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_label(self):

base = qml.S(0) + qml.T(0)
op = Adjoint(base)
assert op.label() == "(S+T)†"
assert op.label() == "𝓗†"

def test_adjoint_of_adjoint(self):
"""Test that the adjoint of an adjoint is the original operation."""
Expand Down
22 changes: 17 additions & 5 deletions tests/ops/op_math/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,18 @@ def test_grouping_indices_setter_error(self):
):
H.grouping_indices = [[0, 1, 3], [2]]

def test_label(self):
"""Tests the label method of Sum when <=3 coefficients."""
H = qml.ops.Sum(-0.8 * Z(0))
assert H.label() == "𝓗"
assert H.label(decimals=2) == "𝓗\n(-0.80)"

def test_label_many_coefficients(self):
"""Tests the label method of Sum when >3 coefficients."""
H = qml.ops.Sum(*(0.1 * qml.Z(0) for _ in range(5)))
assert H.label() == "𝓗"
assert H.label(decimals=2) == "𝓗"


class TestSimplify:
"""Test Sum simplify method and depth property."""
Expand Down Expand Up @@ -1464,7 +1476,7 @@ def test_grouping_method_can_be_set(self):

@pytest.mark.parametrize(
"grouping_type, grouping_indices",
[("commuting", ((0, 1), (2,))), ("anticommuting", ((1,), (0, 2)))],
[("commuting", {(0, 1), (2,)}), ("anticommuting", {(1,), (0, 2)})],
)
def test_grouping_type_can_be_set(self, grouping_type, grouping_indices):
"""Tests that the grouping type can be controlled by kwargs.
Expand All @@ -1478,21 +1490,21 @@ def test_grouping_type_can_be_set(self, grouping_type, grouping_indices):

# compute grouping during construction with qml.dot
op1 = qml.dot(coeffs, obs, grouping_type=grouping_type)
assert op1.grouping_indices == grouping_indices
assert set(op1.grouping_indices) == grouping_indices

# compute grouping during construction with qml.sum
sprods = [qml.s_prod(c, o) for c, o in zip(coeffs, obs)]
op2 = qml.sum(*sprods, grouping_type=grouping_type)
assert op2.grouping_indices == grouping_indices
assert set(op2.grouping_indices) == grouping_indices

# compute grouping during construction with Sum
op3 = Sum(*sprods, grouping_type=grouping_type)
assert op3.grouping_indices == grouping_indices
assert set(op3.grouping_indices) == grouping_indices

# compute grouping separately
op4 = qml.dot(coeffs, obs, grouping_type=None)
op4.compute_grouping(grouping_type=grouping_type)
assert op4.grouping_indices == grouping_indices
assert set(op4.grouping_indices) == grouping_indices

@pytest.mark.parametrize("shots", [None, 1000])
def test_grouping_integration(self, shots):
Expand Down
19 changes: 19 additions & 0 deletions tests/transforms/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def test_compile_pipeline_with_non_default_arguments(self, wires):

compare_operation_lists(transformed_qnode.qtape.operations, names_expected, wires_expected)

def test_compile_decomposes_state_prep(self):
"""Test that compile decomposes state prep operations"""

class DummyStatePrep(qml.operation.StatePrepBase):
"""Dummy state prep operation for unit testing"""

def decomposition(self):
return [qml.Hadamard(i) for i in self.wires]

def state_vector(self, wire_order=None): # pylint: disable=unused-argument
return self.parameters[0]

state_prep_op = DummyStatePrep([1, 1], wires=[0, 1])
tape = qml.tape.QuantumScript([state_prep_op])

[compiled_tape], _ = qml.compile(tape)
expected = qml.tape.QuantumScript(state_prep_op.decomposition())
qml.assert_equal(compiled_tape, expected)

@pytest.mark.parametrize(("wires"), [["a", "b", "c"], [0, 1, 2], [3, 1, 2], [0, "a", 4]])
def test_compile_multiple_passes(self, wires):
"""Test that running multiple passes produces the correct results."""
Expand Down
6 changes: 3 additions & 3 deletions tests/transforms/test_hamiltonian_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,17 @@ def test_constant_offset_grouping(self):

assert len(batch) == 2

tape_0 = qml.tape.QuantumScript([], [qml.expval(qml.Z(0))], shots=50)
tape_1 = qml.tape.QuantumScript(
tape_0 = qml.tape.QuantumScript(
[qml.RY(-np.pi / 2, 0), qml.RX(np.pi / 2, 1)],
[qml.expval(qml.Z(0)), qml.expval(qml.Z(0) @ qml.Z(1))],
shots=50,
)
tape_1 = qml.tape.QuantumScript([], [qml.expval(qml.Z(0))], shots=50)

qml.assert_equal(batch[0], tape_0)
qml.assert_equal(batch[1], tape_1)

dummy_res = (1.0, (1.0, 1.0))
dummy_res = ((1.0, 1.0), 1.0)
processed_res = fn(dummy_res)
assert qml.math.allclose(processed_res, 10.0)

Expand Down

0 comments on commit 3d39e18

Please sign in to comment.