Skip to content

Commit

Permalink
[Capture] higher order primitives store slices instead of integers (#…
Browse files Browse the repository at this point in the history
…6521)

**Context:**

As the higher order primitives `for`, `while`, and `cond` get more
feature-rich, especially with the incoming changes for dynamically
shaped arrays, it can be hard to keep track of what order the arguments
come in. This will get harder when we have to add in more positional
arguments for the dynamically shaped array dimensions.

**Description of the Change:**

Updates `for`, `while` and `cond` to store slices into the positionally
arguments instead of numbers of the different types of arguments. Now we
can simply consume:
```
args[provided_slice]
```
instead of having to do:
```
args[some_calcualtion_for_start: some_calculation_for_end]
```

I also rewrote some of the logic in the for conditional to make it
easier to construct all the relevant data.

**Benefits:**

**Possible Drawbacks:**

For the for loop, the slices into the arguments start after the `start,
stop, step` instead of including the offset by three. I wasn't quite
sure which one made more sense.

**Related GitHub Issues:**

[sc-77579]

---------

Co-authored-by: David Wierichs <[email protected]>
  • Loading branch information
albi3ro and dwierichs authored Nov 8, 2024
1 parent c1a7d3d commit 0d497ec
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 73 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
* `qml.BasisRotation` template is now JIT compatible.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)

* The Jaxpr primitives for `for_loop`, `while_loop` and `cond` now store slices instead of
numbers of args.
[(#6521)](https://github.com/PennyLaneAI/pennylane/pull/6521)

* Expand `ExecutionConfig.gradient_method` to store `TransformDispatcher` type.
[(#6455)](https://github.com/PennyLaneAI/pennylane/pull/6455)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/flatfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def __init__(self, f, in_tree=None):
self.out_tree = None
update_wrapper(self, f)

def __call__(self, *args):
def __call__(self, *args, **kwargs):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out = self.f(*args, **kwargs)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
return out_flat
65 changes: 38 additions & 27 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,11 @@ def _get_while_loop_qfunc_prim():
while_loop_prim.multiple_results = True

@while_loop_prim.def_impl
def _(*jaxpr_args, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond):
def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):

jaxpr_consts_body = jaxpr_args[:n_consts_body]
jaxpr_consts_cond = jaxpr_args[n_consts_body : n_consts_body + n_consts_cond]
init_state = jaxpr_args[n_consts_body + n_consts_cond :]
jaxpr_consts_body = args[body_slice]
jaxpr_consts_cond = args[cond_slice]
init_state = args[args_slice]

# If cond_fn(*init_state) is False, return the initial state
fn_res = init_state
Expand All @@ -425,9 +425,8 @@ def _(*jaxpr_args, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond):
return fn_res

@while_loop_prim.def_abstract_eval
def _(*_, jaxpr_body_fn, **__):

return [out.aval for out in jaxpr_body_fn.outvars]
def _(*args, args_slice, **__):
return args[args_slice]

return while_loop_prim

Expand Down Expand Up @@ -466,15 +465,22 @@ def _call_capture_enabled(self, *init_state):
jaxpr_body_fn = jax.make_jaxpr(flat_body_fn)(*init_state)
jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn)(*init_state)

n_bf_c = len(jaxpr_body_fn.consts)
n_cf_c = len(jaxpr_cond_fn.consts)
body_consts = slice(0, n_bf_c)
cond_consts = slice(n_bf_c, n_bf_c + n_cf_c)
args_slice = slice(n_cf_c + n_bf_c, None)

flat_args, _ = jax.tree_util.tree_flatten(init_state)
results = while_loop_prim.bind(
*jaxpr_body_fn.consts,
*jaxpr_cond_fn.consts,
*flat_args,
jaxpr_body_fn=jaxpr_body_fn.jaxpr,
jaxpr_cond_fn=jaxpr_cond_fn.jaxpr,
n_consts_body=len(jaxpr_body_fn.consts),
n_consts_cond=len(jaxpr_cond_fn.consts),
body_slice=body_consts,
cond_slice=cond_consts,
args_slice=args_slice,
)
assert flat_body_fn.out_tree is not None, "Should be set when constructing the jaxpr"
return jax.tree_util.tree_unflatten(flat_body_fn.out_tree, results)
Expand Down Expand Up @@ -625,24 +631,25 @@ def _get_for_loop_qfunc_prim():
for_loop_prim = create_non_interpreted_prim()("for_loop")
for_loop_prim.multiple_results = True

# pylint: disable=too-many-arguments
@for_loop_prim.def_impl
def _(lower_bound, upper_bound, step, *jaxpr_consts_and_init_state, jaxpr_body_fn, n_consts):
def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice):

jaxpr_consts = jaxpr_consts_and_init_state[:n_consts]
init_state = jaxpr_consts_and_init_state[n_consts:]
consts = args[consts_slice]
init_state = args[args_slice]

# in case lower_bound >= upper_bound, return the initial state
# in case start >= stop, return the initial state
fn_res = init_state

for i in range(lower_bound, upper_bound, step):
fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, jaxpr_consts, i, *fn_res)
for i in range(start, stop, step):
fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, consts, i, *fn_res)

