diff --git a/src/Diff.jl b/src/Diff.jl index 055a240..8fc811e 100644 --- a/src/Diff.jl +++ b/src/Diff.jl @@ -198,7 +198,7 @@ function backward_params!(st, block::Diff{<:DiffBlock}, collector) in, outδ = st Σ = generator(content(block)) g = dropdims(sum(conj.(statevec(in |> Σ)) .* statevec(outδ), dims=1), dims=1) - pushfirst!(collector, -g |> imag) + pushfirst!(collector, -g[1] |> imag) in |> Σ nothing end