diff --git a/examples/test_svd_adjoint.jl b/examples/test_svd_adjoint.jl index fcd50f47..69124870 100644 --- a/examples/test_svd_adjoint.jl +++ b/examples/test_svd_adjoint.jl @@ -3,7 +3,7 @@ using TensorKit using ChainRulesCore, ChainRulesTestUtils, Zygote using PEPSKit -# Non-proper truncated SVD with outdated adjoint +# Truncated SVD with outdated adjoint oldsvd(t::AbstractTensorMap, χ::Int; kwargs...) = itersvd(t, χ; kwargs...) # Outdated adjoint not taking truncated part into account @@ -77,27 +77,33 @@ function oldsvd_rev( end # Gauge-invariant loss function -function lossfun(A, svdfunc) +function lossfun(A, R=TensorMap(randn, space(A)), svdfunc=tsvd) U, _, V = svdfunc(A) - # return real(sum((U * V).data)) # TODO: code up sum for AbstractTensorMap with rrule - return real(tr(U * V)) # trace only allows for m=n + return real(dot(R, U * V)) # Overlap with random tensor R is gauge-invariant and differentiable, also for m≠n end -m, n = 30, 30 +m, n = 20, 30 dtype = ComplexF64 -χ = 20 +χ = 15 r = TensorMap(randn, dtype, ℂ^m ← ℂ^n) +R = TensorMap(randn, space(r)) -println("Non-truncated SVD") -ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, min(m, n))), r) -litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, min(m, n))), r) -@show ltensorkit ≈ litersvd +println("Non-truncated SVD:") +loldsvd, goldsvd = withgradient(A -> lossfun(A, R, x -> oldsvd(x, min(m, n))), r) +ltensorkit, gtensorkit = withgradient( + A -> lossfun(A, R, x -> tsvd(x; trunc=truncdim(min(m, n)))), r +) +litersvd, gitersvd = withgradient(A -> lossfun(A, R, x -> itersvd(x, min(m, n))), r) +@show loldsvd ≈ ltensorkit ≈ litersvd +@show norm(gtensorkit[1] - goldsvd[1]) @show norm(gtensorkit[1] - gitersvd[1]) -println("\nTruncated SVD to χ=$χ:") -ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, χ)), r) -litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, χ)), r) -@show ltensorkit ≈ litersvd +println("\nTruncated SVD with χ=$χ:") +loldsvd, goldsvd = withgradient(A -> lossfun(A, R, x -> oldsvd(x, χ)), r) +ltensorkit, gtensorkit = withgradient( + A -> lossfun(A, R, x -> tsvd(x; trunc=truncdim(χ))), r +) +litersvd, gitersvd = withgradient(A -> lossfun(A, R, x -> itersvd(x, χ)), r) +@show loldsvd ≈ ltensorkit ≈ litersvd +@show norm(gtensorkit[1] - goldsvd[1]) @show norm(gtensorkit[1] - gitersvd[1]) - -# TODO: Finite-difference check via test_rrule diff --git a/src/utility/svd.jl b/src/utility/svd.jl index ee3cbd4b..975c72cb 100644 --- a/src/utility/svd.jl +++ b/src/utility/svd.jl @@ -117,6 +117,7 @@ function itersvd_rev( # Truncation contribution from dU₂ and dV₂ function svdlinprob(v) # Left-preconditioned linear problem + # TODO: make v a Tuple instead of concatening two vectors γ1 = reshape(@view(v[1:dimγ]), (k, m)) γ2 = reshape(@view(v[(dimγ + 1):end]), (k, n)) Γ1 = γ1 - S⁻¹ * γ2 * Vproj * Ad