Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use reallinsolve for solving the CTMRG gradient linear problem #94

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/algorithms/peps_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,14 @@
)
@assert !isnothing(alg.projector_alg.svd_alg.rrule_alg)
envs = leading_boundary(envinit, state, alg)
envsconv, info = ctmrg_iteration(state, envs, alg)
envs_fixed, signs = gauge_fix(envs, envsconv)
envs_conv, info = ctmrg_iteration(state, envs, alg)
envs_fixed, signs = gauge_fix(envs, envs_conv)

Check warning on line 257 in src/algorithms/peps_opt.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/peps_opt.jl#L256-L257

Added lines #L256 - L257 were not covered by tests

# Fix SVD
Ufix, Vfix = fix_relative_phases(info.U, info.V, signs)
U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs)

Check warning on line 260 in src/algorithms/peps_opt.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/peps_opt.jl#L260

Added line #L260 was not covered by tests
svd_alg_fixed = SVDAdjoint(;
fwd_alg=FixedSVD(Ufix, info.S, Vfix), rrule_alg=alg.projector_alg.svd_alg.rrule_alg
fwd_alg=FixedSVD(U_fixed, info.S, V_fixed),
rrule_alg=alg.projector_alg.svd_alg.rrule_alg,
)
alg_fixed = @set alg.projector_alg.svd_alg = svd_alg_fixed
alg_fixed = @set alg_fixed.projector_alg.trscheme = notrunc()
Expand Down Expand Up @@ -346,7 +347,7 @@
end

function fpgrad(∂F∂x, ∂f∂x, ∂f∂A, y₀, alg::LinSolver)
y, info = linsolve(∂f∂x, ∂F∂x, y₀, alg.solver, 1, -1)
y, info = reallinsolve(∂f∂x, ∂F∂x, y₀, alg.solver, 1, -1)
if alg.solver.verbosity > 0 && info.converged != 1
@warn("gradient fixed-point iteration reached maximal number of iterations:", info)
end
Expand Down
64 changes: 64 additions & 0 deletions test/ctmrg/jacobian_real_linear.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using Test
using Random
using Accessors
using Zygote
using TensorKit, KrylovKit, PEPSKit
using PEPSKit: ctmrg_iteration, gauge_fix, fix_relative_phases, fix_global_phases

algs = [
(:fixed, SimultaneousCTMRG(; projector_alg=HalfInfiniteProjector)),
(:diffgauge, SequentialCTMRG(; projector_alg=HalfInfiniteProjector)),
(:diffgauge, SimultaneousCTMRG(; projector_alg=HalfInfiniteProjector)),
# TODO: FullInfiniteProjector errors since even real_err_∂A, real_err_∂x are finite?
# (:fixed, SimultaneousCTMRG(; projector_alg=FullInfiniteProjector)),
# (:diffgauge, SequentialCTMRG(; projector_alg=FullInfiniteProjector)),
# (:diffgauge, SimultaneousCTMRG(; projector_alg=FullInfiniteProjector)),
]
Dbond, χenv = 2, 16

@testset "$iterscheme and $ctm_alg" for (iterscheme, ctm_alg) in algs
Random.seed!(123521938519)
state = InfinitePEPS(2, Dbond)
envs = leading_boundary(CTMRGEnv(state, ComplexSpace(χenv)), state, ctm_alg)

# follow code of _rrule
if iterscheme == :fixed
envs_conv, info = ctmrg_iteration(state, envs, ctm_alg)
envs_fixed, signs = gauge_fix(envs, envs_conv)
U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs)
svd_alg_fixed = SVDAdjoint(;
fwd_alg=PEPSKit.FixedSVD(U_fixed, info.S, V_fixed),
rrule_alg=ctm_alg.projector_alg.svd_alg.rrule_alg,
)
alg_fixed = @set ctm_alg.projector_alg.svd_alg = svd_alg_fixed
alg_fixed = @set alg_fixed.projector_alg.trscheme = notrunc()

_, env_vjp = pullback(state, envs_fixed) do A, x
e, = PEPSKit.ctmrg_iteration(A, x, alg_fixed)
return PEPSKit.fix_global_phases(x, e)
end
elseif iterscheme == :diffgauge
_, env_vjp = pullback(state, envs) do A, x
return gauge_fix(x, ctmrg_iteration(A, x, ctm_alg)[1])[1]
end
end

# get Jacobians of single iteration
∂f∂A(x)::typeof(state) = env_vjp(x)[1]
∂f∂x(x)::typeof(envs) = env_vjp(x)[2]

# compute real and complex errors
env_in = CTMRGEnv(state, ComplexSpace(16))
α_real = randn(Float64)
α_complex = randn(ComplexF64)

real_err_∂A = norm(scale(∂f∂A(env_in), α_real) - ∂f∂A(scale(env_in, α_real)))
real_err_∂x = norm(scale(∂f∂x(env_in), α_real) - ∂f∂x(scale(env_in, α_real)))
complex_err_∂A = norm(scale(∂f∂A(env_in), α_complex) - ∂f∂A(scale(env_in, α_complex)))
complex_err_∂x = norm(scale(∂f∂x(env_in), α_complex) - ∂f∂x(scale(env_in, α_complex)))

@test real_err_∂A < 1e-9
@test real_err_∂x < 1e-9
@test complex_err_∂A > 1e-3
@test complex_err_∂x > 1e-3
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ end
@time @safetestset "Flavors" begin
include("ctmrg/flavors.jl")
end
@time @safetestset "CTMRG schemes" begin
include("ctmrg/jacobian_real_linear.jl")
end
end
if GROUP == "ALL" || GROUP == "BOUNDARYMPS"
@time @safetestset "VUMPS" begin
Expand Down
Loading