Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridTools Stage Extents #136

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
894ae41
Correct extent calculations to match gt4py passes
Jul 15, 2020
36517e0
Remove extra whitespace
Jul 17, 2020
bccf0d8
Another formatting fix
Jul 17, 2020
bff858a
Merge branch 'master' into dawn_extent_fix
Jul 29, 2020
709519b
Add 'dawn_test_fix' branch
Jul 30, 2020
813ae00
Do not pass OMP flags to NVCC
Jul 30, 2020
de29112
Merge with 'cuda_build_fix'
Jul 30, 2020
3088039
Enable NVCC/OMP cross-compile
Jul 30, 2020
2697653
Fix CLI tests for Dawn backends
Jul 30, 2020
27046fa
Revert previous change
Jul 30, 2020
530d378
Merge pyext_builder.py
Jul 30, 2020
6cf7843
Fix CLI tests for Dawn backends
Jul 30, 2020
8668c5b
Merge remote-tracking branch 'origin/dawn_cli_tests_fix' into gt_stag…
Jul 30, 2020
f6a1815
Merge with 'dawn_cli_tests_fix'
Jul 30, 2020
e656fbc
Add stage extents to GT backends
Jul 30, 2020
38627b0
Prepend 'Xcompiler' before OMP flags if CUDA
Jul 31, 2020
f643a98
Merge branch 'cuda_build_fix' into gt_stage_extents
Jul 31, 2020
5740bd6
Change -Xcompiler -> --compiler-options
Aug 3, 2020
16151ec
Merge branch 'cuda_build_fix' into gt_stage_extents
Aug 3, 2020
1b6da20
Merge with master
Aug 6, 2020
ab86a20
Merge branch 'gt_stage_extents' of github.com:eddie-c-davis/gt4py int…
Aug 6, 2020
bf158fe
Add TestQXWestEdge to stencil test suite
Aug 6, 2020
17b4ebe
Update TestQXWestEdge test
Aug 6, 2020
f3556b5
Merge branch 'master' into gt_stage_extents
Aug 25, 2020
978d126
Merge with master
Aug 25, 2020
0b9828d
Add type hint for 'extents'
Aug 25, 2020
9bf02c2
Merge branch 'master' into gt_stage_extents
Aug 31, 2020
e17e838
Merge changes from master
Aug 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/gt4py/backend/gt_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,18 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation) -> Dict
if name not in node.unreferenced
]

stage_extents = {}
stage_functors = {}
for multi_stage in node.multi_stages:
for group in multi_stage.groups:
for stage in group.stages:
compute_extent = stage.compute_extent
extents: List[int] = []
for i in range(compute_extent.ndims - 1):
extents.extend(
(compute_extent.lower_indices[i], compute_extent.upper_indices[i])
)
stage_extents[stage.name] = ", ".join([str(extent) for extent in extents])
stage_functors[stage.name] = self.visit(stage)

multi_stages = []
Expand All @@ -491,6 +499,7 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation) -> Dict
multi_stages=multi_stages,
parameters=parameters,
stage_functors=stage_functors,
stage_extents=stage_extents,
stencil_unique_name=self.class_name,
tmp_fields=tmp_fields,
)
Expand Down Expand Up @@ -653,7 +662,7 @@ def generate_post_run(self) -> str:
output_field_names = [
name
for name, info in self.args_data["field_info"].items()
if info.access == gt_definitions.AccessKind.READ_WRITE
if info and info.access == gt_definitions.AccessKind.READ_WRITE
]

return "\n".join([f + "._set_device_modified()" for f in output_field_names])
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/backend/templates/computation.src.in
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
]
}
}
- stage_extents: { str: str }
- stencil_unique_name: str
- tmp_fields: [{ "name": str, "dtype": str }]
#}
Expand Down Expand Up @@ -259,7 +260,7 @@ void run(const std::array<gt::uint_t, 3>& domain,
{%- for stage in step %}
{%- filter indent(width=extra_indent) %}
{{- stage_comma() }}
gt::make_stage<{{ stage }}_func>(
gt::make_stage_with_extent<{{ stage }}_func, gt::extent<{{ stage_extents[stage] }}>>(
p_{{ stage_functors[stage].args|map(attribute="name")|join("(), p_")}}()
)
{%- endfilter %}
Expand Down
7 changes: 7 additions & 0 deletions tests/test_integration/stencil_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,10 @@ def set_inner_as_kord(a4_1: Field3D, a4_2: Field3D, a4_3: Field3D, extm: Field3D
a4_3 = a4_1
else:
diff_23 = a4_2 - a4_3


@register
def write_after_read(field: Field3D):
with computation(PARALLEL), interval(...):
tmp = field
field = tmp[-1, 0, 0]