-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: SDFGConvertible Program for dace_fieldview backend #1742
Open
DropD
wants to merge
12
commits into
GridTools:main
Choose a base branch
from
DropD:gtir-sdfg-convertible
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
4c41f02
[wip] start adding sdfg convertible Program for dace-fieldview
DropD 481af34
add SDFGConvertible Program replacement to dace_fieldview
DropD ad0a7b2
improve generate_sdfg args, remove old Program replacement
DropD 307b537
Merge branch 'main' into gtir-sdfg-convertible
DropD 7c4d197
turn auto_optimize back on in __sdfg__
DropD 3c56151
disable `auto_opt` once again in `__sdfg__`
DropD c5b4c43
support only CUDA device type in dace_fieldview Program
DropD 835c115
[wip] bring back extra sdfg attributes for halo placement
DropD c126cb2
partially add halo exchange helper attrs with tests
DropD 0b90583
add dace/gt4py type parsing crosscheck
DropD 8c1f9f4
Merge branch 'main' into gtir-sdfg-convertible
DropD 77c250b
fix dace_fieldview.program tests
DropD File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
298 changes: 298 additions & 0 deletions
298
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# Please, refer to the LICENSE file in the root directory. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import collections | ||
import dataclasses | ||
import itertools | ||
import typing | ||
from typing import Any, ClassVar, Optional, Sequence | ||
|
||
import dace | ||
import numpy as np | ||
|
||
from gt4py import eve | ||
from gt4py.next import backend as next_backend, common | ||
from gt4py.next.ffront import decorator | ||
from gt4py.next.iterator import ir as itir | ||
from gt4py.next.otf import arguments, recipes, toolchain | ||
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils | ||
from gt4py.next.type_system import type_specifications as ts | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): | ||
"""Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR.""" | ||
|
||
sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict) | ||
# Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, | ||
# there is no name mangling of the connectivity tables used across the nested SDFGs | ||
# since they share the same memory address. | ||
connectivity_tables_data_descriptors: ClassVar[ | ||
dict[str, dace.data.Array] | ||
] = {} # symbolically defined | ||
|
||
def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: | ||
if (self.backend is None) or "dace" not in self.backend.name.lower(): | ||
raise ValueError("The SDFG can be generated only for the DaCe backend.") | ||
|
||
offset_provider: common.OffsetProvider = { | ||
**(self.connectivities or {}), | ||
**self._implicit_offset_provider, | ||
} | ||
column_axis = kwargs.get("column_axis", None) | ||
|
||
gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir( | ||
toolchain.CompilableProgram( | ||
data=self.past_stage, | ||
args=arguments.CompileTimeArgs( | ||
args=tuple(p.type for p in self.past_stage.past_node.params), | ||
kwargs={}, | ||
column_axis=column_axis, | ||
offset_provider=offset_provider, | ||
), | ||
) | ||
) | ||
program = typing.cast( | ||
itir.Program, gtir_stage.data | ||
) # we already checked that our backend uses GTIR | ||
|
||
_crosscheck_dace_parsing( | ||
dace_parsed_args=[*args, *kwargs.values()], | ||
gt4py_program_args=[p.type for p in program.params], | ||
) | ||
|
||
compile_workflow = typing.cast( | ||
recipes.OTFCompileWorkflow, | ||
self.backend.executor | ||
if not hasattr(self.backend.executor, "step") | ||
else self.backend.executor.step, | ||
) # We know which backend we are using, but we don't know if the compile workflow is cached. | ||
sdfg = dace.SDFG.from_json(compile_workflow.translation(gtir_stage).source_code) | ||
|
||
self.sdfg_closure_cache["arrays"] = sdfg.arrays | ||
|
||
# Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, | ||
# offset_providers_per_input_field. Add them as dynamic attributes to the SDFG | ||
field_params = { | ||
str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType) | ||
} | ||
|
||
def single_horizontal_dim_per_field( | ||
fields: typing.Iterable[itir.Sym], | ||
) -> typing.Iterator[tuple[str, common.Dimension]]: | ||
for field in fields: | ||
assert isinstance(field.type, ts.FieldType) | ||
horizontal_dims = [ | ||
dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL | ||
] | ||
# do nothing for fields with multiple horizontal dimensions | ||
# or without horizontal dimensions | ||
# this is only meant for use with unstructured grids | ||
if len(horizontal_dims) == 1: | ||
yield str(field.id), horizontal_dims[0] | ||
|
||
input_fields = (field_params[name] for name in InputNamesExtractor.only_fields(program)) | ||
sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) | ||
|
||
output_fields = (field_params[name] for name in OutputNamesExtractor.only_fields(program)) | ||
sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) | ||
|
||
# TODO (ricoh): bring back sdfg.offset_providers_per_input_field. | ||
# A starting point would be to use the "trace_shifts" pass on GTIR | ||
# and associate the extracted shifts with each input field. | ||
# Analogous to the version in `runners.dace_iterator.__init__`, which | ||
# was removed when merging #1742. | ||
|
||
return sdfg | ||
|
||
def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: | ||
""" | ||
Return the closure arrays of the SDFG represented by this object | ||
as a mapping between array name and the corresponding value. | ||
|
||
The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. | ||
The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that | ||
the offset providers are not part of GT4Py Program's arguments. | ||
Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. | ||
""" | ||
closure_dict: dict[str, Any] = {} | ||
|
||
if self.connectivities: | ||
symbols = {} | ||
with_table = [ | ||
name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn) | ||
] | ||
in_arrays_with_id = [ | ||
(name, conn_id) | ||
for name in with_table | ||
if (conn_id := dace_utils.connectivity_identifier(name)) | ||
in self.sdfg_closure_cache["arrays"] | ||
] | ||
in_arrays = (name for name, _ in in_arrays_with_id) | ||
name_axis = list(itertools.product(in_arrays, [0, 1])) | ||
|
||
def size_symbol_name(name: str, axis: int) -> str: | ||
return dace_utils.field_size_symbol_name( | ||
dace_utils.connectivity_identifier(name), axis | ||
) | ||
|
||
connectivity_tables_size_symbols = { | ||
(sname := size_symbol_name(name, axis)): dace.symbol(sname) | ||
for name, axis in name_axis | ||
} | ||
|
||
def stride_symbol_name(name: str, axis: int) -> str: | ||
return dace_utils.field_stride_symbol_name( | ||
dace_utils.connectivity_identifier(name), axis | ||
) | ||
|
||
connectivity_table_stride_symbols = { | ||
(sname := stride_symbol_name(name, axis)): dace.symbol(sname) | ||
for name, axis in name_axis | ||
} | ||
|
||
symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols | ||
|
||
# Define the storage location (e.g. CPU, GPU) of the connectivity tables | ||
if "storage" not in self.connectivity_tables_data_descriptors: | ||
for _, conn_id in in_arrays_with_id: | ||
self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[ | ||
"arrays" | ||
][conn_id].storage | ||
break | ||
|
||
# Build the closure dictionary | ||
for name, conn_id in in_arrays_with_id: | ||
if conn_id not in self.connectivity_tables_data_descriptors: | ||
conn = self.connectivities[name] | ||
assert common.is_neighbor_table(conn) | ||
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( | ||
dtype=dace.int64 if conn.dtype == np.int64 else dace.int32, | ||
shape=[ | ||
symbols[dace_utils.field_size_symbol_name(conn_id, 0)], | ||
symbols[dace_utils.field_size_symbol_name(conn_id, 1)], | ||
], | ||
strides=[ | ||
symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], | ||
symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], | ||
], | ||
storage=Program.connectivity_tables_data_descriptors["storage"], | ||
) | ||
closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id] | ||
|
||
return closure_dict | ||
|
||
def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: | ||
return [p.id for p in self.past_stage.past_node.params], [] | ||
|
||
|
||
class SymbolNameSetExtractor(eve.NodeVisitor): | ||
"""Extract a set of symbol names""" | ||
|
||
def generic_visitor(self, node: itir.Node) -> set[str]: | ||
input_fields: set[str] = set() | ||
for child in eve.trees.iter_children_values(node): | ||
input_fields |= self.visit(child) | ||
return input_fields | ||
|
||
@classmethod | ||
def only_fields(cls, program: itir.Program) -> set[str]: | ||
field_param_names = [ | ||
str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) | ||
] | ||
return {name for name in cls().visit(program) if name in field_param_names} | ||
|
||
|
||
class InputNamesExtractor(SymbolNameSetExtractor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you move these utilities to some module in |
||
"""Extract the set of symbol names passed into field operators within a program.""" | ||
|
||
def visit_Program(self, node: itir.Program) -> set[str]: | ||
input_fields = set() | ||
for stmt in node.body: | ||
input_fields |= self.visit(stmt) | ||
return input_fields | ||
|
||
def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: | ||
input_fields = set() | ||
for stmt in node.true_branch + node.false_branch: | ||
input_fields |= self.visit(stmt) | ||
return input_fields | ||
|
||
def visit_Temporary(self, node: itir.Temporary) -> set[str]: | ||
return set() | ||
|
||
def visit_SetAt(self, node: itir.SetAt) -> set[str]: | ||
return self.visit(node.expr) | ||
|
||
def visit_FunCall(self, node: itir.FunCall) -> set[str]: | ||
input_fields = set() | ||
for arg in node.args: | ||
input_fields |= self.visit(arg) | ||
return input_fields | ||
|
||
def visit_SymRef(self, node: itir.SymRef) -> set[str]: | ||
return {str(node.id)} | ||
|
||
|
||
class OutputNamesExtractor(SymbolNameSetExtractor): | ||
"""Extract the set of symbol names written to within a program""" | ||
|
||
def visit_Program(self, node: itir.Program) -> set[str]: | ||
output_fields = set() | ||
for stmt in node.body: | ||
output_fields |= self.visit(stmt) | ||
return output_fields | ||
|
||
def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: | ||
output_fields = set() | ||
for stmt in node.true_branch + node.false_branch: | ||
output_fields |= self.visit(stmt) | ||
return output_fields | ||
|
||
def visit_Temporary(self, node: itir.Temporary) -> set[str]: | ||
return set() | ||
|
||
def visit_SetAt(self, node: itir.SetAt) -> set[str]: | ||
return self.visit(node.target) | ||
|
||
def visit_SymRef(self, node: itir.SymRef) -> set[str]: | ||
return {str(node.id)} | ||
|
||
|
||
def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: | ||
for dace_parsed_arg, gt4py_program_arg in zip( | ||
dace_parsed_args, | ||
gt4py_program_args, | ||
strict=False, # dace does not see implicit size args | ||
): | ||
match dace_parsed_arg: | ||
case dace.data.Scalar(): | ||
assert dace_parsed_arg.type == dace_utils.as_dace_type(gt4py_program_arg) | ||
case bool() | np.bool_(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind == ts.ScalarKind.BOOL | ||
case int() | np.integer(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] | ||
case float() | np.floating(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] | ||
case str() | np.str_(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind == ts.ScalarKind.STRING | ||
case dace.data.Array(): | ||
assert isinstance(gt4py_program_arg, ts.FieldType) | ||
assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) | ||
assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) | ||
case dace.data.Structure() | dict() | collections.OrderedDict(): | ||
# offset provider | ||
pass | ||
case _: | ||
raise ValueError( | ||
f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}" | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@havogt I had to change this to satisfy mypy. Now it is also consistent with
CompileTimeArgs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense