Skip to content

Commit

Permalink
Merge pull request #453 from JuliaArrays/restructure_tracker
Browse files Browse the repository at this point in the history
Fix Tracker with restructure
  • Loading branch information
ChrisRackauckas authored Sep 1, 2024
2 parents 0042e23 + 00cc8c6 commit 7b44f9e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ext/ArrayInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N
end
end

function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)
reshape(y, Base.size(x)...)
end

end # module
7 changes: 7 additions & 0 deletions ext/ArrayInterfaceTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false
ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x)

function ArrayInterface.restructure(x::Array, y::Tracker.TrackedArray)
reshape(y, Base.size(x)...)
end
function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal})
reshape(y, Base.size(x)...)
end

end # module
18 changes: 18 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,21 @@ x = Tracker.TrackedArray([4.0,4.0])
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
x = [x[1],x[2]]
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray

x = rand(4)
y = Tracker.TrackedReal.(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Array
@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal
@test size(ArrayInterface.restructure(x, y)) == (4,)
y = Tracker.TrackedArray(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
@test size(ArrayInterface.restructure(x, y)) == (4,)

x = rand(4)
y = ReverseDiff.track(rand(2,2))
@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray
@test size(ArrayInterface.restructure(x, y)) == (4,)
y = ReverseDiff.track.(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Array
@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 comments on commit 7b44f9e

Please sign in to comment.