diff --git a/examples/simple-dg.py b/examples/simple-dg.py index 964e0a12d..84bea31a4 100644 --- a/examples/simple-dg.py +++ b/examples/simple-dg.py @@ -33,7 +33,7 @@ from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from meshmode.dof_array import DOFArray, flat_norm from meshmode.array_context import (PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext) + SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext) from arraycontext import ( ArrayContainer, map_array_container, @@ -455,11 +455,10 @@ def main(lazy=False): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) - actx_outer = PyOpenCLArrayContext(queue, force_device_scalars=True) if lazy: - actx_rhs = PytatoPyOpenCLArrayContext(queue) + actx = PytatoPyOpenCLArrayContext(queue) else: - actx_rhs = actx_outer + actx = PyOpenCLArrayContext(queue, force_device_scalars=True) nel_1d = 16 from meshmode.mesh.generation import generate_regular_rect_mesh @@ -475,37 +474,34 @@ def main(lazy=False): logger.info("%d elements", mesh.nelements) - discr = DGDiscretization(actx_outer, mesh, order=order) + discr = DGDiscretization(actx, mesh, order=order) fields = WaveState( - u=bump(actx_outer, discr), - v=make_obj_array([discr.zeros(actx_outer) for i in range(discr.dim)]), + u=bump(actx, discr), + v=make_obj_array([discr.zeros(actx) for i in range(discr.dim)]), ) from meshmode.discretization.visualization import make_visualizer - vis = make_visualizer(actx_outer, discr.volume_discr) + vis = make_visualizer(actx, discr.volume_discr) def rhs(t, q): - return wave_operator(actx_rhs, discr, c=1, q=q) + return wave_operator(actx, discr, c=1, q=q) - compiled_rhs = actx_rhs.compile(rhs) - - def rhs_wrapper(t, q): - r = compiled_rhs(t, actx_rhs.thaw(actx_outer.freeze(q))) - return actx_outer.thaw(actx_rhs.freeze(r)) + compiled_rhs = actx.compile(rhs) t = np.float64(0) t_final = 3 istep = 0 while t < t_final: - fields = rk4_step(fields, t, dt, rhs_wrapper) + fields = actx.thaw(actx.freeze(fields,)) + fields = rk4_step(fields, t, dt, compiled_rhs) if istep % 10 == 0: # FIXME: Maybe an integral function to go with the # DOFArray would be nice? assert len(fields.u) == 1 logger.info("[%05d] t %.5e / %.5e norm %.5e", - istep, t, t_final, actx_outer.to_numpy(flat_norm(fields.u, 2))) + istep, t, t_final, actx.to_numpy(flat_norm(fields.u, 2))) vis.write_vtk_file("fld-wave-min-%04d.vtu" % istep, [ ("q", fields), ]) @@ -513,7 +509,7 @@ def rhs_wrapper(t, q): t += dt istep += 1 - assert flat_norm(fields.u, 2) < 100 + assert actx.to_numpy(flat_norm(fields.u, 2)) < 100 if __name__ == "__main__": diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 259ceabc3..a1e9b0c8b 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -26,13 +26,35 @@ """ import sys +import logging +import numpy as np + from warnings import warn +from typing import Union, FrozenSet, Tuple, Any from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase from arraycontext.pytest import ( _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory) +from loopy.translation_unit import for_each_kernel + +from loopy.tools import memoize_on_disk +from pytools import ProcessLogger, memoize_on_first_arg +from pytools.tag import UniqueTag, tag_dataclass + +from meshmode.transform_metadata import (DiscretizationElementAxisTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationDimAxisTag, + DiscretizationTopologicalDimAxisTag, + DiscretizationAmbientDimAxisTag, + DiscretizationFlattenedDOFAxisTag, + DiscretizationEntityAxisTag) +from dataclasses import dataclass + +from pyrsistent import pmap +logger = logging.getLogger(__name__) def thaw(actx, ary): @@ -345,4 +367,1489 @@ def _import_names(): # }}} +@for_each_kernel +def _single_grid_work_group_transform(kernel, cl_device): + import loopy as lp + from meshmode.transform_metadata import (ConcurrentElementInameTag, + ConcurrentDOFInameTag) + + splayed_inames = set() + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + for insn in kernel.instructions: + if insn.within_inames in splayed_inames: + continue + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + + splayed_inames.add(insn.within_inames) + continue + + for iname in insn.within_inames: + if kernel.iname_tags_of_type(iname, + ConcurrentElementInameTag): + assert bigger_loop is None + bigger_loop = iname + elif kernel.iname_tags_of_type(iname, + ConcurrentDOFInameTag): + assert smaller_loop is None + smaller_loop = iname + else: + pass + + if bigger_loop or smaller_loop: + assert (bigger_loop is not None + and smaller_loop is not None) + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=kernel.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + splayed_inames.add(insn.within_inames) + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + + +def _alias_global_temporaries(t_unit): + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + """ + from loopy.kernel.data import AddressSpace + from loopy.kernel import KernelState + from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, + CallKernel, ReturnFromKernel, Barrier) + from loopy.schedule.tools import get_return_from_kernel_mapping + from pytools import UniqueNameGenerator + from collections import defaultdict + + kernel = t_unit.default_entrypoint + assert kernel.state == KernelState.LINEARIZED + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + temp_to_live_interval_start = {} + temp_to_live_interval_end = {} + return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) + + for sched_idx, sched_item in enumerate(kernel.linearization): + if isinstance(sched_item, RunInstruction): + for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = sched_idx + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] + elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, + ReturnFromKernel, Barrier)): + # no variables are accessed within these schedule items => do + # nothing. + pass + else: + raise NotImplementedError(type(sched_item)) + + vng = UniqueNameGenerator() + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage = defaultdict(set) + + sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] + sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] + + for tv, just_alive_idx in temp_to_live_interval_start.items(): + sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) + + for tv, just_dead_idx in temp_to_live_interval_end.items(): + sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) + + new_tvs = {} + + for sched_idx, _ in enumerate(kernel.linearization): + just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] + to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + for tv_name in sorted(to_be_allocated_temps): + assert len(to_be_allocated_temps) <= 1 + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_msh_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + # FIXME: Need tighter assertion condition (this doesn't work when + # zero-size arrays are present) + # assert name in new_tvs + pass + + kernel = kernel.copy(temporary_variables=new_tvs) + + return t_unit.with_kernel(kernel) + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +def deduplicate_data_wrappers(dag): + import pytato as pt + data_wrapper_cache = {} + data_wrappers_encountered = 0 + + def cached_data_wrapper_if_present(ary): + nonlocal data_wrappers_encountered + + if isinstance(ary, pt.DataWrapper): + + data_wrappers_encountered += 1 + cache_key = (ary.data.base_data.int_ptr, ary.data.offset, + ary.shape, ary.data.strides) + try: + result = data_wrapper_cache[cache_key] + except KeyError: + result = ary + data_wrapper_cache[cache_key] = result + + return result + else: + return ary + + dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) + + if data_wrappers_encountered: + logger.info("data wrapper de-duplication: " + "%d encountered, %d kept, %d eliminated", + data_wrappers_encountered, + len(data_wrapper_cache), + data_wrappers_encountered - len(data_wrapper_cache)) + + return dag + + +class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): + """ + A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL + kernel so that the work + """ + def transform_loopy_program(self, t_unit): + import loopy as lp + + t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) + t_unit = lp.set_options(t_unit, "insert_gbarriers") + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + + return t_unit + + def _get_fake_numpy_namespace(self): + from meshmode.pytato_utils import ( + EagerReduceComputingPytatoFakeNumpyNamespace) + return EagerReduceComputingPytatoFakeNumpyNamespace(self) + + def transform_dag(self, dag): + import pytato as pt + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_vec(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) + + # }}} + + # {{{ materialize all einsums + + def materialize_einsums(ary: pt.Array) -> pt.Array: + if isinstance(ary, pt.Einsum): + return ary.tagged(pt.tags.ImplStored()) + + return ary + + dag = pt.transform.map_and_copy(dag, materialize_einsums) + + # }}} + + dag = pt.transform.materialize_with_mpms(dag) + dag = deduplicate_data_wrappers(dag) + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + + +def get_temps_not_to_contract(knl): + from functools import reduce + wmap = knl.writer_map() + rmap = knl.reader_map() + + temps_not_to_contract = set() + for tv in knl.temporary_variables: + if len(wmap.get(tv, set())) == 1: + writer_id, = wmap[tv] + writer_loop_nest = knl.id_to_insn[writer_id].within_inames + insns_in_writer_loop_nest = reduce(frozenset.union, + (knl.iname_to_insns()[iname] + for iname in writer_loop_nest), + frozenset()) + if ( + (not (rmap.get(tv, frozenset()) + <= insns_in_writer_loop_nest)) + or len(knl.id_to_insn[writer_id].reduction_inames()) != 0 + or any((len(knl.id_to_insn[reader_id].reduction_inames()) != 0) + for reader_id in rmap.get(tv, frozenset()))): + temps_not_to_contract.add(tv) + else: + temps_not_to_contract.add(tv) + return temps_not_to_contract + + # Better way to query it... + # import loopy as lp + # from kanren.constraints import neq as kanren_neq + # + # tempo = lp.relations.get_tempo(knl) + # producero = lp.relations.get_producero(knl) + # consumero = lp.relations.get_consumero(knl) + # withino = lp.relations.get_withino(knl) + # reduce_insno = lp.relations.get_reduce_insno(knl) + # + # # temp_k: temporary variable that cannot be contracted + # temp_k = kanren.var() + # producer_insn_k = kanren.var() + # producer_loops_k = kanren.var() + # consumer_insn_k = kanren.var() + # consumer_loops_k = kanren.var() + + # temps_not_to_contract = kanren.run(0, + # temp_k, + # tempo(temp_k), + # producero(producer_insn_k, + # temp_k), + # consumero(consumer_insn_k, + # temp_k), + # withino(producer_insn_k, + # producer_loops_k), + # withino(consumer_insn_k, + # consumer_loops_k), + # kanren.lany( + # kanren_neq( + # producer_loops_k, + # consumer_loops_k), + # reduce_insno(consumer_insn_k)), + # results_filter=frozenset) + # return temps_not_to_contract + + +def _is_iel_loop_part_of_global_dof_loops(iel: str, knl) -> bool: + insn, = knl.iname_to_insns()[iel] + return any(iname + for iname in knl.id_to_insn[insn].within_inames + if knl.iname_tags_of_type(iname, DiscretizationDOFAxisTag)) + + +def _discr_entity_sort_key(discr_tag: DiscretizationEntityAxisTag + ) -> Tuple[Any, ...]: + + return type(discr_tag).__name__ + + +# {{{ define FEMEinsumTag + +@dataclass(frozen=True) +class EinsumIndex: + discr_entity: DiscretizationEntityAxisTag + length: int + + @classmethod + def from_iname(cls, iname, kernel): + discr_entity, = kernel.filter_iname_tags_by_type( + iname, DiscretizationEntityAxisTag) + length = kernel.get_constant_iname_length(iname) + return cls(discr_entity, length) + + +@dataclass(frozen=True) +class FreeEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class SummationEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class FEMEinsumTag(UniqueTag): + indices: Tuple[Tuple[EinsumIndex, ...], ...] + + +class NotAnFEMEinsumError(ValueError): + """ + pass + """ + +# }}} + + +@memoize_on_first_arg +def _get_redn_iname_to_insns(kernel): + from immutables import Map + redn_iname_to_insns = {iname: set() + for iname in kernel.all_inames()} + + for insn in kernel.instructions: + for redn_iname in insn.reduction_inames(): + redn_iname_to_insns[redn_iname].add(insn.id) + + return Map({k: frozenset(v) + for k, v in redn_iname_to_insns.items()}) + + +def _do_inames_belong_to_different_einsum_types(iname1, iname2, kernel): + if kernel.iname_to_insns()[iname1]: + assert (len(kernel.iname_to_insns()[iname1]) + == len(kernel.iname_to_insns()[iname2]) + == 1) + insn1, = kernel.iname_to_insns()[iname1] + insn2, = kernel.iname_to_insns()[iname2] + else: + redn_iname_to_insns = _get_redn_iname_to_insns(kernel) + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + insn1, = redn_iname_to_insns[iname1] + insn2, = redn_iname_to_insns[iname2] + + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + + var1_name, = kernel.id_to_insn[insn1].assignee_var_names() + var2_name, = kernel.id_to_insn[insn2].assignee_var_names() + var1 = kernel.get_var_descriptor(var1_name) + var2 = kernel.get_var_descriptor(var2_name) + + ensm1, = var1.tags_of_type(FEMEinsumTag) + ensm2, = var2.tags_of_type(FEMEinsumTag) + + return ensm1 != ensm2 + + +def _fuse_loops_over_a_discr_entity(knl, + mesh_entity, + fused_loop_prefix, + should_fuse_redn_loops, + orig_knl): + import loopy as lp + import kanren + from functools import reduce, partial + taggedo = lp.relations.get_taggedo_of_type(orig_knl, mesh_entity) + + redn_loops = reduce(frozenset.union, + (insn.reduction_inames() + for insn in orig_knl.instructions), + frozenset()) + + non_redn_loops = reduce(frozenset.union, + (insn.within_inames + for insn in orig_knl.instructions), + frozenset()) + + # tag_k: tag of type 'mesh_entity' + tag_k = kanren.var() + tags = kanren.run(0, + tag_k, + taggedo(kanren.var(), tag_k), + results_filter=frozenset) + for itag, tag in enumerate( + sorted(tags, key=lambda x: _discr_entity_sort_key(x))): + # iname_k: iname tagged with 'tag' + iname_k = kanren.var() + inames = kanren.run(0, + iname_k, + taggedo(iname_k, tag), + results_filter=frozenset) + inames = frozenset(inames) + if should_fuse_redn_loops: + inames = inames & redn_loops + else: + inames = inames & non_redn_loops + + length_to_inames = {} + for iname in inames: + length = knl.get_constant_iname_length(iname) + length_to_inames.setdefault(length, set()).add(iname) + + for i, (_, inames_to_fuse) in enumerate( + sorted(length_to_inames.items())): + + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, inames_to_fuse, + prefix=f"{fused_loop_prefix}_{itag}_{i}_", + force_infusible=partial( + _do_inames_belong_to_different_einsum_types, + kernel=orig_knl), + )) + knl = lp.tag_inames(knl, {f"{fused_loop_prefix}_{itag}_*": tag}) + + return knl + + +@memoize_on_disk +def fuse_same_discretization_entity_loops(knl): + # maintain an 'orig_knl' to keep the original iname and tags before + # transforming it. + orig_knl = knl + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationElementAxisTag, + "iel", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + False, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + True, + orig_knl) + + return knl + + +@memoize_on_disk +def contract_arrays(knl, callables_table): + import loopy as lp + from loopy.transform.precompute import precompute_for_single_kernel + + temps_not_to_contract = get_temps_not_to_contract(knl) + all_temps = frozenset(knl.temporary_variables) + + logger.info("Array Contraction: Contracting " + f"{len(all_temps-frozenset(temps_not_to_contract))} temps") + + wmap = knl.writer_map() + + for temp in sorted(all_temps - frozenset(temps_not_to_contract)): + writer_id, = wmap[temp] + rmap = knl.reader_map() + ensm_tag, = knl.id_to_insn[writer_id].tags_of_type(EinsumTag) + + knl = lp.assignment_to_subst(knl, temp, + remove_newly_unused_inames=False) + if temp not in rmap: + # no one was reading 'temp' i.e. dead code got eliminated :) + assert f"{temp}_subst" not in knl.substitutions + continue + knl = precompute_for_single_kernel( + knl, callables_table, f"{temp}_subst", + sweep_inames=(), + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id=f"_mm_contract_{temp}", + ) + + knl = lp.map_instructions(knl, + f"id:_mm_contract_{temp}", + lambda x: x.tagged(ensm_tag)) + + return lp.remove_unused_inames(knl) + + +def _get_group_size_for_dof_array_loop(nunit_dofs): + """ + Returns the OpenCL workgroup size for a loop iterating over the global DOFs + of a discretization with *nunit_dofs* per cell. + """ + if nunit_dofs == {6}: + return 16, 6 + elif nunit_dofs == {10}: + return 16, 10 + elif nunit_dofs == {20}: + return 16, 10 + elif nunit_dofs == {1}: + return 32, 1 + elif nunit_dofs == {2}: + return 32, 2 + elif nunit_dofs == {4}: + return 16, 4 + elif nunit_dofs == {3}: + return 32, 3 + elif nunit_dofs == {35}: + return 9, 7 + elif nunit_dofs == {15}: + return 8, 8 + else: + raise NotImplementedError(nunit_dofs) + + +def _get_iel_to_idofs(kernel): + iel_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))) + } + idof_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDOFAxisTag)) + } + iface_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationFaceAxisTag)) + } + idim_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDimAxisTag)) + } + + iel_to_idofs = {iel: set() for iel in iel_inames} + + for insn in kernel.instructions: + if (len(insn.within_inames) == 1 + and (insn.within_inames) <= iel_inames): + iel, = insn.within_inames + if all(kernel.id_to_insn[el_insn].within_inames == insn.within_inames + for el_insn in kernel.iname_to_insns()[iel]): + # the iel here doesn't interfere with any idof i.e. we + # support parallelizing such loops. + pass + else: + raise NotImplementedError(f"The loop {insn.within_inames}" + " does not appear as a singly nested" + " loop.") + elif ((len(insn.within_inames) == 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1)): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((iel in kernel.id_to_insn[dof_insn].within_inames) + for dof_insn in kernel.iname_to_insns()[idof]): + pass + else: + raise NotImplementedError("The loop " + f"'{insn.within_inames}' has the idof-loop" + " that's not nested within the iel-loop.") + elif ((len(insn.within_inames) > 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1) + and (len(insn.within_inames & (idim_inames | iface_inames)) + == (len(insn.within_inames) - 2))): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((all({iel, idof} <= kernel.id_to_insn[non_iel_insn].within_inames + for non_iel_insn in kernel.iname_to_insns()[non_iel_iname])) + for non_iel_iname in insn.within_inames - {iel}): + iel_to_idofs[iel].add(idof) + else: + raise NotImplementedError("Could not fit into " + " loop nest pattern.") + else: + raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" + " into known set of loop-nest patterns.") + + return pmap({iel: frozenset(idofs) + for iel, idofs in iel_to_idofs.items()}) + + +def _get_iel_loop_from_insn(insn, knl): + iel, = {iname + for iname in insn.within_inames + if knl.inames[iname].tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))} + return iel + + +def _get_element_loop_topo_sorted_order(knl): + dag = {iel: set() + for iel in knl.all_inames() + if knl.inames[iel].tags_of_type(DiscretizationElementAxisTag)} + + for insn in knl.instructions: + succ_iel = _get_iel_loop_from_insn(insn, knl) + for dep_id in insn.depends_on: + pred_iel = _get_iel_loop_from_insn(knl.id_to_insn[dep_id], knl) + if pred_iel != succ_iel: + dag[pred_iel].add(succ_iel) + + from pytools.graph import compute_topological_order + return compute_topological_order(dag, key=lambda x: x) + + +@tag_dataclass +class EinsumTag(UniqueTag): + orig_loop_nest: FrozenSet[str] + + +def _prepare_kernel_for_parallelization(kernel): + discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel", + DiscretizationDOFAxisTag: "idof", + DiscretizationDimAxisTag: "idim", + DiscretizationAmbientDimAxisTag: "idim", + DiscretizationTopologicalDimAxisTag: "idim", + DiscretizationFlattenedDOFAxisTag: "imsh_nodes", + DiscretizationFaceAxisTag: "iface"} + import loopy as lp + from loopy.match import ObjTagged + + # A mapping from inames that the instruction accesss to + # the instructions ids within that iname. + ensm_buckets = {} + vng = kernel.get_var_name_generator() + + for insn in kernel.instructions: + inames = insn.within_inames | insn.reduction_inames() + ensm_buckets.setdefault(tuple(sorted(inames)), set()).add(insn.id) + + # FIXME: Dependency violation is a big concern here + # Waiting on the loopy feature: https://github.com/inducer/loopy/issues/550 + + for ieinsm, (loop_nest, insns) in enumerate(sorted(ensm_buckets.items())): + new_insns = [insn.tagged(EinsumTag(frozenset(loop_nest))) + if insn.id in insns + else insn + for insn in kernel.instructions] + kernel = kernel.copy(instructions=new_insns) + + new_inames = [] + for iname in loop_nest: + discr_tag, = kernel.iname_tags_of_type(iname, + DiscretizationEntityAxisTag) + new_iname = vng(f"{discr_tag_to_prefix[type(discr_tag)]}_ensm{ieinsm}") + new_inames.append(new_iname) + + kernel = lp.duplicate_inames( + kernel, + loop_nest, + within=ObjTagged(EinsumTag(frozenset(loop_nest))), + new_inames=new_inames, + tags=kernel.iname_to_tags) + + return kernel + + +def _get_elementwise_einsum(t_unit, einsum_tag): + import loopy as lp + import feinsum as fnsm + from loopy.match import ObjTagged + from pymbolic.primitives import Variable, Subscript + + kernel = t_unit.default_entrypoint + + assert isinstance(einsum_tag, EinsumTag) + insn_match = ObjTagged(einsum_tag) + + global_vars = ({tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == lp.AddressSpace.GLOBAL} + | set(kernel.arg_dict.keys())) + insns = [insn + for insn in kernel.instructions + if insn_match(kernel, insn)] + idx_tuples = set() + + for insn in insns: + assert len(insn.assignees) == 1 + if isinstance(insn.assignee, Variable): + if insn.assignee.name in global_vars: + raise NotImplementedError(insn) + else: + assert (kernel.temporary_variables[insn.assignee.name].address_space + == lp.AddressSpace.PRIVATE) + elif isinstance(insn.assignee, Subscript): + assert insn.assignee_name in global_vars + idx_tuples.add(tuple(idx.name + for idx in insn.assignee.index_tuple)) + else: + raise NotImplementedError(insn) + + if len(idx_tuples) != 1: + raise NotImplementedError("Multiple einsums in the same loop nest =>" + " not allowed.") + idx_tuple, = idx_tuples + subscript = "{lhs}, {lhs}->{lhs}".format( + lhs="".join(chr(97+i) + for i in range(len(idx_tuple)))) + arg_shape = tuple(np.inf + if kernel.iname_tags_of_type(idx, DiscretizationElementAxisTag) + else kernel.get_constant_iname_length(idx) + for idx in idx_tuple) + return fnsm.einsum(subscript, + fnsm.array(arg_shape, "float64"), + fnsm.array(arg_shape, "float64")) + + +def _combine_einsum_domains(knl): + import islpy as isl + from functools import reduce + + new_domains = [] + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in knl.instructions), + frozenset()) + + for tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + insns = [insn + for insn in knl.instructions + if tag in insn.tags] + inames = reduce(frozenset.union, + ((insn.within_inames | insn.reduction_inames()) + for insn in insns), + frozenset()) + domain = knl.get_inames_domain(frozenset(inames)) + new_domains.append(domain.project_out_except(sorted(inames), + [isl.dim_type.set])) + + return knl.copy(domains=new_domains) + + +def _rewrite_tvs_as_base_plus_offset(t_unit, device): + import loopy as lp + knl = t_unit.default_entrypoint + vng = knl.get_var_name_generator() + nbytes_to_base_storages = {} + for tv in knl.temporary_variables.values(): + if tv.address_space == lp.AddressSpace.GLOBAL: + nbytes_to_base_storages.setdefault(tv.nbytes, + set()).add(tv.base_storage) + + nbytes_to_new_storage_name = {nbytes: vng("_mm_base_storage") + for nbytes in sorted(nbytes_to_base_storages)} + + if any(nbytes > device.max_mem_alloc_size + for nbytes in nbytes_to_new_storage_name): + raise RuntimeError("Some of the variables " + "require more memory than the CL-device " + "allows.") + + old_storage_to_new_storage_plus_offset = {} + new_storage_to_alloc_nbytes = {} + for nbytes, old_storages in nbytes_to_base_storages.items(): + new_storage_name = nbytes_to_new_storage_name[nbytes] + offset = 0 + new_storage_to_alloc_nbytes[new_storage_name] = offset + for old_storage in sorted(old_storages): + assert (offset + nbytes) < device.max_mem_alloc_size + old_storage_to_new_storage_plus_offset[old_storage] = ( + (new_storage_name, offset)) + offset = offset + nbytes + new_storage_to_alloc_nbytes[new_storage_name] = offset + if (offset + nbytes) > device.max_mem_alloc_size: + new_storage_name = vng("_mm_base_storage") + offset = 0 + + del nbytes_to_new_storage_name + + new_tvs = {} + for name, tv in knl.temporary_variables.items(): + if tv.address_space == lp.AddressSpace.GLOBAL: + new_storage_name, offset_nbytes = ( + old_storage_to_new_storage_plus_offset[tv.base_storage]) + new_storage_size = ( + new_storage_to_alloc_nbytes[new_storage_name] + // tv.dtype.numpy_dtype.itemsize) + tv = tv.copy(base_storage=new_storage_name, + offset=offset_nbytes//tv.dtype.numpy_dtype.itemsize, + storage_shape=(new_storage_size,) + (1,)*(len(tv.shape)-1) + ) + + new_tvs[name] = tv + + knl = knl.copy(temporary_variables=new_tvs) + return t_unit.with_kernel(knl) + + +class FusionContractorArrayContext( + SingleGridWorkBalancingPytatoArrayContext): + + def transform_dag(self, dag): + import pytato as pt + + # {{{ Remove FEMEinsumTags that might have been propagated + + # TODO: Is this too hacky? + + def remove_fem_einsum_tags(expr): + if isinstance(expr, pt.Array): + try: + fem_ensm_tag = next(iter(expr.tags_of_type(FEMEinsumTag))) + except StopIteration: + return expr + else: + assert isinstance(expr, pt.InputArgumentBase) + return expr.without_tags(fem_ensm_tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_fem_einsum_tags) + + # }}} + + # {{{ CSE + + with ProcessLogger(logger, "transform_dag.mpms_materialization"): + dag = pt.transform.materialize_with_mpms(dag) + + def mark_materialized_nodes_as_cse( + ary: Union[pt.Array, + pt.AbstractResultWithNamedArrays]) -> pt.Array: + if isinstance(ary, pt.AbstractResultWithNamedArrays): + return ary + + if ary.tags_of_type(pt.tags.ImplStored): + return ary.tagged(pt.tags.PrefixNamed("cse")) + else: + return ary + + with ProcessLogger(logger, "transform_dag.naming_cse"): + dag = pt.transform.map_and_copy(dag, mark_materialized_nodes_as_cse) + + # }}} + + # {{{ indirect addressing are non-negative + + indirection_maps = set() + + class _IndirectionMapRecorder(pt.transform.CachedWalkMapper): + def post_visit(self, expr): + if isinstance(expr, pt.IndexBase): + for idx in expr.indices: + if isinstance(idx, pt.Array): + indirection_maps.add(idx) + + _IndirectionMapRecorder()(dag) + + def tag_indices_as_non_negative(ary): + if ary in indirection_maps: + return ary.tagged(pt.tags.AssumeNonNegative()) + else: + return ary + + with ProcessLogger(logger, "transform_dag.tag_indices_as_non_negative"): + dag = pt.transform.map_and_copy(dag, tag_indices_as_non_negative) + + # }}} + + with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"): + dag = pt.transform.deduplicate_data_wrappers(dag) + + # {{{ get rid of copies for different views of a cl-array + + def eliminate_reshapes_of_data_wrappers(ary): + if (isinstance(ary, pt.Reshape) + and isinstance(ary.array, pt.DataWrapper)): + return pt.make_data_wrapper(ary.array.data.reshape(ary.shape), + tags=ary.tags, + axes=ary.axes) + else: + return ary + + dag = pt.transform.map_and_copy(dag, + eliminate_reshapes_of_data_wrappers) + + # }}} + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_input_and_output(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return (pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass")))) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_face_mass_ins_and_outs"): + dag = pt.transform.map_and_copy(dag, + materialize_face_mass_input_and_output) + + # }}} + + # {{{ materialize inverse mass inputs + + def materialize_inverse_mass_inputs(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ei,ij,ej->ei")): + arg1, arg2, arg3 = expr.args + if not arg3.tags_of_type(pt.tags.PrefixNamed): + arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp")) + if not arg3.tags_of_type(pt.tags.ImplStored): + arg3 = arg3.tagged(pt.tags.ImplStored()) + + return pt.Einsum(expr.access_descriptors, + (arg1, arg2, arg3), + expr.axes, + expr.redn_axis_to_redn_descr, + expr.index_to_access_descr, + expr.tags) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_inverse_mass_inputs) + + # }}} + + # {{{ materialize all einsums + + def materialize_all_einsums_or_reduces(expr): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + if isinstance(expr, pt.Einsum): + return expr.tagged(pt.tags.ImplStored()) + elif (isinstance(expr, pt.IndexLambda) + and isinstance(index_lambda_to_high_level_op(expr), ReduceOp)): + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_all_einsums_or_reduces"): + dag = pt.transform.map_and_copy(dag, materialize_all_einsums_or_reduces) + + # }}} + + # {{{ infer axis types + + from meshmode.pytato_utils import unify_discretization_entity_tags + + with ProcessLogger(logger, "transform_dag.infer_axes_tags"): + dag = unify_discretization_entity_tags(dag) + + # }}} + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + # {{{ remove broadcasts from einsums: help feinsum + + ensm_arg_rewrite_cache = {} + + def _get_rid_of_broadcasts_from_einsum(expr): + # Helpful for matching against the available expressions + # in feinsum. + + from pytato.utils import (are_shape_components_equal, + are_shapes_equal) + if isinstance(expr, pt.Einsum): + from pytato.array import EinsumElementwiseAxis + idx_to_len = expr._access_descr_to_axis_len() + new_access_descriptors = [] + new_args = [] + inp_gatherer = pt.transform.InputGatherer() + access_descr_to_axes = dict(expr.redn_axis_to_redn_descr) + for iax, axis in enumerate(expr.axes): + access_descr_to_axes[EinsumElementwiseAxis(iax)] = axis + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + new_shape = [] + new_access_descrs = [] + new_axes = [] + for iaxis, (access_descr, axis_len) in enumerate( + zip(access_descrs, + arg.shape)): + if not are_shape_components_equal(axis_len, + idx_to_len[access_descr]): + assert are_shape_components_equal(axis_len, 1) + if any(isinstance(inp, pt.Placeholder) + for inp in inp_gatherer(arg)): + # do not get rid of broadcasts from parameteric + # data. + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + new_axes.append(arg.axes[iaxis]) + else: + new_axes.append(arg.axes[iaxis]) + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + + if not are_shapes_equal(new_shape, arg.shape): + assert len(new_axes) == len(new_shape) + arg_to_freeze = (arg.reshape(new_shape) + .copy(axes=tuple( + access_descr_to_axes[acc_descr] + for acc_descr in new_access_descrs))) + + try: + new_arg = ensm_arg_rewrite_cache[arg_to_freeze] + except KeyError: + new_arg = self.thaw(self.freeze(arg_to_freeze)) + ensm_arg_rewrite_cache[arg_to_freeze] = new_arg + + arg = new_arg + + assert arg.ndim == len(new_access_descrs) + new_args.append(arg) + new_access_descriptors.append(tuple(new_access_descrs)) + + return pt.Einsum(tuple(new_access_descriptors), + tuple(new_args), + tags=expr.tags, + axes=expr.axes, + redn_axis_to_redn_descr=(expr + .redn_axis_to_redn_descr), + index_to_access_descr=expr.index_to_access_descr) + else: + return expr + + dag = pt.transform.map_and_copy(dag, _get_rid_of_broadcasts_from_einsum) + + # }}} + + # {{{ remove any PartID tags + + from pytato.distributed import PartIDTag + + def remove_part_id_tags(expr): + if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): + tag, = expr.tags_of_type(PartIDTag) + return expr.without_tags(tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_part_id_tags) + + # }}} + + # {{{ attach FEMEinsumTag tags + + dag_outputs = frozenset(dag._data.values()) + + def add_fem_einsum_tags(expr): + if isinstance(expr, pt.Einsum): + from pytato.array import (EinsumElementwiseAxis, + EinsumReductionAxis) + assert expr.tags_of_type(pt.tags.ImplStored) + ensm_indices = [] + for arg, access_descrs in zip(expr.args, + expr.access_descriptors): + arg_indices = [] + for iaxis, access_descr in enumerate(access_descrs): + try: + discr_tag = next( + iter(arg + .axes[iaxis] + .tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + if isinstance(access_descr, EinsumElementwiseAxis): + arg_indices.append(FreeEinsumIndex(discr_tag, + arg.shape[iaxis])) + elif isinstance(access_descr, EinsumReductionAxis): + arg_indices.append(SummationEinsumIndex( + discr_tag, + arg.shape[iaxis])) + else: + raise NotImplementedError(access_descr) + ensm_indices.append(tuple(arg_indices)) + + return expr.tagged(FEMEinsumTag(tuple(ensm_indices))) + elif (isinstance(expr, pt.Array) + and (expr.tags_of_type(pt.tags.ImplStored) + or expr in dag_outputs)): + if (isinstance(expr, pt.IndexLambda) + and expr.var_to_reduction_descr): + raise NotImplementedError("pure reductions not implemented") + else: + discr_tags = [] + for axis in expr.axes: + try: + discr_tag = next( + iter(axis.tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + discr_tags.append(discr_tag) + + fem_ensm_tag = FEMEinsumTag( + (tuple(FreeEinsumIndex(discr_tag, dim) + for dim, discr_tag in zip(expr.shape, + discr_tags)),) * 2 + ) + + return expr.tagged(fem_ensm_tag) + + else: + return expr + + try: + dag = pt.transform.map_and_copy(dag, add_fem_einsum_tags) + except NotAnFEMEinsumError: + pass + + # }}} + + # {{{ untag outputs tagged from being tagged ImplStored + + def _untag_impl_stored(expr): + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.without_tags(pt.tags.ImplStored(), + verify_existence=False) + + dag = pt.make_dict_of_named_arrays({ + name: _untag_impl_stored(named_ary.expr) + for name, named_ary in dag.items()}) + + # }}} + + return dag + + def transform_loopy_program(self, t_unit): + import loopy as lp + from functools import reduce + from arraycontext.impl.pytato.compile import FromArrayContextCompile + + original_t_unit = t_unit + + # from loopy.transform.instruction import simplify_indices + # t_unit = simplify_indices(t_unit) + + knl = t_unit.default_entrypoint + + logger.info(f"Transforming kernel with {len(knl.instructions)} statements.") + + # {{{ fallback: if the inames are not inferred which mesh entity they + # iterate over. + + for iname in knl.all_inames(): + if not knl.iname_tags_of_type(iname, DiscretizationEntityAxisTag): + warn("Falling back to a slower transformation strategy as some" + " loops are uninferred which mesh entity they belong to.", + stacklevel=2) + + return super().transform_loopy_program(original_t_unit) + + for insn in knl.instructions: + for assignee in insn.assignee_var_names(): + var = knl.get_var_descriptor(assignee) + if not var.tags_of_type(FEMEinsumTag): + warn("Falling back to a slower transformation strategy as some" + " instructions couldn't be inferred as einsums", + stacklevel=2) + + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ hardcode offset to 0 (sorry humanity) + + knl = knl.copy(args=[arg.copy(offset=0) + for arg in knl.args]) + + # }}} + + # {{{ loop fusion + + with ProcessLogger(logger, "Loop Fusion"): + knl = fuse_same_discretization_entity_loops(knl) + + # }}} + + # {{{ align kernels for fused einsums + + knl = _prepare_kernel_for_parallelization(knl) + knl = _combine_einsum_domains(knl) + + # }}} + + # {{{ array contraction + + with ProcessLogger(logger, "Array Contraction"): + knl = contract_arrays(knl, t_unit.callables_table) + + # }}} + + # {{{ Stats Collection (Disabled) + + if 0: + with ProcessLogger(logger, "Counting Kernel Ops"): + from loopy.kernel.array import ArrayBase + from pytools import product + knl = knl.copy( + silenced_warnings=(knl.silenced_warnings + + ["insn_count_subgroups_upper_bound", + "summing_if_branches_ops"])) + + t_unit = t_unit.with_kernel(knl) + + op_map = lp.get_op_map(t_unit, subgroup_size=32) + + c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + f32_ops = ((op_map.filter_by(dtype=[np.float32], + kernel_name=knl.name) + .eval_and_sum({})) + + (2 * c64_ops["add"] + + 6 * c64_ops["mul"] + + (6 + 3 + 2) * c64_ops["div"])) + f64_ops = ((op_map.filter_by(dtype=[np.float64], + kernel_name="_pt_kernel") + .eval_and_sum({})) + + (2 * c128_ops["add"] + + 6 * c128_ops["mul"] + + (6 + 3 + 2) * c128_ops["div"])) + + # {{{ footprint gathering + + nfootprint_bytes = 0 + + for ary in knl.args: + if (isinstance(ary, ArrayBase) + and ary.address_space == lp.AddressSpace.GLOBAL): + nfootprint_bytes += (product(ary.shape) + * ary.dtype.itemsize) + + for ary in knl.temporary_variables.values(): + if ary.address_space == lp.AddressSpace.GLOBAL: + # global temps would be written once and read once + nfootprint_bytes += (2 * product(ary.shape) + * ary.dtype.itemsize) + + # }}} + + if f32_ops: + logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}") + if f64_ops: + logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}") + logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}") + + # }}} + + # {{{ check whether we can parallelize the kernel + + try: + iel_to_idofs = _get_iel_to_idofs(knl) + except NotImplementedError as err: + if knl.tags_of_type(FromArrayContextCompile): + raise err + else: + warn("FusionContractorArrayContext.transform_loopy_program not" + " broad enough (yet). Falling back to a possibly slower" + " transformation strategy.") + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ insert barriers between consecutive iel-loops + + toposorted_iels = _get_element_loop_topo_sorted_order(knl) + + for iel_pred, iel_succ in zip(toposorted_iels[:-1], + toposorted_iels[1:]): + knl = lp.add_barrier(knl, + insn_before=f"iname:{iel_pred}", + insn_after=f"iname:{iel_succ}") + + # }}} + + # {{{ Parallelization strategy: Use feinsum + + t_unit = t_unit.with_kernel(knl) + del knl + + if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile): + # FIXME: Enable this branch, WIP for now and hence disabled it. + from loopy.match import ObjTagged + import feinsum as fnsm + from meshmode.feinsum_transformations import FEINSUM_TO_TRANSFORMS + + assert all(insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions + if isinstance(insn, lp.MultiAssignmentBase) + ) + + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions), + frozenset()) + for ensm_tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + if reduce(frozenset.union, + (insn.reduction_inames() + for insn in (t_unit.default_entrypoint.instructions) + if ensm_tag in insn.tags), + frozenset()): + fused_einsum = fnsm.match_einsum(t_unit, ObjTagged(ensm_tag)) + else: + # elementwise loop + fused_einsum = _get_elementwise_einsum(t_unit, ensm_tag) + + try: + fnsm_transform = FEINSUM_TO_TRANSFORMS[ + fnsm.normalize_einsum(fused_einsum)] + except KeyError: + fnsm.query(fused_einsum, + self.queue.context, + err_if_no_results=True) + 1/0 + + t_unit = fnsm_transform(t_unit, + insn_match=ObjTagged(ensm_tag)) + else: + knl = t_unit.default_entrypoint + for iel, idofs in sorted(iel_to_idofs.items()): + if idofs: + nunit_dofs = {knl.get_constant_iname_length(idof) + for idof in idofs} + idof, = idofs + + l_one_size, l_zero_size = _get_group_size_for_dof_array_loop( + nunit_dofs) + + knl = lp.split_iname(knl, iel, l_one_size, + inner_tag="l.1", outer_tag="g.0") + knl = lp.split_iname(knl, idof, l_zero_size, + inner_tag="l.0", outer_tag="unr") + else: + knl = lp.split_iname(knl, iel, 32, + outer_tag="g.0", inner_tag="l.0") + + t_unit = t_unit.with_kernel(knl) + + # }}} + + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + t_unit = _rewrite_tvs_as_base_plus_offset(t_unit, self.queue.device) + + return t_unit + # vim: foldmethod=marker diff --git a/meshmode/discretization/__init__.py b/meshmode/discretization/__init__.py index ff1d19e27..386ea50ca 100644 --- a/meshmode/discretization/__init__.py +++ b/meshmode/discretization/__init__.py @@ -40,9 +40,9 @@ from pytools import memoize_in, memoize_method, keyed_memoize_in from pytools.obj_array import make_obj_array from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag, - FirstAxisIsElementsTag, DiscretizationElementAxisTag, - DiscretizationDOFAxisTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag, + IsDOFArray, IsOpArray, EinsumArgsTags, + DiscretizationElementAxisTag, DiscretizationDOFAxisTag) # underscored because it shouldn't be imported from here. from meshmode.dof_array import DOFArray as _DOFArray @@ -612,6 +612,10 @@ def prg(): t_unit = make_loopy_program( "{[iel,idof]: 0<=ielei", actx.tag_axis( 0, DiscretizationDOFAxisTag(), actx.from_numpy(grp.from_mesh_interp_matrix())), nodes, - tagged=( - FirstAxisIsElementsTag(), - NameHint(name_hint))) + tagged=(FirstAxisIsElementsTag(), NameHint(name_hint), kd_tag,)) result = make_obj_array([ _DOFArray(None, tuple([ @@ -730,15 +735,19 @@ def get_mat(grp, gref_axes): return actx.from_numpy(mat) + kd_tag = EinsumArgsTags({"arg0": (IsOpArray(),), + "arg1": (IsDOFArray(),), "out": (IsDOFArray(),)}) + return _DOFArray(actx, tuple( actx.einsum("ij,ej->ei", actx.tag_axis(0, DiscretizationDOFAxisTag(), get_mat(grp, ref_axes)), vec[igrp], - tagged=(FirstAxisIsElementsTag(),)) + tagged=(FirstAxisIsElementsTag(), kd_tag,)) for igrp, grp in enumerate(discr.groups))) + # }}} # vim: fdm=marker diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index e215bad56..7d5a4889b 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -30,8 +30,9 @@ import loopy as lp from meshmode.transform_metadata import ( ConcurrentElementInameTag, ConcurrentDOFInameTag, + IsDOFArray, ParameterValue, DiscretizationElementAxisTag, DiscretizationDOFAxisTag) -from pytools import memoize_in, keyed_memoize_method +from pytools import memoize_in, keyed_memoize_method, keyed_memoize_in from arraycontext import ( ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container, make_loopy_program, @@ -634,7 +635,7 @@ def batch_mat_knl(): * ary[from_element_indices[iel], jdof]) if from_el_present[iel] else 0) """, - [ + kernel_data=[ lp.GlobalArg("ary", None, shape="nelements_vec, nunit_dofs_src", offset=lp.auto), @@ -660,7 +661,7 @@ def batch_pick_knl(): ary[from_element_indices[iel], pick_list[idof]] if from_el_present[iel] else 0) """, - [ + kernel_data=[ lp.GlobalArg("ary", None, shape="nelements_vec, nunit_dofs_src", offset=lp.auto), @@ -676,8 +677,10 @@ def batch_pick_knl(): @memoize_in(actx, (DirectDiscretizationConnection, "resample_by_picking_group_knl")) - def group_pick_knl(is_surjective: bool): - + def group_pick_knl(nelements, nelements_src, nunit_dofs_src, + nelements_tgt, nunit_dofs_tgt, result_dtype, ary_dtype, + from_el_ind_dtype, pick_lists_dtype, pick_list_ind_dtype, + is_surjective: bool): if is_surjective: if_present = "" else: @@ -693,27 +696,181 @@ def group_pick_knl(is_surjective: bool): ary[ from_element_indices[iel], dof_pick_lists[dof_pick_list_indices[iel], idof] - ] + ] \ { if_present }) """, [ - lp.GlobalArg("ary", None, + lp.GlobalArg("result", result_dtype, + shape="nelements, nunit_dofs_tgt", + offset=lp.auto, tags=[IsDOFArray()]), + lp.GlobalArg("ary", ary_dtype, shape="nelements_src, nunit_dofs_src", - offset=lp.auto), - lp.GlobalArg("dof_pick_lists", None, + offset=lp.auto, tags=[IsDOFArray()]), + lp.GlobalArg("dof_pick_lists", pick_lists_dtype, shape="nelements_tgt, nunit_dofs_tgt", - offset=lp.auto), - lp.ValueArg("nelements_tgt", np.int32), - lp.ValueArg("nelements_src", np.int32), - lp.ValueArg("nunit_dofs_src", np.int32), + offset=0), + lp.GlobalArg("from_element_indices", from_el_ind_dtype, + shape="nelements,", offset=lp.auto), + lp.GlobalArg("dof_pick_list_indices", pick_list_ind_dtype, + shape="nelements,", offset=lp.auto), + lp.ValueArg("nelements_tgt", np.int32, + tags=[ParameterValue(nelements_tgt)]), + lp.ValueArg("nelements_src", np.int32, + tags=[ParameterValue(nelements_src)]), + lp.ValueArg("nunit_dofs_src", np.int32, + tags=[ParameterValue(nunit_dofs_src)]), + lp.ValueArg("nunit_dofs_tgt", np.int32, + tags=[ParameterValue(nunit_dofs_tgt)]), + lp.ValueArg("nelements", np.int32, + tags=[ParameterValue(nelements)]), "...", ], name="resample_by_picking_group", ) + return lp.tag_inames(t_unit, { "iel": ConcurrentElementInameTag(), "idof": ConcurrentDOFInameTag()}) + + def calc_get_indices_key(dof_pick_lists, dof_pick_list_indices, from_element_indices, + ary_shape, ary_order, from_el_present): + from pyopencl.array import Array + key = (dof_pick_lists.data.int_ptr, + dof_pick_list_indices.data.int_ptr, + from_element_indices.data.int_ptr, + ary_shape, + ary_order, + from_el_present.data.int_ptr if isinstance(from_el_present, Array) else None,) + return key + + + # from_el_present should be set to None if the indexing is surjective + @keyed_memoize_in(actx, + (DirectDiscretizationConnection, "calc_indices_knl"), calc_get_indices_key) + def get_indices_loopy(dof_pick_lists, dof_pick_list_indices, + from_element_indices, ary_shape, ary_order, + from_el_present): + + nelements = from_element_indices.shape[0] + nelements_tgt, nunit_dofs_tgt = dof_pick_lists.shape + + if ary_order == "F": + row_stride = 1 + col_stride = ary_shape[0] + else: + row_stride = ary_shape[1] + col_stride = 1 + + if from_el_present is None: + if_present = "" + else: + if_present = "if from_el_present[iel] else -1" + + + t_unit = make_loopy_program( + [ + "{[iel]: 0 <= iel < nelements}", + "{[idof]: 0 <= idof < nunit_dofs_tgt}" + ], + f""" + indices[iel, idof] = from_element_indices[iel]*{row_stride} \ + + dof_pick_lists[dof_pick_list_indices[iel], idof]*{col_stride} \ + {if_present} + """, + [ + lp.GlobalArg("indices", np.int32, + shape="nelements, nunit_dofs_tgt", + offset=lp.auto, tags=[IsDOFArray()]), + lp.ValueArg("nunit_dofs_tgt", np.int32, + tags=[ParameterValue(nunit_dofs_tgt)]), + lp.ValueArg("nelements", np.int32, + tags=[ParameterValue(nelements)]), + lp.ValueArg("nelements_tgt", np.int32, + tags=[ParameterValue(nelements_tgt)]), + lp.GlobalArg("dof_pick_lists", dof_pick_lists.dtype, + shape="nelements_tgt, nunit_dofs_tgt", + offset=lp.auto), + lp.GlobalArg("from_element_indices", from_element_indices.dtype, + shape="nelements,", offset=lp.auto), + lp.GlobalArg("dof_pick_list_indices", dof_pick_list_indices.dtype, + shape="nelements,", offset=lp.auto), + "...", + ], + name="resample_by_picking_calc_indices_knl", + ) + + t_unit = lp.tag_inames(t_unit, { + "iel": ConcurrentElementInameTag(), + "idof": ConcurrentDOFInameTag()}) + + if from_el_present is None: + out = actx.call_loopy(t_unit, from_element_indices=from_element_indices, + dof_pick_lists=dof_pick_lists, dof_pick_list_indices=dof_pick_list_indices) + else: + out = actx.call_loopy(t_unit, from_element_indices=from_element_indices, + dof_pick_lists=dof_pick_lists, dof_pick_list_indices=dof_pick_list_indices, + from_el_present=from_el_present) + + + return out["indices"] + + + @memoize_in(actx, + (DirectDiscretizationConnection, "resample_by_picking_single_indirection_knl")) + def group_pick_knl_single_indirection(nelements, nunit_dofs_tgt, nelements_src, nunit_dofs_src, + ary_dtype, is_surjective: bool): + + if is_surjective: + if_present = "" + else: + if_present = "if indices[iel,idof] >= 0 else 0" + + + t_unit = make_loopy_program( + [ + "{[iel]: 0 <= iel < nelements}", + "{[idof]: 0 <= idof < nunit_dofs_tgt}" + ], + f""" + result[iel, idof] = ary[indices[iel,idof]] \ + {if_present} + """, + [ + lp.GlobalArg("result", ary_dtype, + shape="nelements, nunit_dofs_tgt", + offset=lp.auto, tags=[IsDOFArray()]), + # Assuming np.int32 but could it be np.int64? + lp.GlobalArg("indices", np.int32, + shape="nelements, nunit_dofs_tgt", + offset=lp.auto, tags=[IsDOFArray()]), + lp.GlobalArg("ary", ary_dtype, offset=lp.auto, + shape="ary_size"), + # shape="nelements_src, nunit_dofs_src", + # offset=lp.auto)#, tags=[IsDOFArray()]), + lp.ValueArg("nunit_dofs_tgt", np.int32, + tags=[ParameterValue(nunit_dofs_tgt)]), + lp.ValueArg("nelements", np.int32, + tags=[ParameterValue(nelements)]), + lp.ValueArg("ary_size", np.int32, + tags=[ParameterValue(nelements_src*nunit_dofs_src)]), + "...", + ], + name="resample_by_picking_single_indirection", + ) + + t_unit = lp.tag_inames(t_unit, { + "iel": ConcurrentElementInameTag(), + "idof": ConcurrentDOFInameTag()}) + + return t_unit + #order = "F" if ary.flags.f_contiguous else "C" + #out = actx.call_loopy(t_unit, ary=ary.ravel(order=order), indices=indices) + # Raveling not needed? + #out = actx.call_loopy(t_unit, ary=ary, indices=indices) + #return = out["result"] + + # }}} group_arrays = [] @@ -763,16 +920,57 @@ def group_pick_knl(is_surjective: bool): group_knl_kwargs["from_el_present"] = \ fgpd.from_el_present - group_array_contributions.append( - actx.call_loopy( - group_pick_knl(fgpd.is_surjective), - dof_pick_lists=fgpd.dof_pick_lists, - dof_pick_list_indices=fgpd.dof_pick_list_indices, - ary=ary[fgpd.from_group_index], - from_element_indices=fgpd.from_element_indices, - nunit_dofs_tgt=( - self.to_discr.groups[i_tgrp].nunit_dofs), - **group_knl_kwargs)["result"]) + nelements = fgpd.from_element_indices.shape[0] + nelements_src, nunit_dofs_src = \ + ary[fgpd.from_group_index].shape + nelements_tgt = fgpd.dof_pick_lists.shape[0] + nunit_dofs_tgt = self.to_discr.groups[i_tgrp].nunit_dofs + ary_dtype = ary[fgpd.from_group_index].dtype + result_dtype = ary_dtype # Assume they are the same + from_el_ind_dtype = fgpd.from_element_indices.dtype + pick_lists_dtype = fgpd.dof_pick_lists.dtype + pick_list_ind_dtype = fgpd.dof_pick_list_indices.dtype + + if False: + + dof_pick_lists = fgpd.dof_pick_lists + dof_pick_list_indices = fgpd.dof_pick_list_indices + data_ary = ary[fgpd.from_group_index] + from_element_indices = fgpd.from_element_indices + + order = "F" if data_ary.flags.f_contiguous else "C" + lp_indices = get_indices_loopy(dof_pick_lists, dof_pick_list_indices, + from_element_indices, + data_ary.shape, order, None if fgpd.is_surjective else fgpd.from_el_present) + + cl_result = actx.call_loopy(group_pick_knl_single_indirection(lp_indices.shape[0], + lp_indices.shape[1], data_ary.shape[0], data_ary.shape[1], ary_dtype, + fgpd.is_surjective), + ary=data_ary.ravel(order=order), indices=lp_indices) + + group_array_contributions.append(cl_result["result"]) + + else: + group_array_contributions.append( + actx.call_loopy( + #group_pick_knl(fgpd.is_surjective), + group_pick_knl(nelements, nelements_src, + nunit_dofs_src, + nelements_tgt, + nunit_dofs_tgt, + result_dtype, + ary_dtype, + from_el_ind_dtype, + pick_lists_dtype, + pick_list_ind_dtype, + fgpd.is_surjective), + dof_pick_lists=fgpd.dof_pick_lists, + dof_pick_list_indices=fgpd.dof_pick_list_indices, + ary=ary[fgpd.from_group_index], + from_element_indices=fgpd.from_element_indices, + #nunit_dofs_tgt=( + # self.to_discr.groups[i_tgrp].nunit_dofs), + **group_knl_kwargs)["result"]) group_array = sum(group_array_contributions) elif cgrp.batches: @@ -903,10 +1101,10 @@ def knl(): isrc_base + from_element_indices[iel]*nunit_dofs_src + jdof] \ = resample_mat[idof, jdof] {dep=barrier} """, - [ + kernel_data=[ lp.GlobalArg("result", None, shape="nnodes_tgt, nnodes_src", - offset=lp.auto), + offset=lp.auto, tags=[IsDOFArray()]), lp.ValueArg("itgt_base, isrc_base", np.int32), lp.ValueArg("nnodes_tgt, nnodes_src", np.int32), ..., diff --git a/meshmode/discretization/connection/modal.py b/meshmode/discretization/connection/modal.py index 2fe813e5b..ccc3ef495 100644 --- a/meshmode/discretization/connection/modal.py +++ b/meshmode/discretization/connection/modal.py @@ -27,10 +27,10 @@ import numpy.linalg as la import modepy as mp +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, IsDOFArray, IsOpArray, EinsumArgsTags) from arraycontext import ( NotAnArrayContainerError, serialize_container, deserialize_container) -from meshmode.transform_metadata import (FirstAxisIsElementsTag, - DiscretizationDOFAxisTag) from meshmode.discretization import InterpolatoryElementGroupBase from meshmode.discretization.poly_element import QuadratureSimplexElementGroup from meshmode.discretization.connection.direct import DiscretizationConnection @@ -164,9 +164,13 @@ def vandermonde_inverse(grp): return actx.tag_axis(0, DiscretizationDOFAxisTag(), actx.from_numpy(vdm_inv)) + kd_tag = EinsumArgsTags({"arg0": (IsOpArray(),), + "arg1": (IsDOFArray(),), "out": (IsDOFArray(),)}) + return actx.einsum("ij,ej->ei", - vandermonde_inverse(grp), - ary, tagged=(FirstAxisIsElementsTag(),)) + vandermonde_inverse(grp), + ary, + tagged=(FirstAxisIsElementsTag(), kd_tag,)) def __call__(self, ary): """Computes modal coefficients data from a functions diff --git a/meshmode/discretization/connection/projection.py b/meshmode/discretization/connection/projection.py index 179dbf8ea..b8af6dd72 100644 --- a/meshmode/discretization/connection/projection.py +++ b/meshmode/discretization/connection/projection.py @@ -29,7 +29,7 @@ from arraycontext import ( NotAnArrayContainerError, make_loopy_program, serialize_container, deserialize_container) -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import FirstAxisIsElementsTag, IsDOFArray from meshmode.discretization.connection.direct import ( DiscretizationConnection, DirectDiscretizationConnection) @@ -172,7 +172,7 @@ def kproj(): shape=("n_from_elements", "n_from_nodes")), lp.GlobalArg("result", None, shape=("n_to_elements", "n_to_nodes"), - is_input=False), + is_input=False, tags=[IsDOFArray()]), lp.GlobalArg("basis_tabulation", None, shape=("n_to_nodes", "n_to_nodes")), lp.GlobalArg("weights", None, diff --git a/meshmode/distributed.py b/meshmode/distributed.py index d99e90dc8..161dc78f4 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -3,6 +3,7 @@ .. autoclass:: InterRankBoundaryInfo .. autoclass:: MPIBoundaryCommSetupHelper +.. autofunction:: mpi_distribute .. autofunction:: get_partition_by_pymetis .. autofunction:: membership_list_to_map .. autofunction:: get_connected_parts @@ -37,8 +38,11 @@ """ from dataclasses import dataclass +from contextlib import contextmanager import numpy as np -from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +from typing import ( + Any, Optional, List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +) from arraycontext import ArrayContext from meshmode.discretization.connection import ( @@ -66,12 +70,73 @@ import logging logger = logging.getLogger(__name__) -TAG_BASE = 83411 -TAG_DISTRIBUTE_MESHES = TAG_BASE + 1 - # {{{ mesh distributor +@contextmanager +def _duplicate_mpi_comm(mpi_comm): + dup_comm = mpi_comm.Dup() + try: + yield dup_comm + finally: + dup_comm.Free() + + +def mpi_distribute( + mpi_comm: "mpi4py.MPI.Intracomm", + source_data: Optional[Mapping[int, Any]] = None, + source_rank: int = 0) -> Optional[Any]: + """ + Distribute data to a set of processes. + + :arg mpi_comm: An ``MPI.Intracomm`` + :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. + Only present on the source rank. + :arg source_rank: The rank from which the data is being sent. + + :returns: The data local to the current process if there is any, otherwise + *None*. + """ + with _duplicate_mpi_comm(mpi_comm) as mpi_comm: + num_proc = mpi_comm.Get_size() + rank = mpi_comm.Get_rank() + + local_data = None + + if rank == source_rank: + if source_data is None: + raise TypeError("source rank has no data.") + + sending_to = [False] * num_proc + for dest_rank in source_data.keys(): + sending_to[dest_rank] = True + + mpi_comm.scatter(sending_to, root=source_rank) + + reqs = [] + for dest_rank, data in source_data.items(): + if dest_rank == rank: + local_data = data + logger.info("rank %d: received data", rank) + else: + reqs.append(mpi_comm.isend(data, dest=dest_rank)) + + logger.info("rank %d: sent all data", rank) + + from mpi4py import MPI + MPI.Request.waitall(reqs) + + else: + receiving = mpi_comm.scatter([], root=source_rank) + + if receiving: + local_data = mpi_comm.recv(source=source_rank) + logger.info("rank %d: received data", rank) + + return local_data + + +# TODO: Deprecate? class MPIMeshDistributor: """ .. automethod:: is_mananger_rank @@ -99,9 +164,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): Sends each part to a different rank. Returns one part that was not sent to any other rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert num_parts <= mpi_comm.Get_size() + assert num_parts <= self.mpi_comm.Get_size() assert self.is_mananger_rank() @@ -110,38 +173,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): from meshmode.mesh.processing import partition_mesh parts = partition_mesh(mesh, part_num_to_elements) - local_part = None - - reqs = [] - for r, part in parts.items(): - if r == self.manager_rank: - local_part = part - else: - reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES)) - - logger.info("rank %d: sent all mesh parts", rank) - for req in reqs: - req.wait() - - return local_part + return mpi_distribute( + self.mpi_comm, source_data=parts, source_rank=self.manager_rank) def receive_mesh_part(self): """ Returns the mesh sent by the manager rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert not self.is_mananger_rank(), "Manager rank cannot receive mesh" - from mpi4py import MPI - status = MPI.Status() - result = self.mpi_comm.recv( - source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES, - status=status) - logger.info("rank %d: received local mesh (size = %d)", rank, status.count) - - return result + return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank) # }}} diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py index 7e141113d..9dc146d01 100644 --- a/meshmode/dof_array.py +++ b/meshmode/dof_array.py @@ -35,7 +35,8 @@ from pytools import single_valued, memoize_in from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + IsDOFArray) from arraycontext import ( ArrayContext, ArrayOrContainerT, NotAnArrayContainerError, make_loopy_program, with_container_arithmetic, @@ -59,9 +60,9 @@ .. autofunction:: check_dofarray_against_discr """ - # {{{ DOFArray + @with_container_arithmetic( bcast_obj_array=True, bcast_numpy_array=True, @@ -459,9 +460,9 @@ def prg(): """, [ lp.GlobalArg("result", None, - shape="nelements * ndofs_per_element"), + shape="nelements * ndofs_per_element"), lp.GlobalArg("grp_ary", None, - shape=("nelements", "ndofs_per_element")), + shape=("nelements", "ndofs_per_element"), tags=[IsDOFArray()]), lp.ValueArg("nelements", np.int32), lp.ValueArg("ndofs_per_element", np.int32), "..." @@ -536,6 +537,10 @@ def prg(): t_unit = make_loopy_program( "{[iel,idof]: 0<=iel bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +class EagerReduceComputingPytatoFakeNumpyNamespace(PytatoFakeNumpyNamespace): + """ + A Numpy-namespace that computes the reductions eagerly whenever possible. + """ + def sum(self, a, axis=None, dtype=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + + def _pt_sum(ary): + return cl_array.sum(self._array_context.freeze(ary), + dtype=dtype, + queue=self._array_context.queue) + + return self._array_context.thaw(rec_map_reduce_array_container(sum, + _pt_sum, + a)) + else: + return super().sum(a, axis=axis, dtype=dtype) + + def min(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + lambda ary: cl_array.min(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().min(a, axis=axis) + + def max(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + lambda ary: cl_array.max(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().max(a, axis=axis) + + +# {{{ solve for discretization metadata for arrays' axes + +class AxesTagsEquationCollector(BaseAxesTagsEquationCollector): + def map_reshape(self, expr: pt.Reshape) -> None: + super().map_reshape(expr) + + if (expr.size > 0 + and (1 not in (expr.array.shape)) # leads to ambiguous newaxis + and (set(expr.shape) <= (set(expr.array.shape) | {1}))): + i_in_axis = 0 + for i_out_axis, dim in enumerate(expr.shape): + if dim != 1: + assert dim == expr.array.shape[i_in_axis] + self.record_equation( + self.get_var_for_axis(expr.array, + i_in_axis), + self.get_var_for_axis(expr, + i_out_axis) + ) + i_in_axis += 1 + else: + # print(f"Skipping: {expr.array.shape} -> {expr.shape}") + # Wacky reshape => bail. + pass + + +def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] + ) -> ArrayOrNames: + if not isinstance(expr, (pt.Array, pt.DictOfNamedArrays)): + return rec_map_array_container(unify_discretization_entity_tags, + expr) + + return pt.unify_axes_tags(expr, + tag_t=DiscretizationEntityAxisTag, + equations_collector_t=AxesTagsEquationCollector) + +# }}} + + +# vim: fdm=marker diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index 0753e0dfe..8c032df54 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -2,6 +2,10 @@ .. autoclass:: FirstAxisIsElementsTag .. autoclass:: ConcurrentElementInameTag .. autoclass:: ConcurrentDOFInameTag +.. autoclass:: ParameterValue +.. autoclass:: IsDOFArray +.. autoclass:: IsOpArray +.. autoclass:: EinsumArgsTags .. autoclass:: DiscretizationEntityAxisTag .. autoclass:: DiscretizationElementAxisTag .. autoclass:: DiscretizationFaceAxisTag @@ -34,6 +38,8 @@ THE SOFTWARE. """ +from immutables import Map +from typing import Any from pytools.tag import Tag, tag_dataclass, UniqueTag @@ -65,6 +71,41 @@ class ConcurrentDOFInameTag(Tag): """ +@tag_dataclass +class ParameterValue(UniqueTag): + """A tag that applies to :class:`loopy.ValueArg`. Instances of this tag + are initialized with the value of the parameter and this value may be + later retrieved to fix the value of the parameter. This allows moving + calls to `loopy.fix_parameter` to `transform_loopy_program` so that all + kernel transformations may occur there. + """ + value: Any + + +class IsDOFArray(Tag): + """A tag that is applicable to :class:`loopy.ArrayArg` indicating the content of the + array comprises element DOFs. + """ + pass + + +class IsOpArray(Tag): + """A tag that is applicable to arrays indicating the array is an + operator (as opposed, for instance, to a DOF array).""" + pass + + +@tag_dataclass +class EinsumArgsTags(Tag): + """A tag containing an `immutables.Map` of tuples of tags indexed by + argument name. + """ + tags_map: Map + + def __init__(self, tags_map): + object.__setattr__(self, "tags_map", Map(tags_map)) + + class DiscretizationEntityAxisTag(UniqueTag): """ A tag applicable to an array's axis to describe which discretization entity diff --git a/requirements.txt b/requirements.txt index 7ad3c51b2..307d02058 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,19 @@ numpy recursivenodes +immutables git+https://github.com/inducer/pytools.git#egg=pytools git+https://github.com/inducer/gmsh_interop.git#egg=gmsh_interop git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/kaushikcfd/pytato.git#egg=pytato # required by pytential, which is in turn needed for some tests git+https://github.com/inducer/pymbolic.git#egg=pymbolic # also depends on pymbolic, so should come after it -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/kaushikcfd/loopy.git#egg=loopy # depends on loopy, so should come after it. git+https://github.com/inducer/arraycontext.git#egg=arraycontext @@ -27,3 +28,8 @@ git+https://github.com/inducer/pymetis.git#egg=pymetis # for examples/tp-lagrange-stl.py numpy-stl + + +# for FusionContractorActx transforms +git+https://github.com/kaushikcfd/feinsum.git#egg=feinsum +git+https://github.com/pythological/kanren.git#egg=miniKanren diff --git a/setup.py b/setup.py index 51e6fda4e..34df70ffc 100644 --- a/setup.py +++ b/setup.py @@ -41,15 +41,17 @@ def main(): "numpy", "modepy>=2020.2", "gmsh_interop", - "pytools>=2020.4.1", + "pytools>=2021.2.1", "pytest>=2.3", # 2019.1 is required for the Firedrake CIs, which use an very specific # version of Loopy. - "loopy>=2019.1", + #"loopy>=2019.1", + "loopy>=2020.2.2", "arraycontext", + "immutables", "recursivenodes", "dataclasses; python_version<'3.7'", "typing_extensions; python_version<'3.8'",