Skip to content

Commit

Permalink
Replace info NamedTuple by explicit return values
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrehmer committed Jan 20, 2025
1 parent 7459e29 commit 50c76f4
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 57 deletions.
1 change: 0 additions & 1 deletion src/algorithms/contractions/ctmrg_contractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ Right projector:
```
"""
function contract_projectors(U, S, V, Q, Q_next)
@show (space(U), space(V), space(Q), space(Q_next))
isqS = sdiag_pow(S, -0.5)
P_left = Q_next * V' * isqS # use * to respect fermionic case
P_right = isqS * U' * Q
Expand Down
9 changes: 5 additions & 4 deletions src/algorithms/ctmrg/ctmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ function MPSKit.leading_boundary(envinit, state, alg::CTMRGAlgorithm)

truncation_error = 0.0
condition_number = 1.0
U, S, V = _prealloc_svd(envinit.edges)
return LoggingExtras.withlevel(; alg.verbosity) do
ctmrg_loginit!(log, η, state, envinit)
for iter in 1:(alg.maxiter)
env, info = ctmrg_iteration(state, env, alg) # Grow and renormalize in all 4 directions
truncation_error = info.truncation_error
condition_number = info.condition_number
env, truncation_error, condition_number, U, S, V = ctmrg_iteration(
state, env, alg
) # Grow and renormalize in all 4 directions
η, CS, TS = calc_convergence(env, CS, TS)

if η alg.tol && iter alg.miniter
Expand All @@ -58,7 +59,7 @@ function MPSKit.leading_boundary(envinit, state, alg::CTMRGAlgorithm)
ctmrg_logiter!(log, iter, η, state, env)
end
end
return env, (; truncation_error, condition_number)
return env, truncation_error, condition_number, copy(U), copy(S), copy(V)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/ctmrg/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function compute_projector(enlarged_corners, coordinate, alg::HalfInfiniteProjec
end
end
P_left, P_right = contract_projectors(U, S, V, enlarged_corners...)
return (P_left, P_right), (; err, U, S, V)
return (P_left, P_right), err, U, S, V
end
function compute_projector(enlarged_corners, coordinate, alg::FullInfiniteProjector)
halfinf_left = half_infinite_environment(enlarged_corners[1], enlarged_corners[2])
Expand All @@ -93,5 +93,5 @@ function compute_projector(enlarged_corners, coordinate, alg::FullInfiniteProjec
end
end
P_left, P_right = contract_projectors(U, S, V, halfinf_left, halfinf_right)
return (P_left, P_right), (; err, U, S, V)
return (P_left, P_right), err, U, S, V

Check warning on line 96 in src/algorithms/ctmrg/projectors.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/projectors.jl#L96

Added line #L96 was not covered by tests
end
35 changes: 24 additions & 11 deletions src/algorithms/ctmrg/sequential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,26 @@ end
function ctmrg_iteration(state, envs::CTMRGEnv, alg::SequentialCTMRG)
truncation_error = zero(real(scalartype(state)))
condition_number = zero(real(scalartype(state)))
for _ in 1:4 # rotate
U, S, V = _prealloc_svd(envs.edges)
for dir in 1:4 # rotate

Check warning on line 37 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L34-L37

Added lines #L34 - L37 were not covered by tests
for col in 1:size(state, 2) # left move column-wise
projectors, info = sequential_projectors(col, state, envs, alg.projector_alg)
projectors, err, cond, U′, S′, V′ = sequential_projectors(

Check warning on line 39 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L39

Added line #L39 was not covered by tests
col, state, envs, alg.projector_alg
)
envs = renormalize_sequentially(col, projectors, state, envs)
truncation_error = max(truncation_error, info.truncation_error)
condition_number = max(condition_number, info.condition_number)
truncation_error = max(truncation_error, err)
condition_number = max(condition_number, cond)
for row in 1:size(state, 1)
U[dir, row, col] = U′[row]
S[dir, row, col] = S′[row]
V[dir, row, col] = V′[row]
end

Check warning on line 49 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L43-L49

Added lines #L43 - L49 were not covered by tests
end
state = rotate_north(state, EAST)
envs = rotate_north(envs, EAST)
end

return envs, (; truncation_error, condition_number)
return envs, truncation_error, condition_number, U, S, V

Check warning on line 55 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L55

Added line #L55 was not covered by tests
end

"""
Expand All @@ -62,21 +70,26 @@ function sequential_projectors(
S = Zygote.Buffer(

Check warning on line 70 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
zeros(size(envs, 2)), tensormaptype(spacetype(T), 1, 1, real(scalartype(T)))
)
U, S, V = _prealloc_svd(@view(envs.edges[4, :, col]))

Check warning on line 73 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L73

Added line #L73 was not covered by tests
coordinates = eachcoordinate(envs)[:, col]
projectors = dtmap(coordinates) do (r, c)
trscheme = truncation_scheme(alg, envs.edges[WEST, _prev(r, size(envs, 2)), c])
proj, info = sequential_projectors(
proj, err, U′, S′, V′ = sequential_projectors(

Check warning on line 77 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L77

Added line #L77 was not covered by tests
(WEST, r, c), state, envs, @set(alg.trscheme = trscheme)
)
ϵ[r] = info.err / norm(info.S)
S[r] = info.S
U[r] = U′
S[r] = S′
V[r] = V′
ϵ[r] = err / norm(S′)

Check warning on line 83 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L80-L83

Added lines #L80 - L83 were not covered by tests
return proj
end

P_left = map(first, projectors)
P_right = map(last, projectors)
S = copy(S)
truncation_error = maximum(copy(ϵ))
condition_number = maximum(_condition_number, copy(S))
info = (; truncation_error, condition_number)
return (map(first, projectors), map(last, projectors)), info
condition_number = maximum(_condition_number, S)
return (P_left, P_right), truncation_error, condition_number, copy(U), S, copy(V)

Check warning on line 92 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L87-L92

Added lines #L87 - L92 were not covered by tests
end
function sequential_projectors(
coordinate::NTuple{3,Int},
Expand Down
23 changes: 12 additions & 11 deletions src/algorithms/ctmrg/simultaneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ function ctmrg_iteration(state, envs::CTMRGEnv, alg::SimultaneousCTMRG)
enlarged_corners = dtmap(eachcoordinate(state, 1:4)) do idx
return TensorMap(EnlargedCorner(state, envs, idx), idx[1])
end # expand environment
projectors, info = simultaneous_projectors(enlarged_corners, envs, alg.projector_alg) # compute projectors on all coordinates
projectors, truncation_error, condition_number, U, S, V = simultaneous_projectors(
enlarged_corners, envs, alg.projector_alg
) # compute projectors on all coordinates
envs′ = renormalize_simultaneously(enlarged_corners, projectors, state, envs) # renormalize enlarged corners
return envs′, info
return envs′, truncation_error, condition_number, U, S, V
end

# Pre-allocate U, S, and V tensor as Zygote buffers to make it differentiable
function _prealloc_svd(edges::Array{E,N}) where {E,N}
function _prealloc_svd(edges::AbstractArray{E,N}) where {E,N}
Sc = scalartype(E)
U = Zygote.Buffer(map(e -> TensorMap(zeros, Sc, space(e)), edges))
V = Zygote.Buffer(map(e -> TensorMap(zeros, Sc, domain(e), codomain(e)), edges))
Expand Down Expand Up @@ -74,23 +76,22 @@ function simultaneous_projectors(
projectors = dtmap(eachcoordinate(envs, 1:4)) do coordinate
coordinate′ = _next_coordinate(coordinate, size(envs)[2:3]...)
trscheme = truncation_scheme(alg, envs.edges[coordinate[1], coordinate′[2:3]...])
proj, info = simultaneous_projectors(
proj, err, U′, S′, V′ = simultaneous_projectors(
coordinate, enlarged_corners, @set(alg.trscheme = trscheme)
)
U[coordinate...] = info.U
S[coordinate...] = info.S
V[coordinate...] = info.V
ϵ[coordinate...] = info.err / norm(info.S)
U[coordinate...] = U′
S[coordinate...] = S′
V[coordinate...] = V′
ϵ[coordinate...] = err / norm(S′)
return proj
end

P_left = map(first, projectors)
P_right = map(last, projectors)
S = copy(S)
truncation_error = maximum(copy(ϵ)) # TODO: This makes Zygote error on first execution?
truncation_error = maximum(copy(ϵ))
condition_number = maximum(_condition_number, S)
info = (; truncation_error, condition_number, U=copy(U), S, V=copy(V))
return (P_left, P_right), info
return (P_left, P_right), truncation_error, condition_number, copy(U), S, copy(V)
end
function simultaneous_projectors(
coordinate, enlarged_corners::Array{E,3}, alg::HalfInfiniteProjector
Expand Down
18 changes: 10 additions & 8 deletions src/algorithms/optimization/fixed_point_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ function _rrule(
state,
alg::CTMRGAlgorithm,
)
envs, info = leading_boundary(envinit, state, alg)
envs, truncation_error, condition_number = leading_boundary(envinit, state, alg)

Check warning on line 94 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L94

Added line #L94 was not covered by tests

function leading_boundary_diffgauge_pullback((Δenvs′, Δinfo))
function leading_boundary_diffgauge_pullback((Δenvs′, Δtrunc_error, Δcond_number))

Check warning on line 96 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L96

Added line #L96 was not covered by tests
Δenvs = unthunk(Δenvs′)

# find partial gradients of gauge_fixed single CTMRG iteration
Expand All @@ -108,7 +108,7 @@ function _rrule(
return NoTangent(), ZeroTangent(), ∂F∂envs, NoTangent()
end

return (envs, info), leading_boundary_diffgauge_pullback
return (envs, truncation_error, condition_number), leading_boundary_diffgauge_pullback

Check warning on line 111 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L111

Added line #L111 was not covered by tests
end

# Here f is differentiated from an pre-computed SVD with fixed U, S and V
Expand All @@ -122,18 +122,20 @@ function _rrule(
)
@assert !isnothing(alg.projector_alg.svd_alg.rrule_alg)
envs, = leading_boundary(envinit, state, alg)
envs_conv, info = ctmrg_iteration(state, envs, alg)
envs_conv, truncation_error, condition_number, U, S, V = ctmrg_iteration(

Check warning on line 125 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
state, envs, alg
)
envs_fixed, signs = gauge_fix(envs, envs_conv)

Check warning on line 128 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L128

Added line #L128 was not covered by tests

# Fix SVD
Ufix, Vfix = fix_relative_phases(info.U, info.V, signs)
Ufix, Vfix = fix_relative_phases(U, V, signs)

Check warning on line 131 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L131

Added line #L131 was not covered by tests
svd_alg_fixed = SVDAdjoint(;
fwd_alg=FixedSVD(Ufix, info.S, Vfix), rrule_alg=alg.projector_alg.svd_alg.rrule_alg
fwd_alg=FixedSVD(Ufix, S, Vfix), rrule_alg=alg.projector_alg.svd_alg.rrule_alg
)
alg_fixed = @set alg.projector_alg.svd_alg = svd_alg_fixed
alg_fixed = @set alg_fixed.projector_alg.trscheme = notrunc()

function leading_boundary_fixed_pullback((Δenvs′, Δinfo))
function leading_boundary_fixed_pullback((Δenvs′, Δtrunc_error, Δcond_number))

Check warning on line 138 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L138

Added line #L138 was not covered by tests
Δenvs = unthunk(Δenvs′)

f(A, x) = fix_global_phases(x, ctmrg_iteration(A, x, alg_fixed)[1])
Expand All @@ -147,7 +149,7 @@ function _rrule(
return NoTangent(), ZeroTangent(), ∂F∂envs, NoTangent()
end

return (envs_fixed, info), leading_boundary_fixed_pullback
return (envs_fixed, truncation_error, condition_number), leading_boundary_fixed_pullback

Check warning on line 152 in src/algorithms/optimization/fixed_point_differentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/fixed_point_differentiation.jl#L152

Added line #L152 was not covered by tests
end

@doc """
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/optimization/manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function (r::RecordTruncationError)(
p::AbstractManoptProblem, ::AbstractManoptSolverState, i::Int
)
cache = Manopt.get_cost_function(get_objective(p))
return Manopt.record_or_reset!(r, cache.env_info.truncation_error, i)
return Manopt.record_or_reset!(r, cache.truncation_error, i)

Check warning on line 25 in src/algorithms/optimization/manopt.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/manopt.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end

"""
Expand All @@ -39,7 +39,7 @@ function (r::RecordConditionNumber)(
p::AbstractManoptProblem, ::AbstractManoptSolverState, i::Int
)
cache = Manopt.get_cost_function(get_objective(p))
return Manopt.record_or_reset!(r, cache.env_info.condition_number, i)
return Manopt.record_or_reset!(r, cache.condition_number, i)

Check warning on line 42 in src/algorithms/optimization/manopt.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/manopt.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

"""
Expand Down
17 changes: 6 additions & 11 deletions src/algorithms/optimization/peps_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ mutable struct PEPSCostFunctionCache{T}
peps_vec::Vector{T}
grad_vec::Vector{T}
cost::Float64
env_info::NamedTuple
truncation_error::Float64
condition_number::Float64
end

"""
Expand All @@ -115,14 +116,7 @@ function PEPSCostFunctionCache(
operator::LocalOperator, alg::PEPSOptimize, peps_vec::Vector, from_vec, env::CTMRGEnv
)
return PEPSCostFunctionCache(

Check warning on line 118 in src/algorithms/optimization/peps_optimization.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/peps_optimization.jl#L118

Added line #L118 was not covered by tests
operator,
alg,
env,
from_vec,
similar(peps_vec),
similar(peps_vec),
0.0,
(; truncation_error=0.0, condition_number=1.0),
operator, alg, env, from_vec, similar(peps_vec), similar(peps_vec), 0.0, 0.0, 1.0
)
end

Expand All @@ -140,7 +134,7 @@ function cost_and_grad!(cache::PEPSCostFunctionCache{T}, peps_vec::Vector{T}) wh

# compute cost and gradient
cost, grads = withgradient(peps) do ψ
env, info = hook_pullback(
env, truncation_error, condition_number = hook_pullback(

Check warning on line 137 in src/algorithms/optimization/peps_optimization.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/peps_optimization.jl#L136-L137

Added lines #L136 - L137 were not covered by tests
leading_boundary,
env₀,
ψ,
Expand All @@ -150,7 +144,8 @@ function cost_and_grad!(cache::PEPSCostFunctionCache{T}, peps_vec::Vector{T}) wh
cost = expectation_value(ψ, cache.operator, env)
ignore_derivatives() do
update!(cache.env, env) # update environment in-place
cache.env_info = info # update environment information (truncation error, ...)
cache.truncation_error = truncation_error # update environment information
cache.condition_number = condition_number
isapprox(imag(cost), 0; atol=sqrt(eps(real(cost)))) ||
@warn "Expectation value is not real: $cost."

Check warning on line 150 in src/algorithms/optimization/peps_optimization.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/optimization/peps_optimization.jl#L144-L150

Added lines #L144 - L150 were not covered by tests
end
Expand Down
8 changes: 5 additions & 3 deletions test/ctmrg/fixed_iterscheme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ atol = 1e-5
env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χenv)), psi, ctm_alg)

# do extra iteration to get SVD
env_conv2, info = ctmrg_iteration(psi, env_conv1, ctm_alg)
env_conv2, truncation_error, condition_number, U, S, V = ctmrg_iteration(
psi, env_conv1, ctm_alg
)
env_fix, signs = gauge_fix(env_conv1, env_conv2)
@test calc_elementwise_convergence(env_conv1, env_fix) 0 atol = atol

# fix gauge of SVD
U_fix, V_fix = fix_relative_phases(info.U, info.V, signs)
svd_alg_fix = SVDAdjoint(; fwd_alg=FixedSVD(U_fix, info.S, V_fix))
U_fix, V_fix = fix_relative_phases(U, V, signs)
svd_alg_fix = SVDAdjoint(; fwd_alg=FixedSVD(U_fix, S, V_fix))
ctm_alg_fix = SimultaneousCTMRG(;
projector_alg, svd_alg=svd_alg_fix, trscheme=notrunc()
)
Expand Down
6 changes: 2 additions & 4 deletions test/ctmrg/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ gradmodes = [
calgs = ctmrg_algs[i]
@testset "$ctmrg_alg and $gradient_alg" for (ctmrg_alg, gradient_alg) in
Iterators.product(calgs, gms)
@info "gradient check of $ctmrg_alg and $alg_rrule on $(names[i])"
@info "gradient check of $ctmrg_alg and $gradient_alg on $(names[i])"
Random.seed!(42039482030)
psi = InfinitePEPS(Pspace, Vspace)
env, = leading_boundary(CTMRGEnv(psi, Espace), psi, ctmrg_alg)
Expand All @@ -67,8 +67,6 @@ gradmodes = [
grad = gradient_function(cache)

M = Euclidean(length(psi_vec))
@test check_gradient(
M, cost, grad; N=10, exactness_tol=gradtol, limits=(-8, -3), io=stdout
)
@test check_gradient(M, cost, grad; N=10, exactness_tol=gradtol, limits=(-8, -3))
end
end

0 comments on commit 50c76f4

Please sign in to comment.