diff --git a/asQ/complex_proxy/common.py b/asQ/complex_proxy/common.py index 1b1a7636..145fcd3f 100644 --- a/asQ/complex_proxy/common.py +++ b/asQ/complex_proxy/common.py @@ -15,7 +15,7 @@ def _flatten_tree(root, is_leaf, get_children, container=tuple): """ - Return the recursively flattened tree below root in the order that the leafs appear in the traversal. + Return the recursively flattened tree below root in the order that the leafs appear in the depth first traversal. :arg root: the current root node. :arg is_leaf: predicate on root returning True if root has no children. diff --git a/examples/serial/shallow_water/galewsky_gusto.py b/examples/serial/shallow_water/galewsky_gusto.py new file mode 100644 index 00000000..15a97753 --- /dev/null +++ b/examples/serial/shallow_water/galewsky_gusto.py @@ -0,0 +1,263 @@ + +import firedrake as fd +from firedrake.petsc import PETSc +import gusto + +from utils import units +from utils import mg +from utils.planets import earth +import utils.shallow_water as swe +from utils.shallow_water import galewsky +from utils import diagnostics + +from utils.serial import SerialMiniApp + +from functools import partial + +PETSc.Sys.popErrorHandler() + +# get command arguments +import argparse +parser = argparse.ArgumentParser( + description='Galewsky testcase using fully implicit SWE solver.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--ref_level', type=int, default=2, help='Refinement level of icosahedral grid.') +parser.add_argument('--nt', type=int, default=10, help='Number of time steps.') +parser.add_argument('--dt', type=float, default=0.5, help='Timestep in hours.') +parser.add_argument('--degree', type=float, default=swe.default_degree(), help='Degree of the depth function space.') +parser.add_argument('--theta', type=float, default=0.5, help='Parameter for implicit theta method. 0.5 for trapezium rule, 1 for backwards Euler.') +parser.add_argument('--filename', type=str, default='galewsky', help='Name of output vtk files') +parser.add_argument('--show_args', action='store_true', help='Output all the arguments.') + +args = parser.parse_known_args() +args = args[0] + +nt = args.nt +degree = args.degree + +if args.show_args: + PETSc.Sys.Print(args) + +PETSc.Sys.Print('') +PETSc.Sys.Print('### === --- Setting up --- === ###') +PETSc.Sys.Print('') + +# icosahedral mg mesh +mesh = swe.create_mg_globe_mesh(ref_level=args.ref_level, coords_degree=1) +x = fd.SpatialCoordinate(mesh) + +# time step +dt = args.dt*units.hour + +# shallow water equation function spaces (velocity and depth) +W = swe.default_function_space(mesh, degree=args.degree) + +# parameters +gravity = earth.Gravity + +topography = galewsky.topography_expression(*x) +coriolis = swe.earth_coriolis_expression(*x) + +# initial conditions +w_initial = fd.Function(W) +u_initial = w_initial.subfunctions[0] +h_initial = w_initial.subfunctions[1] + +u_initial.project(galewsky.velocity_expression(*x)) +h_initial.project(galewsky.depth_expression(*x)) + +# current and next timestep +w0 = fd.Function(W).assign(w_initial) +w1 = fd.Function(W).assign(w_initial) + + +# shallow water equation forms +def form_function(u, h, v, q, t): + return swe.nonlinear.form_function(mesh, + gravity, + topography, + coriolis, + u, h, v, q, t) + + +def form_mass(u, h, v, q): + return swe.nonlinear.form_mass(mesh, u, h, v, q) + +# gusto forms + +swe_parameters = gusto.ShallowWaterParameters(H=galewsky.H0, + g=earth.Gravity, + Omega=earth.Omega) + +domain = gusto.Domain(mesh, dt, 'BDM', degree=args.degree) + +eqn = gusto.ShallowWaterEquations(domain, + swe_parameters, + fexpr=galewsky.coriolis_expression(*x)) + +from gusto.labels import replace_subject, replace_test_function, time_derivative, prognostic +from gusto.fml.form_manipulation_labelling import all_terms, drop +from gusto import Term +from firedrake.formmanipulation import split_form + + +def extract_form_mass(u, h, v, q, residual=None): + M = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + Mu = M.label_map(lambda t: t.get(prognostic) == "u", + lambda t: Term(split_form(t.form)[0].form, t.labels), + drop) + + Mh = M.label_map(lambda t: t.get(prognostic) == "D", + lambda t: Term(split_form(t.form)[1].form, t.labels), + drop) + + Mu = Mu.label_map(all_terms, replace_test_function(v)) + Mu = Mu.label_map(all_terms, replace_subject(u, idx=0)) + + Mh = Mh.label_map(all_terms, replace_test_function(q)) + Mh = Mh.label_map(all_terms, replace_subject(h, idx=1)) + + M = Mu + Mh + print("form_mass: ", M.form) + return M.form + + +def extract_form_function(u, h, v, q, t, residual=None): + K = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_true=drop) + + Ku = K.label_map(lambda t: t.get(prognostic) == "u", + lambda t: Term(split_form(t.form)[0].form, t.labels), + drop) + + Kh = K.label_map(lambda t: t.get(prognostic) == "D", + lambda t: Term(split_form(t.form)[1].form, t.labels), + drop) + + Ku = Ku.label_map(all_terms, replace_test_function(v)) + Ku = Ku.label_map(all_terms, replace_subject(u, idx=0)) + + Kh = Kh.label_map(all_terms, replace_test_function(q)) + Kh = Kh.label_map(all_terms, replace_subject(h, idx=1)) + + K = Ku + Kh + print("form_function: ", K.form) + return K.form + + +form_mass_gusto = partial(extract_form_mass, residual=eqn.residual) +form_function_gusto = partial(extract_form_function, residual=eqn.residual) + + +# solver parameters for the implicit solve +atol = 1e-12 +sparameters = { + 'snes': { + 'monitor': None, + 'converged_reason': None, + 'rtol': 1e-12, + 'atol': atol, + 'ksp_ew': None, + 'ksp_ew_version': 1, + }, + #'mat_type': 'matfree', + #'ksp_type': 'fgmres', + 'ksp': { + 'monitor': None, + 'converged_reason': None, + 'atol': atol, + 'rtol': 1e-5, + }, + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_mat_solver_type': 'mumps', +# 'pc_type': 'mg', +# 'pc_mg_cycle_type': 'w', +# 'pc_mg_type': 'multiplicative', +# 'mg': { +# 'levels': { +# 'ksp_type': 'gmres', +# 'ksp_max_it': 5, +# 'pc_type': 'python', +# 'pc_python_type': 'firedrake.PatchPC', +# 'patch': { +# 'pc_patch_save_operators': True, +# 'pc_patch_partition_of_unity': True, +# 'pc_patch_sub_mat_type': 'seqdense', +# 'pc_patch_construct_dim': 0, +# 'pc_patch_construct_type': 'vanka', +# 'pc_patch_local_type': 'additive', +# 'pc_patch_precompute_element_tensors': True, +# 'pc_patch_symmetrise_sweep': False, +# 'sub_ksp_type': 'preonly', +# 'sub_pc_type': 'lu', +# 'sub_pc_factor_shift_type': 'nonzero', +# }, +# }, +# 'coarse': { +# 'pc_type': 'python', +# 'pc_python_type': 'firedrake.AssembledPC', +# 'assembled_pc_type': 'lu', +# 'assembled_pc_factor_mat_solver_type': 'mumps', +# }, +# } +} + +# set up nonlinear solver +miniapp = SerialMiniApp(dt, args.theta, + w_initial, + form_mass, + form_function_gusto, + sparameters) + +miniapp.nlsolver.set_transfer_manager( + mg.manifold_transfer_manager(W)) + +potential_vorticity = diagnostics.potential_vorticity_calculator( + u_initial.function_space(), name='vorticity') + +uout = fd.Function(u_initial.function_space(), name='velocity') +hout = fd.Function(h_initial.function_space(), name='elevation') +ofile = fd.File(f"output/{args.filename}.pvd") +# save initial conditions +uout.assign(u_initial) +hout.assign(h_initial) +ofile.write(uout, hout, potential_vorticity(uout), time=0) + +PETSc.Sys.Print('### === --- Timestepping loop --- === ###') +linear_its = 0 +nonlinear_its = 0 + + +def preproc(app, step, t): + PETSc.Sys.Print('') + PETSc.Sys.Print(f'=== --- Timestep {step} --- ===') + PETSc.Sys.Print('') + + +def postproc(app, step, t): + global linear_its + global nonlinear_its + + linear_its += app.nlsolver.snes.getLinearSolveIterations() + nonlinear_its += app.nlsolver.snes.getIterationNumber() + + uout.assign(miniapp.w0.subfunctions[0]) + hout.assign(miniapp.w0.subfunctions[1]) + ofile.write(uout, hout, potential_vorticity(uout), time=float(t)) + + +miniapp.solve(args.nt, + preproc=preproc, + postproc=postproc) + +PETSc.Sys.Print('### === --- Iteration counts --- === ###') +PETSc.Sys.Print('') + +PETSc.Sys.Print(f'linear iterations: {linear_its} | iterations per timestep: {linear_its/args.nt}') +PETSc.Sys.Print(f'nonlinear iterations: {nonlinear_its} | iterations per timestep: {nonlinear_its/args.nt}') +PETSc.Sys.Print('') diff --git a/examples/shallow_water/galewsky.py b/examples/shallow_water/galewsky.py index d9a0cff6..bf681774 100644 --- a/examples/shallow_water/galewsky.py +++ b/examples/shallow_water/galewsky.py @@ -78,11 +78,11 @@ 'pc_type': 'python', 'pc_python_type': 'firedrake.AssembledPC', 'assembled_pc_type': 'lu', - 'assembled_pc_factor_mat_solver_type': 'mumps' - } + 'assembled_pc_factor_mat_solver_type': 'mumps', + }, } -sparameters = { +block_params = { 'mat_type': 'matfree', 'ksp_type': 'fgmres', 'ksp': { @@ -96,12 +96,13 @@ 'mg': mg_parameters } +atol = 1e0 sparameters_diag = { 'snes': { 'linesearch_type': 'basic', 'monitor': None, 'converged_reason': None, - 'atol': 1e-0, + 'atol': atol, 'rtol': 1e-10, 'stol': 1e-12, 'ksp_ew': None, @@ -114,7 +115,7 @@ 'monitor': None, 'converged_reason': None, 'rtol': 1e-5, - 'atol': 1e-0, + 'atol': atol, }, 'pc_type': 'python', 'pc_python_type': 'asQ.DiagFFTPC', @@ -123,8 +124,9 @@ 'aaos_jacobian_state': 'current' } +# sparameters_diag['diagfft_block_'] = block_params for i in range(window_length): - sparameters_diag['diagfft_block_'+str(i)+'_'] = sparameters + sparameters_diag['diagfft_block_'+str(i)+'_'] = block_params create_mesh = partial( swe.create_mg_globe_mesh, diff --git a/examples/shallow_water/williamson2_gusto.py b/examples/shallow_water/williamson2_gusto.py new file mode 100644 index 00000000..2a44d1b3 --- /dev/null +++ b/examples/shallow_water/williamson2_gusto.py @@ -0,0 +1,334 @@ + +import firedrake as fd +from petsc4py import PETSc +import asQ + +from utils import mg +from utils import units +from utils.planets import earth +import utils.shallow_water as swe +import utils.shallow_water.williamson1992.case2 as case2 + +from functools import partial + +import gusto + +PETSc.Sys.popErrorHandler() + +# get command arguments +import argparse +parser = argparse.ArgumentParser( + description='Williamson 2 testcase for ParaDiag solver using fully implicit SWE solver.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +parser.add_argument('--base_level', type=int, default=1, help='Base refinement level of icosahedral grid for MG solve.') +parser.add_argument('--ref_level', type=int, default=2, help='Refinement level of icosahedral grid.') +parser.add_argument('--nwindows', type=int, default=1, help='Number of time-windows.') +parser.add_argument('--nslices', type=int, default=1, help='Number of time-slices per time-window.') +parser.add_argument('--slice_length', type=int, default=2, help='Number of timesteps per time-slice.') +parser.add_argument('--alpha', type=float, default=0.0001, help='Circulant coefficient.') +parser.add_argument('--dt', type=float, default=0.5, help='Timestep in hours.') +parser.add_argument('--filename', type=str, default='w5diag', help='Name of output vtk files') +parser.add_argument('--coords_degree', type=int, default=1, help='Degree of polynomials for sphere mesh approximation.') +parser.add_argument('--degree', type=int, default=1, help='Degree of finite element space (the DG space).') +parser.add_argument('--show_args', action='store_true', help='Output all the arguments.') + +args = parser.parse_known_args() +args = args[0] + +if args.show_args: + PETSc.Sys.Print(args) + +PETSc.Sys.Print('') +PETSc.Sys.Print('### === --- Setting up --- === ###') +PETSc.Sys.Print('') + +# time steps + +time_partition = tuple(args.slice_length for _ in range(args.nslices)) +window_length = sum(time_partition) +nsteps = args.nwindows*window_length + +dt = args.dt*units.hour + +# multigrid mesh set up + +ensemble = asQ.create_ensemble(time_partition) + +distribution_parameters = {"partition": True, "overlap_type": (fd.DistributedMeshOverlapType.VERTEX, 2)} + +# mesh set up +mesh = swe.create_mg_globe_mesh(base_level=args.base_level, + ref_level=args.ref_level, + coords_degree=args.coords_degree, + comm=ensemble.comm) +x = fd.SpatialCoordinate(mesh) + +# Mixed function space for velocity and depth +V1 = swe.default_velocity_function_space(mesh, degree=args.degree) +V2 = swe.default_depth_function_space(mesh, degree=args.degree) +W = fd.MixedFunctionSpace((V1, V2)) + +# initial conditions +w0 = fd.Function(W) +un, hn = w0.subfunctions[:] + +f = case2.coriolis_expression(*x) +b = case2.topography_function(*x, V2, name="Topography") +H = case2.H0 + +un.project(case2.velocity_expression(*x)) +etan = case2.elevation_function(*x, V2, name="Elevation") +hn.assign(H + etan - b) + + +# nonlinear swe forms + +# original method + +def form_function_asq(u, h, v, q, t): + return swe.nonlinear.form_function(mesh, earth.Gravity, b, f, u, h, v, q, t) + + +def form_mass_asq(u, h, v, q): + return swe.nonlinear.form_mass(mesh, u, h, v, q) + + +# gusto method + +swe_parameters = gusto.ShallowWaterParameters(H=H, + g=earth.Gravity, + Omega=earth.Omega) + +domain = gusto.Domain(mesh, dt, 'BDM', degree=args.degree) + +eqn = gusto.ShallowWaterEquations(domain, + swe_parameters, + fexpr=case2.coriolis_expression(*x)) + +from gusto.labels import replace_subject, replace_test_function, time_derivative, prognostic +from gusto.fml.form_manipulation_labelling import all_terms, drop +from gusto import Term +from firedrake.formmanipulation import split_form + + +def extract_form_mass(u, h, v, q, residual=None): + M = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + Mu = M.label_map(lambda t: t.get(prognostic) == "u", + lambda t: Term(split_form(t.form)[0].form, t.labels), + drop) + + Mh = M.label_map(lambda t: t.get(prognostic) == "D", + lambda t: Term(split_form(t.form)[1].form, t.labels), + drop) + + Mu = Mu.label_map(all_terms, replace_test_function(v)) + Mu = Mu.label_map(all_terms, replace_subject(u, idx=0)) + + Mh = Mh.label_map(all_terms, replace_test_function(q)) + Mh = Mh.label_map(all_terms, replace_subject(h, idx=1)) + + M = Mu + Mh + return M.form + + +def extract_form_function(u, h, v, q, t, residual=None): + K = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_true=drop) + + Ku = K.label_map(lambda t: t.get(prognostic) == "u", + lambda t: Term(split_form(t.form)[0].form, t.labels), + drop) + + Kh = K.label_map(lambda t: t.get(prognostic) == "D", + lambda t: Term(split_form(t.form)[1].form, t.labels), + drop) + + Ku = Ku.label_map(all_terms, replace_test_function(v)) + Ku = Ku.label_map(all_terms, replace_subject(u, idx=0)) + + Kh = Kh.label_map(all_terms, replace_test_function(q)) + Kh = Kh.label_map(all_terms, replace_subject(h, idx=1)) + + K = Ku + Kh + return K.form + + +form_mass_gusto = partial(extract_form_mass, residual=eqn.residual) +form_function_gusto = partial(extract_form_function, residual=eqn.residual) + +# solver parameters + +patch_parameters = { + 'pc_patch': { + 'save_operators': True, + 'partition_of_unity': True, + 'sub_mat_type': 'seqdense', + 'construct_dim': 0, + 'construct_type': 'vanka', + 'local_type': 'additive', + 'precompute_element_tensors': True, + 'symmetrise_sweep': False + }, + 'sub': { + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_shift_type': 'nonzero', + } +} + +mg_params = { + 'levels': { + 'ksp_type': 'gmres', + 'ksp_max_it': 4, + 'pc_type': 'python', + 'pc_python_type': 'firedrake.PatchPC', + 'patch': patch_parameters, + }, + 'coarse': { + 'pc_type': 'python', + 'pc_python_type': 'firedrake.AssembledPC', + 'assembled_pc_type': 'lu', + 'assembled_pc_factor_mat_solver_type': 'mumps', + }, +} + +block_params = { + 'mat_type': 'matfree', + 'ksp_type': 'fgmres', + 'ksp': { + 'atol': 1e-5, + 'rtol': 1e-5, + }, + 'pc_type': 'mg', + 'pc_mg_cycle_type': 'v', + 'pc_mg_type': 'multiplicative', + 'mg': mg_params +} + +lu_params = { + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_mat_solver_type': 'mumps', +} + +sparameters_diag = { + 'snes': { + 'linesearch_type': 'basic', + 'monitor': None, + 'converged_reason': None, + 'atol': 1e-0, + 'rtol': 1e-12, + }, + 'mat_type': 'matfree', + 'ksp_type': 'fgmres', + 'ksp': { + 'monitor': None, + 'converged_reason': None, + }, + 'pc_type': 'python', + 'pc_python_type': 'asQ.DiagFFTPC' +} + +sparameters_diag['diagfft_block_'] = lu_params + +# non-petsc information for block solve +block_ctx = {} + +# mesh transfer operators +transfer_managers = [] +nlocal_timesteps = time_partition[ensemble.ensemble_comm.rank] +for _ in range(nlocal_timesteps): + tm = mg.manifold_transfer_manager(W) + transfer_managers.append(tm) + +block_ctx['diag_transfer_managers'] = transfer_managers + +# make a paradiag solver the old way + +PETSc.Sys.Print('### === --- Setting up old Paradiag --- === ###') +PETSc.Sys.Print('') + +pdg = asQ.paradiag(ensemble=ensemble, + form_function=form_function_asq, + form_mass=form_mass_asq, + w0=w0, dt=dt, theta=0.5, + alpha=args.alpha, + time_partition=time_partition, solver_parameters=sparameters_diag, + circ=None, tol=1.0e-6, maxits=None, + ctx={}, block_ctx=block_ctx, block_mat_type="aij") + +# make a paradiag solver with gusto forms + +PETSc.Sys.Print('### === --- Setting up gusto Paradiag --- === ###') +PETSc.Sys.Print('') + +w0g = w0.copy() +pdg_gusto = asQ.paradiag(ensemble=ensemble, + form_function=form_function_gusto, + form_mass=form_mass_gusto, + w0=w0g, dt=dt, theta=0.5, + alpha=args.alpha, + time_partition=time_partition, solver_parameters=sparameters_diag, + circ=None, tol=1.0e-6, maxits=None, + ctx={}, block_ctx=block_ctx, block_mat_type="aij") + + +def window_preproc(pdg, wndw): + PETSc.Sys.Print('') + PETSc.Sys.Print(f'### === --- Calculating time-window {wndw} --- === ###') + PETSc.Sys.Print('') + + +# check against initial conditions +wcheck = w0.copy(deepcopy=True) +ucheck, hcheck = wcheck.subfunctions[:] +hcheck.assign(hcheck - H + b) + + +def steady_state_test(w): + up = w.subfunctions[0] + hp = w.subfunctions[1] + hp.assign(hp - H + b) + + uerr = fd.errornorm(ucheck, up)/fd.norm(ucheck) + herr = fd.errornorm(hcheck, hp)/fd.norm(hcheck) + + return uerr, herr + + +# check each timestep against steady state +def window_postproc(pdg, wndw): + uerrors = asQ.SharedArray(time_partition, comm=ensemble.ensemble_comm) + herrors = asQ.SharedArray(time_partition, comm=ensemble.ensemble_comm) + + for i in range(pdg.nlocal_timesteps): + uerrors.dlocal[i], herrors.dlocal[i] = steady_state_test(pdg.aaos.get_field(i)) + + uerrors.synchronise() + herrors.synchronise() + + for window_index in range(pdg.ntimesteps): + timestep = wndw*pdg.ntimesteps + window_index + uerr = uerrors.dglobal[window_index] + herr = herrors.dglobal[window_index] + PETSc.Sys.Print(f"timestep={timestep}, uerr={uerr}, herr={herr}") + + +PETSc.Sys.Print('### === --- Calculating parallel solution --- === ###') +PETSc.Sys.Print('') + +PETSc.Sys.Print("Solving with gusto forms") + +pdg_gusto.solve(nwindows=args.nwindows, + preproc=window_preproc, + postproc=window_postproc) + +PETSc.Sys.Print('') +PETSc.Sys.Print("Solving with the old forms") + +pdg.solve(nwindows=args.nwindows, + preproc=window_preproc, + postproc=window_postproc) diff --git a/examples/test_gusto_advection.py b/examples/test_gusto_advection.py new file mode 100644 index 00000000..2272c82c --- /dev/null +++ b/examples/test_gusto_advection.py @@ -0,0 +1,51 @@ +from utils.planets import earth +import firedrake as fd +import gusto +import asQ.complex_proxy.mixed as cpx + +from gusto.labels import replace_subject, replace_test_function, time_derivative +from gusto.fml.form_manipulation_labelling import all_terms, drop + +dt = 0.1 +L =1.0 +nx = 8 + +# mesh = fd.UnitSquareMesh(8, 8) + +base_mesh = fd.PeriodicIntervalMesh(nx, L) +mesh = fd.ExtrudedMesh(base_mesh, nx, L/nx) + +x, y = fd.SpatialCoordinate(mesh) + +finit = fd.exp(-((x-0.5*L)**2 + (y-0.5*L)**2)) + +domain = gusto.Domain(mesh, dt, family='CG', degree=1) + +V = domain.spaces("DG") + +eqn = gusto.AdvectionEquation(domain, V, "f") + +residual = eqn.residual + +def form_mass(u, v): + M = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + M = M.label_map(all_terms, replace_subject(u, idx=0)) + M = M.label_map(all_terms, replace_test_function(v, idx=0)) + return M.form + +V = eqn.function_space + +u = fd.TrialFunction(V) +v = fd.TestFunction(V) + +M = form_mass(u, v) + +C = cpx.FunctionSpace(V) + +# print(V) +# print() +# print(C) + +N = cpx.BilinearForm(C, 1+0j, form_mass) diff --git a/examples/test_gusto_replace.py b/examples/test_gusto_replace.py new file mode 100644 index 00000000..401bf204 --- /dev/null +++ b/examples/test_gusto_replace.py @@ -0,0 +1,43 @@ +import firedrake as fd +import gusto + +mesh = fd.UnitIcosahedralSphereMesh(refinement_level=1, degree=1) +swe_params = gusto.ShallowWaterParameters(H=1, g=1, Omega=1) +domain = gusto.Domain(mesh, dt=1, family='BDM', degree=1) +eqn = gusto.ShallowWaterEquations(domain, swe_params) + +residual = eqn.residual + +from gusto.fml.labels import replace_subject, replace_test_function, replace_trial_function, time_derivative +from gusto.fml.form_manipulation_labelling import all_terms, drop + +def form_mass(u, h, v, q): + M = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + M = M.label_map(all_terms, replace_test_function(v, idx=0)) + M = M.label_map(all_terms, replace_test_function(q, idx=1)) + + M = M.label_map(all_terms, replace_subject(u, idx=0)) + M = M.label_map(all_terms, replace_subject(h, idx=1)) + return M.form + +V = eqn.function_space +Vu = V.subfunctions[0] +Vh = V.subfunctions[1] + +u, h = fd.TrialFunctions(V) +v, q = fd.TestFunctions(V) + +W = V*V +u0, h0, u_, h_ = fd.split(fd.Function(W)) +v0, q0, v_, q_ = fd.TestFunctions(W) + +W = fd.MixedFunctionSpace((Vu, Vu, Vh, Vh)) +u1, u_, h1, h_ = fd.split(fd.Function(W)) +v1, v_, q1, q_ = fd.TestFunctions(W) + + +M = form_mass(u, h, v, q) +M0 = form_mass(u0, h0, v0, q0) +M1 = form_mass(u1, h1, v1, q1) diff --git a/examples/test_replace.py b/examples/test_replace.py new file mode 100644 index 00000000..7be306ba --- /dev/null +++ b/examples/test_replace.py @@ -0,0 +1,71 @@ +import firedrake as fd +import gusto + +from gusto import Term +from gusto.fml import replace_subject, replace_test_function, replace_trial_function +from gusto.labels import time_derivative, prognostic +from gusto.fml import all_terms, drop + +from firedrake.formmanipulation import split_form + +# set up the equation +mesh = fd.UnitIcosahedralSphereMesh(refinement_level=1, degree=1) +swe_params = gusto.ShallowWaterParameters(H=1, g=1, Omega=1) +domain = gusto.Domain(mesh, dt=1, family='BDM', degree=1) +eqn = gusto.ShallowWaterEquations(domain, swe_params) + +residual = eqn.residual + +print() +print("Original mass matrix:") +print(residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop).form) +print() + +# function to replace the test function and subject of the time-derivative with different test function and subject. +def form_mass(u, h, v, q): + M = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + M = M.label_map(all_terms, replace_subject((u, h))) + M = M.label_map(all_terms, replace_test_function((v, q))) + return M.form + +V = eqn.function_space +Vu = V.subfunctions[0] +Vh = V.subfunctions[1] + +# the usual FunctionSpace + +u, h = fd.split(fd.Function(V)) +v, q = fd.TestFunctions(V) + +M = form_mass(u, h, v, q) +print("Simple mass matrix replacement:") +print(u, h, v, q) +print(M) +print() + +# A FunctionSpace for multiple timesteps (u0, h0, u1, h1) + +W0 = V*V +u0, h0, u1, h1 = fd.split(fd.Function(W0)) +v0, q0, v1, q1 = fd.TestFunctions(W0) + +Mt = form_mass(u1, h1, v1, q1) +print("Multiple timesteps mass matrix replacement:") +print(u1, h1, v1, q1) +print(Mt) +print() + +# A FunctionSpace to proxy the complex problem (real u, imag u, real h, imag h) + +W1 = fd.MixedFunctionSpace((Vu, Vu, Vh, Vh)) +ur, ui, hr, hi = fd.split(fd.Function(W1)) +vr, vi, qr, qi = fd.TestFunctions(W1) + +Mc = form_mass(ui, hi, vi, qi) +print("Complex mass matrix replacement:") +print(ui, hi, vi, qi) +print(Mc) +print() diff --git a/examples/test_replace_mixed.py b/examples/test_replace_mixed.py new file mode 100644 index 00000000..c2d87639 --- /dev/null +++ b/examples/test_replace_mixed.py @@ -0,0 +1,42 @@ +import firedrake as fd +import ufl + +mesh = fd.UnitSquareMesh(8, 8) + +Vu = fd.FunctionSpace(mesh, "RT", 2) +Vh = fd.FunctionSpace(mesh, "DG", 1) + +V = Vu * Vh + +def form(sigma, u, tau, v): + return (fd.dot(sigma, tau) + fd.div(tau)*u + fd.div(sigma)*v)*fd.dx + # return (fd.inner(u, v) + fd.inner(h, q))*fd.dx + +sigma, u = fd.split(fd.Function(V)) +tau, v = fd.TestFunctions(V) + +eqn = form(sigma, u, tau, v) + +# W = Vu * Vh * Vu * Vh +# sigma0, u0, sigma1, u1 = fd.split(fd.Function(W)) +# tau0, v0, tau1, v1 = fd.TestFunctions(W) + +W = Vu * Vu * Vh * Vh +sigma0, sigma1, u0, u1 = fd.split(fd.Function(W)) +tau0, tau1, v0, v1 = fd.TestFunctions(W) + +# first eqn + +new0 = ufl.replace(eqn, {tau: tau0}) +new0 = ufl.replace(new0, {v: v0}) + +new0 = ufl.replace(new0, {sigma: sigma0}) +new0 = ufl.replace(new0, {u: u0}) + +# second eqn + +new1 = ufl.replace(eqn, {tau: tau1}) +new1 = ufl.replace(new1, {v: v1}) + +new1 = ufl.replace(new1, {sigma: sigma1}) +new1 = ufl.replace(new1, {u: u1}) diff --git a/examples/test_replace_primal.py b/examples/test_replace_primal.py new file mode 100644 index 00000000..a104fead --- /dev/null +++ b/examples/test_replace_primal.py @@ -0,0 +1,25 @@ +import firedrake as fd +import ufl + +mesh = fd.UnitIntervalMesh(8) + +V = fd.FunctionSpace(mesh, "CG", 1) + +def form(u, v): + return fd.inner(u, v)*fd.dx + +u = fd.TrialFunction(V) +v = fd.TestFunction(V) + +eqn = form(u, v) + +W = V*V + +u0, u1 = fd.TrialFunctions(W) +v0, v1 = fd.TestFunctions(W) + +new0 = ufl.replace(eqn, {u: u0}) +new0 = ufl.replace(new0, {v: v0}) + +new1 = ufl.replace(eqn, {u: u1}) +new1 = ufl.replace(new1, {v: v1}) diff --git a/examples/test_swe_replace.py b/examples/test_swe_replace.py new file mode 100644 index 00000000..b5546b3d --- /dev/null +++ b/examples/test_swe_replace.py @@ -0,0 +1,69 @@ +from utils import shallow_water as swe +from utils.planets import earth +import firedrake as fd +import ufl + +from functools import partial + +mesh = swe.create_mg_globe_mesh(coords_degree=1) + +W = swe.default_function_space(mesh) +Vu, Vh = W.subfunctions[:] + +g = earth.Gravity +b = fd.Constant(0) +f = fd.Constant(0) +t = fd.Constant(0) + +form_mass = lambda *args: swe.nonlinear.form_mass(mesh, *args) +form_function = lambda *args: swe.nonlinear.form_function(mesh, g, b, f, *args, t) + +form = form_function + +u, h = fd.TrialFunctions(W) +v, q = fd.TestFunctions(W) + +M = form(u, h, v, q) + +# replace from W + +u0, h0 = fd.TrialFunctions(W) +v0, q0 = fd.TestFunctions(W) + +# subject +M0 = ufl.replace(M, {u:u0}) +M0 = ufl.replace(M0, {h:h0}) + +# tests +M0 = ufl.replace(M0, {v:v0}) +M0 = ufl.replace(M0, {q:q0}) + +# replace from (W * W) + +WW = W * W + +u1, h1, u_, h_ = fd.TrialFunctions(WW) +v1, q1, v_, q_ = fd.TestFunctions(WW) + +# subject +M1 = ufl.replace(M, {u:u1}) +M1 = ufl.replace(M1, {h:h1}) + +# tests +M1 = ufl.replace(M1, {v:v1}) +M1 = ufl.replace(M1, {q:q1}) + +# replace from (Vu * Vu * Vh * Vh) + +VV = Vu * Vu * Vh * Vh + +u2, u_, h2, h_ = fd.TrialFunctions(VV) +v2, v_, q2, q_ = fd.TestFunctions(VV) + +# subject +M2 = ufl.replace(M, {u:u2}) +M2 = ufl.replace(M2, {h:h2}) + +# tests +M2 = ufl.replace(M2, {v:v2}) +M2 = ufl.replace(M2, {q:q2})