Skip to content

Commit

Permalink
Update QJIT VQE demo to use value_and_grad and Catalyst printing (#…
Browse files Browse the repository at this point in the history
…1279)

When this demo was originally written, these features weren't available.
Adding these features to the demo makes it more efficient (we aren't
computing the cost function twice), and also shows users how to print
from inside qjit.
  • Loading branch information
josh146 authored Dec 13, 2024
1 parent b54b61f commit 8765fb2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
],
"dateOfPublication": "2024-04-26T00:00:00+00:00",
"dateOfLastModification": "2024-11-08T00:00:00+00:00",
"dateOfLastModification": "2024-11-12T00:00:00+00:00",
"categories": [
"Quantum Machine Learning",
"Optimization",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def cost(params):
# not JAX compatible.
#
# Instead, we can use `Optax <https://github.com/google-deepmind/optax>`__, a library designed for
# optimization using JAX, as well as the :func:`~.catalyst.grad` function, which allows us to
# differentiate through quantum just-in-time compiled workflows.
# optimization using JAX, as well as the :func:`~.catalyst.value_and_grad` function, which allows us to
# differentiate through quantum just-in-time compiled workflows while also returning the cost value.
# Here we use :func:`~.catalyst.value_and_grad` as we want to be able to print out and track our
# cost function during execution, but if this is not required the :func:`~.catalyst.grad` function
# can be used instead.
#

import catalyst
Expand All @@ -153,23 +156,17 @@ def cost(params):
@qml.qjit
def update_step(i, params, opt_state):
"""Perform a single gradient update step"""
grads = catalyst.grad(cost)(params)
energy, grads = catalyst.value_and_grad(cost)(params)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
catalyst.debug.print("Step = {i}, Energy = {energy:.8f} Ha", i=i, energy=energy)
return (params, opt_state)

loss_history = []

opt_state = opt.init(init_params)
params = init_params

for i in range(10):
params, opt_state = update_step(i, params, opt_state)
loss_val = cost(params)

print(f"--- Step: {i}, Energy: {loss_val:.8f}")

loss_history.append(loss_val)

######################################################################
# Step 4: QJIT-compile the optimization
Expand Down

0 comments on commit 8765fb2

Please sign in to comment.