Skip to content

Commit

Permalink
refactor VOMPS
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 30, 2023
1 parent 7225833 commit 228af5a
Showing 1 changed file with 60 additions and 43 deletions.
103 changes: 60 additions & 43 deletions src/algorithms/approximate/vomps.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,81 @@
function approximate(
state::InfiniteMPS,
ψ::InfiniteMPS,
toapprox::Tuple{<:Union{SparseMPO,DenseMPO},<:InfiniteMPS},
alg,
envs=environments(state, toapprox),
algorithm,
envs=environments(ψ, toapprox),
)
# PeriodicMPO's always act on MPSMultiline's. I therefore convert the imps to multilines, approximate and convert back
(multi, envs) = approximate(
convert(MPSMultiline, state),
(convert(MPOMultiline, envs.opp), convert(MPSMultiline, envs.above)),
alg,
# PeriodicMPO's always act on MPSMultiline's. To avoid code duplication, define everything in terms of MPSMultiline's.
multi, envs = approximate(
convert(MPSMultiline, ψ),
(convert(MPOMultiline, toapprox[1]), convert(MPSMultiline, toapprox[2])),
algorithm,
envs,
)
state = convert(InfiniteMPS, multi)
return (state, envs)
ψ = convert(InfiniteMPS, multi)
return ψ, envs
end
function approximate(
state::MPSMultiline,
ψ::MPSMultiline,
toapprox::Tuple{<:MPOMultiline,<:MPSMultiline},
alg::VUMPS,
envs=environments(state, toapprox),
envs=environments(ψ, toapprox),
)
galerkin = calc_galerkin(state, envs)
iter = 1
t₀ = Base.time_ns()
ϵ::Float64 = calc_galerkin(ψ, envs)
temp_ACs = similar.(ψ.AC)

temp_ACs = similar.(state.AC)
temp_Cs = similar.(state.CR)
for iter in 1:(alg.maxiter)
_, tol_gauge, tol_envs = updatetols(alg, iter, ϵ)
Δt = @elapsed begin
@static if Defaults.parallelize_sites
@sync for col in 1:size(ψ, 2)
Threads.@spawn _vomps_localupdate!(
temp_ACs[:, col], col, ψ, toapprox, envs
)
end
else
for col in 1:size(ψ, 2)
_vomps_localupdate!(temp_ACs[:, col], col, ψ, toapprox, envs)
end
end

while true
_, tol_gauge, tol_envs = updatetols(alg, iter, galerkin)
@sync for col in 1:size(state, 2)
Threads.@spawn $temp_ACs[:, col] = circshift(
[ac_proj(row, $col, $state, $envs) for row in 1:size($state, 1)], 1
)
Threads.@spawn $temp_Cs[:, col] = circshift(
[c_proj(row, $col, $state, $envs) for row in 1:size($state, 1)], 1
)
end
ψ = MPSMultiline(temp_ACs, ψ.CR[:, end]; tol=tol_gauge, maxiter=alg.orthmaxiter)
recalculate!(envs, ψ; tol=tol_envs)

ψ, envs = alg.finalize(iter, ψ, toapprox, envs)::Tuple{typeof(ψ),typeof(envs)}

for row in 1:size(state, 1), col in 1:size(state, 2)
QAc, _ = leftorth!(temp_ACs[row, col]; alg=TensorKit.QRpos())
Qc, _ = leftorth!(temp_Cs[row, col]; alg=TensorKit.QRpos())
temp_ACs[row, col] = QAc * adjoint(Qc)
ϵ = calc_galerkin(ψ, envs)
end

state = MPSMultiline(
temp_ACs, state.CR[:, end]; tol=tol_gauge, maxiter=alg.orthmaxiter
)
recalculate!(envs, state; tol=tol_envs)
alg.verbose && @info "VOMPS iteration:" iter ϵ Δt

(state, envs) =
alg.finalize(iter, state, toapprox, envs)::Tuple{typeof(state),typeof(envs)}
ϵ <= alg.tol_galerkin && break
iter == alg.maxiter && @warn "VOMPS maximum iterations" iter ϵ
end

galerkin = calc_galerkin(state, envs)
alg.verbose && @info "vomps @iteration $(iter) galerkin = $(galerkin)"
Δt = (Base.time_ns() - t₀) / 1.0e9
alg.verbose && @info "VOMPS summary:" ϵ Δt
return ψ, envs, ϵ
end

if (galerkin <= alg.tol_galerkin) || iter >= alg.maxiter
iter >= alg.maxiter && @warn "vomps didn't converge $(galerkin)"
return state, envs, galerkin
function _vomps_localupdate!(AC′, loc, ψ, (O, ψ₀), envs, factalg=QRpos())
local Q_AC, Q_C
@static if Defaults.parallelize_sites
@sync begin
Threads.@spawn begin
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
Q_AC = first.(leftorth!.(tmp_AC; alg=factalg))
end
Threads.@spawn begin
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
Q_C = first.(leftorth!.(tmp_C; alg=factalg))
end
end

iter += 1
else
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
Q_AC = first.(leftorth!.(tmp_AC; alg=factalg))
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
Q_C = first.(leftorth!.(tmp_C; alg=factalg))
end
return mul!.(AC′, Q_AC, adjoint.(Q_C))
end

0 comments on commit 228af5a

Please sign in to comment.