Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change import to cirq_google from cirq.google #643

Merged
merged 5 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions tensorflow_quantum/core/ops/circuit_execution_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized
from scipy import stats
import cirq
import cirq_google

from tensorflow_quantum.core.ops import batch_util, circuit_execution_ops
from tensorflow_quantum.python import util
Expand Down Expand Up @@ -91,9 +92,9 @@ def test_get_expectation_inputs(self):
expected_regex='Sample-based'):
mock_engine = mock.Mock()
circuit_execution_ops.get_expectation_op(
cirq.google.QuantumEngineSampler(engine=mock_engine,
cirq_google.QuantumEngineSampler(engine=mock_engine,
processor_id='test',
gate_set=cirq.google.XMON))
gate_set=cirq_google.XMON))
with self.assertRaisesRegex(
TypeError,
expected_regex="cirq.sim.simulator.SimulatesExpectationValues"):
Expand All @@ -112,9 +113,9 @@ def test_get_sampled_expectation_inputs(self):
backend=cirq.DensityMatrixSimulator())
mock_engine = mock.Mock()
circuit_execution_ops.get_sampled_expectation_op(
cirq.google.QuantumEngineSampler(engine=mock_engine,
cirq_google.QuantumEngineSampler(engine=mock_engine,
processor_id='test',
gate_set=cirq.google.XMON))
gate_set=cirq_google.XMON))
with self.assertRaisesRegex(TypeError, expected_regex="a Cirq.Sampler"):
circuit_execution_ops.get_sampled_expectation_op(backend="junk")

Expand All @@ -131,9 +132,9 @@ def test_get_samples_inputs(self):
backend=cirq.DensityMatrixSimulator())
mock_engine = mock.Mock()
circuit_execution_ops.get_sampling_op(
backend=cirq.google.QuantumEngineSampler(engine=mock_engine,
backend=cirq_google.QuantumEngineSampler(engine=mock_engine,
processor_id='test',
gate_set=cirq.google.XMON))
gate_set=cirq_google.XMON))
with self.assertRaisesRegex(TypeError,
expected_regex="Expected a Cirq.Sampler"):
circuit_execution_ops.get_sampling_op(backend="junk")
Expand All @@ -155,10 +156,10 @@ def test_get_state_inputs(self):
expected_regex="Cirq.SimulatesFinalState"):
mock_engine = mock.Mock()
circuit_execution_ops.get_state_op(
backend=cirq.google.QuantumEngineSampler(
backend=cirq_google.QuantumEngineSampler(
engine=mock_engine,
processor_id='test',
gate_set=cirq.google.XMON))
gate_set=cirq_google.XMON))

with self.assertRaisesRegex(TypeError,
expected_regex="must be type bool."):
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_quantum/core/ops/cirq_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import tensorflow as tf
import cirq
import cirq_google

from tensorflow_quantum.core.ops import batch_util
from tensorflow_quantum.core.proto import pauli_sum_pb2
Expand Down Expand Up @@ -490,7 +491,7 @@ def _no_grad(grad):
]
max_n_qubits = max(len(p.all_qubits()) for p in programs)

if isinstance(sampler, cirq.google.QuantumEngineSampler):
if isinstance(sampler, cirq_google.QuantumEngineSampler):
# group samples from identical circuits to reduce communication
# overhead. Have to keep track of the order in which things came
# in to make sure the output is ordered correctly
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_quantum/core/ops/cirq_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf
from absl.testing import parameterized
import cirq
import cirq_google

from tensorflow_quantum.core.ops import cirq_ops
from tensorflow_quantum.core.serialize import serializer
Expand Down Expand Up @@ -342,9 +343,9 @@ def test_get_cirq_sampling_op(self):
cirq_ops._get_cirq_samples(cirq.DensityMatrixSimulator())
mock_engine = mock.Mock()
cirq_ops._get_cirq_samples(
cirq.google.QuantumEngineSampler(engine=mock_engine,
cirq_google.QuantumEngineSampler(engine=mock_engine,
processor_id='test',
gate_set=cirq.google.XMON))
gate_set=cirq_google.XMON))

def test_cirq_sampling_op_inputs(self):
"""test input checking in the cirq sampling op."""
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_quantum/core/serialize/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(self,
self.op_wrapper = op_wrapper

def from_proto(self, proto, *, arg_function_language=''):
"""Turns a cirq.google.api.v2.Operation proto into a GateOperation."""
"""Turns a cirq_google.api.v2.Operation proto into a GateOperation."""
qubits = [qubit_from_proto(q.id) for q in proto.qubits]
args = self._args_from_proto(
proto, arg_function_language=arg_function_language)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_quantum/core/serialize/op_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def to_proto(
*,
arg_function_language='',
):
"""Returns the cirq.google.api.v2.Operation message as a proto dict."""
"""Returns the cirq_google.api.v2.Operation message as a proto dict."""

gate = op.gate
if not isinstance(gate, self.gate_type):
Expand Down
18 changes: 9 additions & 9 deletions tensorflow_quantum/core/serialize/serializable_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for serializing and deserializing cirq.google.api.v2 protos."""
"""Support for serializing and deserializing cirq_google.api.v2 protos."""

import cirq
from tensorflow_quantum.core.proto import program_pb2
Expand Down Expand Up @@ -56,7 +56,7 @@ def _function_languages_from_arg(arg_proto):
class SerializableGateSet:
"""A class for serializing and deserializing programs and operations.

This class is for cirq.google.api.v2. protos.
This class is for cirq_google.api.v2. protos.
"""

def __init__(self, gate_set_name, serializers, deserializers):
Expand Down Expand Up @@ -118,7 +118,7 @@ def is_supported_operation(self, op):
return False

def serialize(self, program, msg=None, *, arg_function_language=None):
"""Serialize a Circuit to cirq.google.api.v2.Program proto.
"""Serialize a Circuit to cirq_google.api.v2.Program proto.

Args:
program: The Circuit to serialize.
Expand Down Expand Up @@ -146,13 +146,13 @@ def serialize_op(
*,
arg_function_language='',
):
"""Serialize an Operation to cirq.google.api.v2.Operation proto.
"""Serialize an Operation to cirq_google.api.v2.Operation proto.

Args:
op: The operation to serialize.

Returns:
A dictionary corresponds to the cirq.google.api.v2.Operation proto.
A dictionary corresponds to the cirq_google.api.v2.Operation proto.
"""
gate_type = type(op.gate)
for gate_type_mro in gate_type.mro():
Expand All @@ -169,10 +169,10 @@ def serialize_op(
op, gate_type))

def deserialize(self, proto, device=None):
"""Deserialize a Circuit from a cirq.google.api.v2.Program.
"""Deserialize a Circuit from a cirq_google.api.v2.Program.

Args:
proto: A dictionary representing a cirq.google.api.v2.Program proto.
proto: A dictionary representing a cirq_google.api.v2.Program proto.
device: If the proto is for a schedule, a device is required
Otherwise optional.

Expand Down Expand Up @@ -200,11 +200,11 @@ def deserialize_op(
*,
arg_function_language='',
):
"""Deserialize an Operation from a cirq.google.api.v2.Operation.
"""Deserialize an Operation from a cirq_google.api.v2.Operation.

Args:
operation_proto: A dictionary representing a
cirq.google.api.v2.Operation proto.
cirq_google.api.v2.Operation proto.

Returns:
The deserialized Operation.
Expand Down