Skip to content

Commit

Permalink
Rename DenseSVDAdjoint, update svd_wrapper test
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrehmer committed Jul 9, 2024
1 parent b3a0726 commit 89ae0a4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ module Defaults
const fpgrad_tol = 1e-6
end

export SVDrrule, IterSVD, OldSVD, CompleteSVDAdjoint, SparseSVDAdjoint, NonTruncSVDAdjoint
export SVDrrule, IterSVD, OldSVD, DenseSVDAdjoint, SparseSVDAdjoint, NonTruncSVDAdjoint
export FixedSpaceTruncation, ProjectorAlg, CTMRG, CTMRGEnv
export LocalOperator
export expectation_value, costfun
Expand Down
16 changes: 8 additions & 8 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ using TensorKit:
CRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt)

"""
struct SVDrrule(; svd_alg = TensorKit.SVD(), rrule_alg = CompleteSVDAdjoint())
struct SVDrrule(; svd_alg = TensorKit.SVD(), rrule_alg = DenseSVDAdjoint())
Wrapper for a SVD algorithm `svd_alg` with a defined reverse rule `rrule_alg`.
"""
@kwdef struct SVDrrule{S,R}
svd_alg::S = TensorKit.SVD()
rrule_alg::R = CompleteSVDAdjoint() # TODO: should contain Lorentzian broadening eventually
rrule_alg::R = DenseSVDAdjoint() # TODO: should contain Lorentzian broadening eventually
end # Keep truncation algorithm separate to be able to specify CTMRG dependent information

"""
Expand Down Expand Up @@ -105,18 +105,18 @@ function TensorKit._compute_svddata!(
end

"""
struct CompleteSVDAdjoint(; lorentz_broadening = 0.0)
struct DenseSVDAdjoint(; lorentz_broadening = 0.0)
Wrapper around the complete `TensorKit.tsvd!` rrule which requires computing the full SVD.
"""
@kwdef struct CompleteSVDAdjoint
@kwdef struct DenseSVDAdjoint
lorentz_broadening::Float64 = 0.0
end

function ChainRulesCore.rrule(
::typeof(PEPSKit.tsvd!),
t::AbstractTensorMap,
alg::SVDrrule{A,CompleteSVDAdjoint};
alg::SVDrrule{A,DenseSVDAdjoint};
trunc::TruncationScheme=notrunc(),
p::Real=2,
) where {A}
Expand All @@ -130,15 +130,15 @@ end
Wrapper around the `KrylovKit.svdsolve` rrule where only the truncated decomposition is required.
"""
@kwdef struct SparseSVDAdjoint
alg::Union{GMRES,BiCGStab,Arnoldi} = GMRES()
@kwdef struct SparseSVDAdjoint{A}
alg::A = GMRES()
lorentz_broadening::Float64 = 0.0
end

function ChainRulesCore.rrule(
::typeof(PEPSKit.tsvd!),
t::AbstractTensorMap,
alg::SVDrrule{A,SparseSVDAdjoint};
alg::SVDrrule{A,<:SparseSVDAdjoint};
trunc::TruncationScheme=notrunc(),
p::Real=2,
) where {A}
Expand Down
16 changes: 8 additions & 8 deletions test/ctmrg/svd_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ rtol = 1e-9
r = TensorMap(randn, dtype, ℂ^m, ℂ^n)
R = TensorMap(randn, space(r))

full_alg = SVDrrule(; svd_alg=TensorKit.SVD(), rrule_alg=CompleteSVDAdjoint())
full_alg = SVDrrule(; svd_alg=TensorKit.SVD(), rrule_alg=DenseSVDAdjoint())
old_alg = SVDrrule(; svd_alg=TensorKit.SVD(), rrule_alg=NonTruncSVDAdjoint())
iter_alg = SVDrrule(; # Don't make adjoint tolerance too small, g_itersvd will be weird
svd_alg=IterSVD(; alg=GKL(; krylovdim=50)),
Expand All @@ -35,8 +35,8 @@ iter_alg = SVDrrule(; # Don't make adjoint tolerance too small, g_itersvd will
l_itersvd, g_itersvd = withgradient(A -> lossfun(A, iter_alg, R), r)

@test l_oldsvd l_itersvd l_fullsvd
@test norm(g_fullsvd[1] - g_oldsvd[1]) / norm(g_fullsvd[1]) < rtol
@test norm(g_fullsvd[1] - g_itersvd[1]) / norm(g_fullsvd[1]) < rtol
@test g_fullsvd[1] g_oldsvd[1] rtol = rtol
@test g_fullsvd[1] g_itersvd[1] rtol = rtol
end

@testset "Truncated SVD with χ=" begin
Expand All @@ -45,8 +45,8 @@ end
l_itersvd, g_itersvd = withgradient(A -> lossfun(A, iter_alg, R, trunc), r)

@test l_oldsvd l_itersvd l_fullsvd
@test norm(g_fullsvd[1] - g_oldsvd[1]) / norm(g_fullsvd[1]) > rtol
@test norm(g_fullsvd[1] - g_itersvd[1]) / norm(g_fullsvd[1]) < rtol
@test !isapprox(g_fullsvd[1], g_oldsvd[1]; rtol)
@test g_fullsvd[1] g_itersvd[1] rtol = rtol
end

# TODO: Add when Lorentzian broadening is implemented
Expand Down Expand Up @@ -74,7 +74,7 @@ symm_R = TensorMap(randn, dtype, space(symm_r))
l_fullsvd, g_fullsvd = withgradient(A -> lossfun(A, full_alg, symm_R), symm_r)
l_itersvd, g_itersvd = withgradient(A -> lossfun(A, iter_alg, symm_R), symm_r)
@test l_itersvd l_fullsvd
@test norm(g_fullsvd[1] - g_itersvd[1]) / norm(g_fullsvd[1]) < rtol
@test g_fullsvd[1] g_itersvd[1] rtol = rtol

l_fullsvd_tr, g_fullsvd_tr = withgradient(
A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r
Expand All @@ -83,12 +83,12 @@ symm_R = TensorMap(randn, dtype, space(symm_r))
A -> lossfun(A, iter_alg, symm_R, symm_trspace), symm_r
)
@test l_itersvd_tr l_fullsvd_tr
@test norm(g_fullsvd_tr[1] - g_itersvd_tr[1]) / norm(g_fullsvd_tr[1]) < rtol
@test g_fullsvd_tr[1] g_itersvd_tr[1] rtol = rtol

iter_alg_fallback = @set iter_alg.svd_alg.fallback_threshold = 0.4 # Do dense SVD in one block, sparse SVD in the other
l_itersvd_fb, g_itersvd_fb = withgradient(
A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace), symm_r
)
@test l_itersvd_fb l_fullsvd_tr
@test norm(g_fullsvd_tr[1] - g_itersvd_fb[1]) / norm(g_fullsvd_tr[1]) < rtol
@test g_fullsvd_tr[1] g_itersvd_fb[1] rtol = rtol
end

0 comments on commit 89ae0a4

Please sign in to comment.