Skip to content

Commit

Permalink
add docstring and apply PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MarionQuandela committed Sep 19, 2024
1 parent 7b326c1 commit 787a5ac
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 13 deletions.
6 changes: 3 additions & 3 deletions perceval/components/core_catalog/mzi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self):

def build_circuit(self, **kwargs) -> Circuit:
phi_a, phi_b, theta_a, theta_b = self._handle_params(**kwargs)
return (Circuit(2, name="MZI")
return (Circuit(2, name=_NAME_MZI_PHASE_FIRST)
// (0, PS(phi=phi_a)) // BS(theta=theta_a) // (0, PS(phi=phi_b)) // BS(theta=theta_b))


Expand All @@ -96,7 +96,7 @@ def __init__(self):

def build_circuit(self, **kwargs) -> Circuit:
phi_a, phi_b, theta_a, theta_b = self._handle_params(**kwargs)
return (Circuit(2, name="MZI")
return (Circuit(2, name=_NAME_MZI_PHASE_LAST)
// BS(theta=theta_a) // (1, PS(phi=phi_a)) // BS(theta=theta_b) // (1, PS(phi=phi_b)))


Expand All @@ -114,5 +114,5 @@ def __init__(self):

def build_circuit(self, **kwargs) -> Circuit:
phi_a, phi_b, theta_a, theta_b = self._handle_params(**kwargs)
return (Circuit(2, name="MZI")
return (Circuit(2, name=_NAME_SYMMETRIC_MZI)
// BS(theta=theta_a) // (0, PS(phi=phi_a)) // (1, PS(phi=phi_b))) // BS(theta=theta_b)
30 changes: 23 additions & 7 deletions perceval/components/generic_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ class GenericInterferometer(Circuit):
:param m: number of modes
:param fun_gen: generator function for the building components, index is an integer allowing to generate
named parameters - for instance:
:code:`fun_gen=lambda idx: phys.BS()//(0, phys.PS(pcvl.P(f"phi_{idx}")))`
:code:`fun_gen=lambda idx: pcvl.BS()//(0, pcvl.PS(pcvl.P(f"phi_{idx}")))`
:param shape: The output interferometer shape (InterferometerShape.RECTANGLE or InterferometerShape.TRIANGLE)
:param depth: if None, maximal depth is :math:`m-1` for rectangular shape, :math:`m` for triangular shape.
Can be used with :math:`2*m` to reproduce :cite:`fldzhyan2020optimal`.
:param phase_shifter_fun_gen: a function generating a phase_shifter circuit.
:param phase_at_output: if True creates a layer of phase shifters at the output of the generated interferometer
else creates it in the input (default: False)
:param upper_component_gen fun_gen: generator function for the building the upper component, index is an integer allowing to generate
named parameters - for instance:
:code:`fun_gen=lambda idx: pcvl.PS(pcvl.P(f"phi_upper_{idx}"))`
:param lower_component_gen: generator function for the building the lower component, index is an integer allowing to generate
named parameters - for instance:
:code:`fun_gen=lambda idx: pcvl.PS(pcvl.P(f"phi_lower_{idx}"))`
See :cite:`fldzhyan2020optimal`, :cite:`clements2016optimal` and :cite:`reck1994experimental`
"""
Expand Down Expand Up @@ -114,20 +120,30 @@ def set_identity_mode(self):
p.set_value(math.pi)

def _add_component(self, mode: int, component: ACircuit) -> None:
"""Add a component to the circuit, check if it's a one mode circuit
:param mode: mode to add the component
:param component: component to add
"""
assert component.m == 1, f"Component should always be a one mode circuit, instead it's a {component.m} modes circuit"
self.add(mode, component)

def _add_upper_component(self, i_depth: int) -> bool:
def _add_upper_component(self, i_depth: int) -> None:
"""Add a component with upper_component_gen between the interferometer on the first mode
:param i_depth: depth index of the interferometer
"""
if self._upper_component_gen and i_depth % 2 == 1:
self._add_component(0, self._upper_component_gen(int(i_depth/2)))
return True
return False

def _add_lower_component(self, i_depth: int) -> bool:
def _add_lower_component(self, i_depth: None) -> None:
"""Add a component with lower_component_gen between the interferometer on the last mode
:param i_depth: depth index of the interferometer
"""
# If m is even, the component is added at even depth index, else it's added in at odd depth index
if self._lower_component_gen and (i_depth % 2 == 1 and self.m % 2 == 0 or i_depth % 2 == 0 and self.m % 2 == 1):
self._add_component(self.m-1, self._lower_component_gen(int(i_depth/2)))
return True
return False

def _build_rectangle(self):
max_depth = self.m if self._depth is None else self._depth
Expand Down
2 changes: 1 addition & 1 deletion perceval/components/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def log_resources(self, method: str, extra_parameters: Dict):
elif isinstance(self._input_state, SVDistribution):
my_dict['n'] = self._input_state.n_max
else:
get_logger().info(f"Cannot get n for type {type(self._input_state)}", channel.resource)
get_logger().info(f"Cannot get n for type {type(self._input_state)}", channel.general)
if extra_parameters:
my_dict.update(extra_parameters)
if self.noise: # TODO: PCVL-782
Expand Down
1 change: 0 additions & 1 deletion perceval/utils/algorithms/circuit_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def _gen_ps(i: int):
template_component_generator_func,
phase_shifter_fun_gen=_gen_ps,
phase_at_output=phase_at_output)
barrier_free_template = Circuit(template.m)
result_circuit, fidelity = self.optimize(target, template)
if fidelity < 1 - self._threshold:
if allow_error:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generic_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_set_param_list():

interferometer = GenericInterferometer(size, mzi_generator_func)
interferometer.set_param_list(values, (0, 0), m=2)
# With m=2, online one row of phase shifters get impacted (on mode 1, given the mzi we used)
# With m=2, only one row of phase shifters get impacted (on mode 1, given the mzi we used)
# The 10 first phase shifter of this row get phi=values[idx]
indexes = [1, 3, 7, 9, 13, 15, 19, 21, 25, 27]
for idx, phase_shifter_pos in enumerate(indexes):
Expand Down

0 comments on commit 787a5ac

Please sign in to comment.