diff --git a/src/Integrator.jl b/src/Integrator.jl index 9531303..822052c 100644 --- a/src/Integrator.jl +++ b/src/Integrator.jl @@ -56,6 +56,33 @@ end return nothing end +@inline function RK4b(rhs!::F, dt::Real, reg1, reg2, reg3, reg4, params, t::Real, + indices::CartesianIndices, N) where {F} + rhs!(reg4, reg1, params.h, N, t) + @inbounds @fastmath @avx for i in indices + reg3[i] = dt * reg4[i] + reg1[i] = reg1[i] + reg3[i] / 2 + reg2[i] = reg3[i] + end + rhs!(reg4, reg1, params.h, N, t) + @inbounds @fastmath @avx for i in indices + reg3[i] = dt * reg4[i] + reg1[i] = reg1[i] + (reg3[i] - reg2[i]) / 2 + end + rhs!(reg4, reg1, params.h, N, t) + @inbounds @fastmath @avx for i in indices + reg3[i] = dt * reg4[i] - reg3[i] / 2 + reg1[i] = reg1[i] + reg3[i] + reg2[i] = reg2[i] / 6 - reg3[i] + end + rhs!(reg4, reg1, params.h, N, t) + @inbounds @fastmath @avx for i in indices + reg3[i] = dt * reg4[i] + reg3[i] + reg3[i] + reg1[i] = reg1[i] + reg2[i] + reg3[i] / 6 + end + return nothing +end + function solve(rhs, statevector, params, t, grid, save_every, folder) # istr = get_iter_str(0, t.ncells + 1) base_dir = base_path(folder) @@ -98,23 +125,25 @@ function solve(rhs, statevector, params, t, grid, save_every, folder) println("saving every=", save_every) - # nt = t.ncells + 1 - # dt = spacing(t) - # indices = CartesianIndices(statevector) - # params = (grid=grid, dt=dt, ti=coords(t), params...) + nt = t.ncells + 1 + dt = spacing(t) + indices = CartesianIndices(statevector) + params = (grid=grid, dt=dt, ti=coords(t), params...) # InputOutput.write_metadata(base_dir, params, save_every, sims_dir) - # for (i, ti) in enumerate(params.ti) - # i -= 1 - # if i == 0 - # continue - # end - # println("Iteration = ", i, "/", t.ncells) - # @time RK4(rhs, dt, statevector, reg2, reg3, reg4, params, ti, indices, N) - # if i % save_every == 0 - # InputOutput.writevtk(statevector, sims_dir, xcoord, ti, i, pvd) - # end - # end - + for (i, ti) in enumerate(coords(t)[2:end]) + # println("Iteration = ", i, "/", t.ncells) + if (i==1) global wtime0 = Base.time() end + RK4(rhs, dt, statevector, reg2, reg3, reg4, params, ti, indices, N) + # RK4(ODE.rhs_batch!, dt, statevector, reg2, reg3, reg4, params, ti, indices, N) + # if i % save_every == 0 + # InputOutput.writevtk(statevector, sims_dir, xcoord, ti, i, pvd) + # end + end + wtime = Base.time()-wtime0 + A_eff = (3*2)/1e9*N*N*sizeof(Data.Number) + wtime_it = wtime/(nt) # Execution time per iteration [s] + T_eff = A_eff/wtime_it # Effective memory throughput [GB/s] + println("Total steps=$nt, time=$wtime sec (@ T_eff = $(round(T_eff, sigdigits=2)) GB/s)") # InputOutput.save_pvd(pvd) return nothing end