Skip to content

Commit

Permalink
Fix Tracker with restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 1, 2024
1 parent 0042e23 commit 0acd847
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ext/ArrayInterfaceTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false
ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x)

function ArrayInterface.restructure(x::Array, y::TrackedArray)
reshape(y, Base.size(x)...)
end
function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal})
reshape(y, Base.size(x)...)
end

end # module
9 changes: 9 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ x = Tracker.TrackedArray([4.0,4.0])
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
x = [x[1],x[2]]
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray

x = rand(4)
y = Tracker.TrackedReal.(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Array
@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal
@test size(ArrayInterface.restructure(x, y)) == (4,)
y = Tracker.TrackedArray(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 comments on commit 0acd847

Please sign in to comment.