From 51507ce14eae2fff8cfaa94089e58eb009e28b26 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Tue, 5 Mar 2024 11:16:41 +0100 Subject: [PATCH] Use KrylovKit.linsolve for truncation linear problem, make loss function differentiable --- Project.toml | 1 + examples/test_svd_adjoint.jl | 34 +++++++++++++++++++--------------- src/utility/svd.jl | 28 ++++++++++++---------------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index d5e3a930..2826518e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502" diff --git a/examples/test_svd_adjoint.jl b/examples/test_svd_adjoint.jl index e11f0eff..fcd50f47 100644 --- a/examples/test_svd_adjoint.jl +++ b/examples/test_svd_adjoint.jl @@ -1,6 +1,6 @@ using LinearAlgebra using TensorKit -using ChainRulesCore, Zygote +using ChainRulesCore, ChainRulesTestUtils, Zygote using PEPSKit # Non-proper truncated SVD with outdated adjoint @@ -45,9 +45,6 @@ function oldsvd_rev( atol::Real=0, rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4), ) - S = diagm(S) - V = copy(V') - tol = atol > 0 ? atol : rtol * S[1, 1] F = PEPSKit.invert_S²(S, tol; εbroad) # Includes Lorentzian broadening S⁻¹ = pinv(S; atol=tol) @@ -74,26 +71,33 @@ function oldsvd_rev( VdV = V' * V Uproj = one(UUd) - UUd Vproj = one(VdV) - VdV - ΔA += Uproj * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * Vproj # Old wrong stuff + ΔA += Uproj * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * Vproj # Wrong truncation contribution return ΔA end -# Loss function taking the nfirst first singular vectors into account -function nfirst_loss(A, svdfunc; nfirst=1) +# Gauge-invariant loss function +function lossfun(A, svdfunc) U, _, V = svdfunc(A) - U = convert(Array, U) - V = convert(Array, V) - return real(sum([U[i, i] * V[i, i] for i in 1:nfirst])) + # return real(sum((U * V).data)) # TODO: code up sum for AbstractTensorMap with rrule + return real(tr(U * V)) # trace only allows for m=n end -m, n = 30, 20 +m, n = 30, 30 dtype = ComplexF64 -χ = 15 +χ = 20 r = TensorMap(randn, dtype, ℂ^m ← ℂ^n) -ltensorkit, gtensorkit = withgradient(A -> nfirst_loss(A, x -> oldsvd(x, χ); nfirst=3), r) -litersvd, gitersvd = withgradient(A -> nfirst_loss(A, x -> itersvd(x, χ); nfirst=3), r) +println("Non-truncated SVD") +ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, min(m, n))), r) +litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, min(m, n))), r) +@show ltensorkit ≈ litersvd +@show norm(gtensorkit[1] - gitersvd[1]) +println("\nTruncated SVD to χ=$χ:") +ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, χ)), r) +litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, χ)), r) @show ltensorkit ≈ litersvd -@show gtensorkit ≈ gitersvd \ No newline at end of file +@show norm(gtensorkit[1] - gitersvd[1]) + +# TODO: Finite-difference check via test_rrule diff --git a/src/utility/svd.jl b/src/utility/svd.jl index d3b149d0..ee3cbd4b 100644 --- a/src/utility/svd.jl +++ b/src/utility/svd.jl @@ -116,28 +116,24 @@ function itersvd_rev( dimγ = k * m # Vectorized dimension of γ-matrix # Truncation contribution from dU₂ and dV₂ - # TODO: Use KrylovKit instead of IterativeSolvers - Sop = LinearMap(k * m + k * n) do v # Left-preconditioned linear problem - γ = reshape(@view(v[1:dimγ]), (k, m)) - γd = reshape(@view(v[(dimγ + 1):end]), (k, n)) - Γ1 = γ - S⁻¹ * γd * Vproj * Ad - Γ2 = γd - S⁻¹ * γ * Uproj * A - vcat(reshape(Γ1, :), reshape(Γ2, :)) + function svdlinprob(v) # Left-preconditioned linear problem + γ1 = reshape(@view(v[1:dimγ]), (k, m)) + γ2 = reshape(@view(v[(dimγ + 1):end]), (k, n)) + Γ1 = γ1 - S⁻¹ * γ2 * Vproj * Ad + Γ2 = γ2 - S⁻¹ * γ1 * Uproj * A + return vcat(reshape(Γ1, :), reshape(Γ2, :)) end if ΔU isa ZeroTangent && ΔV isa ZeroTangent - γ = gmres(Sop, zeros(eltype(A), k * m + k * n)) + γ = linsolve(Sop, zeros(eltype(A), k * m + k * n)) else # Explicit left-preconditioning # Set relative tolerance to machine precision to converge SVD gradient error properly - γ = gmres( - Sop, - vcat(reshape(S⁻¹ * ΔU' * Uproj, :), reshape(S⁻¹ * ΔV * Vproj, :)); - reltol=eps(real(eltype(A))), - ) + y = vcat(reshape(S⁻¹ * ΔU' * Uproj, :), reshape(S⁻¹ * ΔV * Vproj, :)) + γ, = linsolve(svdlinprob, y; rtol=eps(real(eltype(A)))) end - γA = reshape(@view(γ[1:dimγ]), k, m) - γAd = reshape(@view(γ[(dimγ + 1):end]), k, n) - ΔA += Uproj * γA' * V + U * γAd * Vproj + γA1 = reshape(@view(γ[1:dimγ]), k, m) + γA2 = reshape(@view(γ[(dimγ + 1):end]), k, n) + ΔA += Uproj * γA1' * V + U * γA2 * Vproj return ΔA end