Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend jax.vmap to the multidimensional case #6422

Merged
merged 86 commits into from
Nov 18, 2024

Conversation

PietropaoloFrisoni
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni commented Oct 21, 2024

Context: This PR extends the captured jax.vmap version to the multidimensional input case. For further details, we refer to the description of the first PR.

Description of the Change: As above. For 'multidimensional input case' we mean something like the following:

qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=...))
...

jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) 

Benefits: Now jax.vmap can be used with captured enabled if the input parameter is an array with a shape greater than 1.

Possible Drawbacks: None that I can think of.

Related GitHub Issues: None.

Related Shortcut Stories: [sc-76381]

Copy link

codecov bot commented Nov 1, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.45%. Comparing base (9a3dbdd) to head (d983e21).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6422      +/-   ##
==========================================
- Coverage   99.45%   99.45%   -0.01%     
==========================================
  Files         450      450              
  Lines       42088    42085       -3     
==========================================
- Hits        41857    41854       -3     
  Misses        231      231              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

albi3ro added a commit that referenced this pull request Nov 8, 2024
**Context:** This PR is the first step to enable the usage of `jax.vmap`
with quantum circuits when `qml.capture.enabled()` is `True`.

**Description of the Change:** We implemented a batching rule for the
captured `qnode` and modified the abstract evaluation to keep track of
the batch dimension.

**Benefits:** Allows vectorization (although with several limitations at
this stage) with `qml.capture.enabled()` via `jax.vmap`

**Possible Drawbacks:** Right now, there are 2 main limitations that I
see.

- Limitation 1: *Multidimensional arrays*

Right now, the implementation does not work with multidimensional
arrays. For example, the following does not work:

```
qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x):
    qml.RX(x[1], 0)
    return qml.expval(qml.Z(0))

# jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error
```

This limitation will be removed in the [following
PR](#6422).

- Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?*

```
qml.capture.enable()


@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    return qml.expval(qml.Z(0))

x = jax.numpy.array([0.1, 0.2])
y = jax.numpy.array([0.1, 0.2])

jax.vmap(circuit, in_axes=(0, None))(x, y) 

```

This works by accident, but the second argument is vectorized along 2
dimensions, although it shouldn't!

With `in_axes=(0, None)` we should raise an error if `qml.RY` receives
something that is not a scalar (as [it happens in
Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)).

At this stage, we simply raise a warning because we don't have a way to
check this inside the QNode batching rule. To fix this behavior, I think
we need to add more properties to the `AbstractOperator` class and
improve the integration with the captured `QNode`. This is most probably
also necessary for capturing parameter broadcasting in PL.


**Related GitHub Issues:** None.

**Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783]

---------

Co-authored-by: Christina Lee <[email protected]>
Base automatically changed from vmap_program_capture to master November 8, 2024 17:09
mudit2812 pushed a commit that referenced this pull request Nov 11, 2024
**Context:** This PR is the first step to enable the usage of `jax.vmap`
with quantum circuits when `qml.capture.enabled()` is `True`.

**Description of the Change:** We implemented a batching rule for the
captured `qnode` and modified the abstract evaluation to keep track of
the batch dimension.

**Benefits:** Allows vectorization (although with several limitations at
this stage) with `qml.capture.enabled()` via `jax.vmap`

**Possible Drawbacks:** Right now, there are 2 main limitations that I
see.

- Limitation 1: *Multidimensional arrays*

Right now, the implementation does not work with multidimensional
arrays. For example, the following does not work:

```
qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x):
    qml.RX(x[1], 0)
    return qml.expval(qml.Z(0))

# jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error
```

This limitation will be removed in the [following
PR](#6422).

- Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?*

```
qml.capture.enable()


@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    return qml.expval(qml.Z(0))

x = jax.numpy.array([0.1, 0.2])
y = jax.numpy.array([0.1, 0.2])

jax.vmap(circuit, in_axes=(0, None))(x, y) 

```

This works by accident, but the second argument is vectorized along 2
dimensions, although it shouldn't!

With `in_axes=(0, None)` we should raise an error if `qml.RY` receives
something that is not a scalar (as [it happens in
Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)).

At this stage, we simply raise a warning because we don't have a way to
check this inside the QNode batching rule. To fix this behavior, I think
we need to add more properties to the `AbstractOperator` class and
improve the integration with the captured `QNode`. This is most probably
also necessary for capturing parameter broadcasting in PL.


**Related GitHub Issues:** None.

**Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783]

---------

Co-authored-by: Christina Lee <[email protected]>
@albi3ro albi3ro self-requested a review November 13, 2024 16:33
Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @PietropaoloFrisoni !

@PietropaoloFrisoni PietropaoloFrisoni merged commit 1a7db33 into master Nov 18, 2024
46 checks passed
@PietropaoloFrisoni PietropaoloFrisoni deleted the vmap_multidimensional_input branch November 18, 2024 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants