Skip to content

Commit 8fdef4c

Browse files
committed
update Zygote example
1 parent 88fcff4 commit 8fdef4c

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@ uuid = "65c24e16-9b0a-11e8-1353-efc5bc5f6586"
33
version = "0.1.0"
44

55
[deps]
6+
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8-
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
99
Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"
1010
YaoArrayRegister = "e600142f-9330-5003-8abb-0ebd767abc51"
1111
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
1212
YaoExtensions = "7a06699c-c960-11e9-3c98-9f78548b5f0f"
1313

1414
[compat]
15-
julia = "1"
16-
Yao = "0.6.0"
1715
BitBasis = "0.6"
16+
Yao = "0.6.0"
1817
YaoArrayRegister = "0.6"
1918
YaoBlocks = "0.8, 0.9, 0.10"
2019
YaoExtensions = "0.2"
20+
julia = "1"
2121

2222
[extras]
2323
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
include("zygote_patch.jl")
2+
3+
import YaoExtensions, Random
4+
5+
c = YaoExtensions.variational_circuit(5)
6+
h = YaoExtensions.heisenberg(5)
7+
8+
function loss(h, c, θ) where N
9+
# the assign is nessesary!
10+
c = dispatch!(c, fill(θ, nparameters(c)))
11+
reg = apply!(zero_state(nqubits(c)), c)
12+
real(expect(h, reg))
13+
end
14+
15+
reg0 = zero_state(5)
16+
zygote_grad = Zygote.gradient->loss(h, c, θ), 0.5)[1]
17+
18+
19+
# check gradients
20+
using Test
21+
dispatch!(c, fill(0.5, nparameters(c)))
22+
greg, gparams = expect'(h, zero_state(5)=>c)
23+
true_grad = sum(gparams)
24+
25+
@test true_grad true_grad

examples/PortZygote/zygote_patch.jl

+12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ using Yao, Yao.AD
1010
end
1111
end
1212

13+
@adjoint function Yao.dispatch!(block::AbstractBlock, params)
14+
out = dispatch!(block, params)
15+
out, function (outδ)
16+
(nothing, outδ)
17+
end
18+
end
19+
1320
@adjoint function Matrix(block::AbstractBlock)
1421
out = Matrix(block)
1522
out, function (outδ)
@@ -40,3 +47,8 @@ end
4047
@adjoint statevec(reg::AdjointArrayReg) = statevec(reg), adjy->(ArrayReg(adjy')',)
4148
@adjoint parent(reg::AdjointArrayReg) = parent(reg), adjy->(adjy',)
4249
@adjoint Base.adjoint(reg::ArrayReg) = Base.adjoint(reg), adjy->(parent(adjy),)
50+
Zygote.@nograd Yao.nparameters
51+
Zygote.@nograd Yao.zero_state
52+
Zygote.@nograd Yao.rand_state
53+
Zygote.@nograd Yao.uniform_state
54+
Zygote.@nograd Yao.product_state

0 commit comments

Comments
 (0)