From 2cf4cb4e9fd65564a514fc2bb5c5ccdd7cd92ee5 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 12:28:58 +0100 Subject: [PATCH] Corrects bug in the DiagonalTensorMap rrule, adds tests for the new code, adds a proper generator of random tangents for DiagonalTensorMap --- ext/TensorKitChainRulesCoreExt/constructors.jl | 7 ++++--- test/ad.jl | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index af1f1ed9..caa58856 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -12,9 +12,10 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg return TensorMap(d, args...; kwargs...), TensorMap_pullback end -function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; kwargs...) - D=TensorMap(d, args...; kwargs...) - project_D=ProjectTo(D) +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; + kwargs...) + D = DiagonalTensorMap(d, args...; kwargs...) + project_D = ProjectTo(D) function DiagonalTensorMap_pullback(Δt) ∂d = project_D(unthunk(Δt)).data return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... diff --git a/test/ad.jl b/test/ad.jl index a684c4f8..5d77b5be 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -15,6 +15,9 @@ end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) return randn!(similar(x)) end +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap) + return DiagonalTensorMap(randn(eltype(x), dim(x.domain)), x.domain) +end ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, expected::AbstractTensorMap, msg=""; kwargs...) @@ -144,6 +147,20 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), fkwargs=(; tol=Inf)) end + @timedtestset "Basic utility (DiagonalTensor)" begin + for NumType in [Float64, ComplexF64] + for v in V + T1 = DiagonalTensorMap(randn(NumType, dim(v)), v) + T2 = TensorMap(T1) + + P1 = ProjectTo(T1) + @test P1(T2) == T1 + + test_rrule(DiagonalTensorMap, T1.data, T1.domain) + end + end + end + @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = randn(T, space(A))