-
Notifications
You must be signed in to change notification settings - Fork 622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend jax.vmap
to the multidimensional case
#6422
Conversation
…o vmap_program_capture
…com/PennyLaneAI/pennylane into vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…/pennylane into experim_param_broad_capture
…o experim_param_broad_capture
…o vmap_program_capture
…o experim_param_broad_capture
…yLaneAI/pennylane into vmap_program_capture
…o vmap_program_capture
…/pennylane into vmap_multidimensional_input
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6422 +/- ##
==========================================
- Coverage 99.45% 99.45% -0.01%
==========================================
Files 450 450
Lines 42088 42085 -3
==========================================
- Hits 41857 41854 -3
Misses 231 231 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
…o vmap_program_capture
…/pennylane into vmap_multidimensional_input
…/pennylane into vmap_multidimensional_input
**Context:** This PR is the first step to enable the usage of `jax.vmap` with quantum circuits when `qml.capture.enabled()` is `True`. **Description of the Change:** We implemented a batching rule for the captured `qnode` and modified the abstract evaluation to keep track of the batch dimension. **Benefits:** Allows vectorization (although with several limitations at this stage) with `qml.capture.enabled()` via `jax.vmap` **Possible Drawbacks:** Right now, there are 2 main limitations that I see. - Limitation 1: *Multidimensional arrays* Right now, the implementation does not work with multidimensional arrays. For example, the following does not work: ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x): qml.RX(x[1], 0) return qml.expval(qml.Z(0)) # jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error ``` This limitation will be removed in the [following PR](#6422). - Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?* ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x, y): qml.RX(x, 0) qml.RY(y, 0) return qml.expval(qml.Z(0)) x = jax.numpy.array([0.1, 0.2]) y = jax.numpy.array([0.1, 0.2]) jax.vmap(circuit, in_axes=(0, None))(x, y) ``` This works by accident, but the second argument is vectorized along 2 dimensions, although it shouldn't! With `in_axes=(0, None)` we should raise an error if `qml.RY` receives something that is not a scalar (as [it happens in Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)). At this stage, we simply raise a warning because we don't have a way to check this inside the QNode batching rule. To fix this behavior, I think we need to add more properties to the `AbstractOperator` class and improve the integration with the captured `QNode`. This is most probably also necessary for capturing parameter broadcasting in PL. **Related GitHub Issues:** None. **Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783] --------- Co-authored-by: Christina Lee <[email protected]>
…o vmap_multidimensional_input
**Context:** This PR is the first step to enable the usage of `jax.vmap` with quantum circuits when `qml.capture.enabled()` is `True`. **Description of the Change:** We implemented a batching rule for the captured `qnode` and modified the abstract evaluation to keep track of the batch dimension. **Benefits:** Allows vectorization (although with several limitations at this stage) with `qml.capture.enabled()` via `jax.vmap` **Possible Drawbacks:** Right now, there are 2 main limitations that I see. - Limitation 1: *Multidimensional arrays* Right now, the implementation does not work with multidimensional arrays. For example, the following does not work: ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x): qml.RX(x[1], 0) return qml.expval(qml.Z(0)) # jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error ``` This limitation will be removed in the [following PR](#6422). - Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?* ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x, y): qml.RX(x, 0) qml.RY(y, 0) return qml.expval(qml.Z(0)) x = jax.numpy.array([0.1, 0.2]) y = jax.numpy.array([0.1, 0.2]) jax.vmap(circuit, in_axes=(0, None))(x, y) ``` This works by accident, but the second argument is vectorized along 2 dimensions, although it shouldn't! With `in_axes=(0, None)` we should raise an error if `qml.RY` receives something that is not a scalar (as [it happens in Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)). At this stage, we simply raise a warning because we don't have a way to check this inside the QNode batching rule. To fix this behavior, I think we need to add more properties to the `AbstractOperator` class and improve the integration with the captured `QNode`. This is most probably also necessary for capturing parameter broadcasting in PL. **Related GitHub Issues:** None. **Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783] --------- Co-authored-by: Christina Lee <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @PietropaoloFrisoni !
Context: This PR extends the captured
jax.vmap
version to the multidimensional input case. For further details, we refer to the description of the first PR.Description of the Change: As above. For 'multidimensional input case' we mean something like the following:
Benefits: Now
jax.vmap
can be used with captured enabled if the input parameter is an array with a shape greater than 1.Possible Drawbacks: None that I can think of.
Related GitHub Issues: None.
Related Shortcut Stories: [sc-76381]