From 21e8f6f01e12da0139b661e437b495d8368aaf09 Mon Sep 17 00:00:00 2001 From: Jutho Date: Fri, 17 Nov 2023 11:43:51 +0100 Subject: [PATCH] further svd ad implementation and test fix --- ext/TensorKitChainRulesCoreExt.jl | 2 +- src/tensors/truncation.jl | 3 +- test/ad.jl | 71 +++++++------------------------ 3 files changed, 17 insertions(+), 59 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 12d10509..47e8fb25 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -166,7 +166,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; Uc, Σc, Vc = block(U, c), block(Σ, c), block(V, c) ΔUc, ΔΣc, ΔVc = block(ΔU, c), block(ΔΣ, c), block(ΔV, c) Σdc = view(Σc, diagind(Σc)) - ΔΣdc = view(ΔΣc, diagind(ΔΣc)) + ΔΣdc = (ΔΣdc isa AbstractZero) ? ΔΣdc : view(ΔΣdc, diagind(ΔΣdc)) copyto!(b, svd_pullback(Uc, Σdc, Vc, ΔUc, ΔΣdc, ΔVc)) end return NoTangent(), Δt diff --git a/src/tensors/truncation.jl b/src/tensors/truncation.jl index f35efea2..61fe777d 100644 --- a/src/tensors/truncation.jl +++ b/src/tensors/truncation.jl @@ -164,8 +164,7 @@ function _truncate!(V::SectorVectorDict, trunc::TruncationSpace, p=2) end return V, truncerr end -######################## -######################## + function _truncate!(V::SectorVectorDict, trunc::TruncationCutoff, p=2) I = keytype(V) S = real(eltype(valtype(V))) diff --git a/test/ad.jl b/test/ad.jl index 12e85970..a0060b3b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -78,36 +78,14 @@ end # complex-valued svd? # ------------------- -# function _gaugefix!(U, V) -# s = LinearAlgebra.Diagonal(TensorKit._safesign.(diag(U))) -# rmul!(U, s) -# lmul!(s', V) -# return U, V -# end - -# function _tsvd(t::AbstractTensorMap) -# U, S, V, ϵ = tsvd(t) -# for (c, b) in blocks(U) -# _gaugefix!(b, block(V, c)) -# end -# return U, S, V, ϵ -# end - -# svd_rev = Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt).svd_rev - -# function ChainRulesCore.rrule(::typeof(_tsvd), t::AbstractTensorMap) -# U, S, V, ϵ = _tsvd(t) -# function _tsvd_pullback((ΔU, ΔS, ΔV, Δϵ)) -# ∂t = similar(t) -# for (c, b) in blocks(∂t) -# copyto!(b, -# svd_rev(block(U, c), block(S, c), block(V, c), -# block(ΔU, c), block(ΔS, c), block(ΔV, c))) -# end -# return NoTangent(), ∂t -# end -# return (U, S, V, ϵ), _tsvd_pullback -# end +function remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # simple implementation, assumes no degeneracies or zeros in singular values + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + return ΔU, ΔV +end # Tests # ----- @@ -275,12 +253,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S)) ΔV = TensorMap(randn, scalartype(V), space(V)) - if T <: Complex # remove gauge dependent components - gaugepart = U' * ΔU + V * ΔV' - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) - end - end + T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2) @@ -290,12 +263,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S)) ΔV = TensorMap(randn, scalartype(V), space(V)) - if T <: Complex # remove gauge dependent components - gaugepart = U' * ΔU + V * ΔV' - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) - end - end + T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc=truncspace(Vtrunc))) end @@ -304,26 +272,17 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S)) ΔV = TensorMap(randn, scalartype(V), space(V)) - if T <: Complex # remove gauge dependent components - gaugepart = U' * ΔU + V * ΔV' - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) - end - end + T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) - U, S, V, ϵ = tsvd(C; trunc=truncdim(2)) + c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) + U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c))) ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S)) ΔV = TensorMap(randn, scalartype(V), space(V)) - if T <: Complex # remove gauge dependent components - gaugepart = U' * ΔU + V * ΔV' - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) - end - end + T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), - fkwargs=(; trunc=truncdim(2))) + fkwargs=(; trunc=truncdim(2 * dim(c)))) end end end