diff --git a/ext/ArrayInterfaceTrackerExt.jl b/ext/ArrayInterfaceTrackerExt.jl index d2d4e2ce..6d26f410 100644 --- a/ext/ArrayInterfaceTrackerExt.jl +++ b/ext/ArrayInterfaceTrackerExt.jl @@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) +function ArrayInterface.restructure(x::Array, y::TrackedArray) + reshape(y, Base.size(x)...) +end +function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal}) + reshape(y, Base.size(x)...) +end + end # module diff --git a/test/ad.jl b/test/ad.jl index 7c29c8dd..c1a207b7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -18,3 +18,12 @@ x = Tracker.TrackedArray([4.0,4.0]) x = reduce(vcat, Tracker.TrackedArray([4.0,4.0])) x = [x[1],x[2]] @test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray + +x = rand(4) +y = Tracker.TrackedReal.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = Tracker.TrackedArray(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) \ No newline at end of file