return fn_res

# pylint: disable=unused-argument
@for_loop_prim.def_abstract_eval
def _(*_, jaxpr_body_fn, **__):

return [out.aval for out in jaxpr_body_fn.outvars]
def _(start, stop, step, *args, args_slice, **_):
return args[args_slice]

return for_loop_prim

Expand All @@ -653,8 +660,8 @@ class ForLoopCallable: # pylint:disable=too-few-public-methods
loop via the Python interpreter.
Args:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
start (int): starting value of the iteration index
stop (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
body_fn (Callable): The function called within the for loop. Note that the loop body
function must always have the iteration index as its first
Expand All @@ -663,17 +670,17 @@ class ForLoopCallable: # pylint:disable=too-few-public-methods
returned from the function.
"""

def __init__(self, lower_bound, upper_bound, step, body_fn):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def __init__(self, start, stop, step, body_fn):
self.start = start
self.stop = stop
self.step = step
self.body_fn = body_fn

def _call_capture_disabled(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

for i in range(self.lower_bound, self.upper_bound, self.step):
for i in range(self.start, self.stop, self.step):
fn_res = self.body_fn(i, *args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

Expand All @@ -688,15 +695,19 @@ def _call_capture_enabled(self, *init_state):
flat_fn = FlatFn(self.body_fn)
jaxpr_body_fn = jax.make_jaxpr(flat_fn)(0, *init_state)

consts_slice = slice(0, len(jaxpr_body_fn.consts))
args_slice = slice(len(jaxpr_body_fn.consts), None)

flat_args, _ = jax.tree_util.tree_flatten(init_state)
results = for_loop_prim.bind(
self.lower_bound,
self.upper_bound,
self.start,
self.stop,
self.step,
*jaxpr_body_fn.consts,
*flat_args,
jaxpr_body_fn=jaxpr_body_fn.jaxpr,
n_consts=len(jaxpr_body_fn.consts),
consts_slice=consts_slice,
args_slice=args_slice,
)
assert flat_fn.out_tree is not None
return jax.tree_util.tree_unflatten(flat_fn.out_tree, results)
Expand Down
81 changes: 37 additions & 44 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,53 +220,47 @@ def __call_capture_enabled(self, *args, **kwargs):
cond_prim = _get_cond_qfunc_prim()

elifs = (
(self.orig_elifs,)
[self.orig_elifs]
if len(self.orig_elifs) > 0 and not isinstance(self.orig_elifs[0], tuple)
else self.orig_elifs
else list(self.orig_elifs)
)

flat_fn = FlatFn(functools.partial(self.true_fn, **kwargs))
jaxpr_true = jax.make_jaxpr(flat_fn)(*args)
jaxpr_false = (
jax.make_jaxpr(functools.partial(self.otherwise_fn, **kwargs))(*args)
if self.otherwise_fn
else None
)

# We extract each condition (or predicate) from the elifs argument list
# since these are traced by JAX and are passed as positional arguments to the primitive
elifs_conditions = []
jaxpr_elifs = []

for pred, elif_fn in elifs:
elifs_conditions.append(pred)
jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args))

conditions = [self.condition, *elifs_conditions, True]

jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false]
jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches]
jaxpr_branches = [j.jaxpr if j else None for j in jaxpr_branches]

# We need to flatten the constants since JAX does not allow
# to pass lists as positional arguments
consts_flat = [const for sublist in jaxpr_consts for const in sublist]
n_consts_per_branch = [len(consts) for consts in jaxpr_consts]
flat_true_fn = FlatFn(self.true_fn)
branches = [(self.preds[0], flat_true_fn), *elifs, (True, self.otherwise_fn)]

end_const_ind = len(
branches
) # consts go after the len(branches) conditions, first const at len(branches)
conditions = []
jaxpr_branches = []
consts = []
consts_slices = []

for pred, fn in branches:
conditions.append(pred)
if fn is None:
jaxpr_branches.append(None)
consts_slices.append(slice(0, 0))
else:
jaxpr = jax.make_jaxpr(functools.partial(fn, **kwargs))(*args)
jaxpr_branches.append(jaxpr.jaxpr)
consts_slices.append(slice(end_const_ind, end_const_ind + len(jaxpr.consts)))
consts += jaxpr.consts
end_const_ind += len(jaxpr.consts)

flat_args, _ = jax.tree_util.tree_flatten(args)
results = cond_prim.bind(
*conditions,
*consts,
*flat_args,
*consts_flat,
jaxpr_branches=jaxpr_branches,
n_consts_per_branch=n_consts_per_branch,
n_args=len(flat_args),
consts_slices=consts_slices,
args_slice=slice(end_const_ind, None),
)
assert flat_fn.out_tree is not None
if flat_fn.out_tree.num_leaves != len(results):
assert flat_true_fn.out_tree is not None
if flat_true_fn.out_tree.num_leaves != len(results):
# undefined false fn leads to empty results
return results
return jax.tree_util.tree_unflatten(flat_fn.out_tree, results)
return jax.tree_util.tree_unflatten(flat_true_fn.out_tree, results)

def __call__(self, *args, **kwargs):

Expand Down Expand Up @@ -694,15 +688,16 @@ def _get_cond_qfunc_prim():
cond_prim.multiple_results = True

@cond_prim.def_impl
def _(*all_args, jaxpr_branches, n_consts_per_branch, n_args):
def _(*all_args, jaxpr_branches, consts_slices, args_slice):
n_branches = len(jaxpr_branches)
conditions = all_args[:n_branches]
args = all_args[n_branches : n_branches + n_args]
consts_flat = all_args[n_branches + n_args :]
args = all_args[args_slice]

# Find predicates that use mid-circuit measurements. We don't check the last
# condition as that is always `True`.
mcm_conditions = [pred for pred in conditions[:-1] if isinstance(pred, MeasurementValue)]
mcm_conditions = tuple(
pred for pred in conditions[:-1] if isinstance(pred, MeasurementValue)
)
if len(mcm_conditions) != 0:
if len(mcm_conditions) != len(conditions) - 1:
raise ConditionalTransformError(
Expand All @@ -711,10 +706,8 @@ def _(*all_args, jaxpr_branches, n_consts_per_branch, n_args):
)
conditions = _get_mcm_predicates(mcm_conditions)

start = 0
for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch):
consts = consts_flat[start : start + n_consts]
start += n_consts
for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices):
consts = all_args[const_slice]
if jaxpr is None:
continue
if isinstance(pred, qml.measurements.MeasurementValue):
Expand Down

0 comments on commit 0d497ec

Please sign in to comment.