Skip to content

Commit

Permalink
Corrects bug in the DiagonalTensorMap rrule, adds
Browse files Browse the repository at this point in the history
tests for the new code, adds a proper generator of
random tangents for DiagonalTensorMap
  • Loading branch information
ebelnikola committed Jan 20, 2025
1 parent f275175 commit 2cf4cb4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end

function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; kwargs...)
D=TensorMap(d, args...; kwargs...)
project_D=ProjectTo(D)
function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...;
kwargs...)
D = DiagonalTensorMap(d, args...; kwargs...)
project_D = ProjectTo(D)
function DiagonalTensorMap_pullback(Δt)
∂d = project_D(unthunk(Δt)).data
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
Expand Down
17 changes: 17 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
return randn!(similar(x))
end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
return DiagonalTensorMap(randn(eltype(x), dim(x.domain)), x.domain)
end
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
expected::AbstractTensorMap, msg=""; kwargs...)
Expand Down Expand Up @@ -144,6 +147,20 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
fkwargs=(; tol=Inf))
end

@timedtestset "Basic utility (DiagonalTensor)" begin
for NumType in [Float64, ComplexF64]
for v in V
T1 = DiagonalTensorMap(randn(NumType, dim(v)), v)
T2 = TensorMap(T1)

P1 = ProjectTo(T1)
@test P1(T2) == T1

test_rrule(DiagonalTensorMap, T1.data, T1.domain)
end
end
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
A = randn(T, V[1] V[2] V[3] V[4] V[5])
B = randn(T, space(A))
Expand Down

0 comments on commit 2cf4cb4

Please sign in to comment.