Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ The `pyproject.toml` file contains both the definition of the `gt4py` Python dis

### Development Tasks (`dev-tasks.py`)

Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration accross different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`).
Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration across different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`).

## 📖 Documentation

Expand Down
57 changes: 30 additions & 27 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,20 @@ def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, doma

class SDFGManager:
# Cache loaded SDFGs across all instances
_loaded_sdfgs: ClassVar[dict[str, dace.SDFG]] = dict()
_loaded_sdfgs: ClassVar[dict[str | pathlib.Path, dace.SDFG]] = dict()

def __init__(self, builder):
def __init__(self, builder: StencilBuilder) -> None:
self.builder = builder

@staticmethod
def _strip_history(sdfg):
def _strip_history(sdfg: dace.SDFG) -> None:
# strip history from SDFG for faster save/load
for tmp_sdfg in sdfg.all_sdfgs_recursive():
tmp_sdfg.transformation_hist = []
tmp_sdfg.orig_sdfg = None

@staticmethod
def _save_sdfg(sdfg, path):
def _save_sdfg(sdfg: dace.SDFG, path: str) -> None:
SDFGManager._strip_history(sdfg)
sdfg.save(path)

Expand Down Expand Up @@ -375,27 +375,26 @@ def expanded_sdfg(self):
return copy.deepcopy(self._expanded_sdfg())

def _frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]):
frozen_hash = shash(origin, domain)
basename = self.builder.module_path.with_suffix("")
path = f"{basename}_{shash(origin, domain)}.sdfg"

# check if same sdfg already cached on disk
path = self.builder.module_path
basename = os.path.splitext(path)[0]
path = basename + "_" + str(frozen_hash) + ".sdfg"
if path not in SDFGManager._loaded_sdfgs:
try:
sdfg = dace.SDFG.from_file(path)
except FileNotFoundError:
# otherwise, wrap and save sdfg from scratch
inner_sdfg = self.unexpanded_sdfg()

sdfg = freeze_origin_domain_sdfg(
inner_sdfg,
arg_names=[arg.name for arg in self.builder.gtir.api_signature],
field_info=make_args_data_from_gtir(self.builder.gtir_pipeline).field_info,
origin=origin,
domain=domain,
)
self._save_sdfg(sdfg, path)
SDFGManager._loaded_sdfgs[path] = sdfg
if path in SDFGManager._loaded_sdfgs:
return SDFGManager._loaded_sdfgs[path]

# otherwise, wrap and save sdfg from scratch
inner_sdfg = self.unexpanded_sdfg()

sdfg = freeze_origin_domain_sdfg(
inner_sdfg,
arg_names=[arg.name for arg in self.builder.gtir.api_signature],
field_info=make_args_data_from_gtir(self.builder.gtir_pipeline).field_info,
origin=origin,
domain=domain,
)
SDFGManager._loaded_sdfgs[path] = sdfg
self._save_sdfg(sdfg, path)

return SDFGManager._loaded_sdfgs[path]

def frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]):
Expand Down Expand Up @@ -631,6 +630,7 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> list[
symbols[name] = fmt.format(
name=name, ndim=len(array.shape), origin=",".join(str(o) for o in origin)
)

# the remaining arguments are variables and can be passed by name
for sym in sdfg.signature_arglist(with_types=False, for_call=True):
if sym not in symbols:
Expand Down Expand Up @@ -678,8 +678,7 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> li
)
)
elif name in sdfg.symbols and not name.startswith("__"):
assert name in sdfg.symbols
res[name] = "{dtype} {name}".format(dtype=sdfg.symbols[name].ctype, name=name)
res[name] = f"{sdfg.symbols[name].ctype} {name}"
return list(res[node.name] for node in stencil_ir.params if node.name in res)

def generate_sid_params(self, sdfg: dace.SDFG) -> list[str]:
Expand All @@ -688,7 +687,11 @@ def generate_sid_params(self, sdfg: dace.SDFG) -> list[str]:
for name, array in sdfg.arrays.items():
if array.transient:
continue
domain_dim_flags = array_dimensions(array)

domain_dim_flags = tuple(array_dimensions(array))
if len(domain_dim_flags) != 3:
raise RuntimeError("Expected 3 cartesian array dimensions. Codegen error.")

data_ndim = len(array.shape) - sum(domain_dim_flags)
sid_def = pybuffer_to_sid(
name=name,
Expand Down
31 changes: 13 additions & 18 deletions src/gt4py/cartesian/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,30 +100,25 @@ def freeze(
self: DaCeStencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...]
) -> DaCeFrozenStencil:
key = DaCeStencilObject._get_domain_origin_key(domain, origin)

# check if same sdfg already cached on disk
if key in self._frozen_cache:
return self._frozen_cache[key]

frozen_hash = shash(origin, domain)
# otherwise, wrap and save sdfg from scratch
frozen_sdfg = freeze_origin_domain_sdfg(
self.sdfg(),
arg_names=self.__sdfg_signature__()[0],
field_info=self.field_info,
origin=origin,
domain=domain,
)
self._frozen_cache[key] = DaCeFrozenStencil(self, origin, domain, frozen_sdfg)

# check if same sdfg already cached on disk
basename = os.path.splitext(self.SDFG_PATH)[0]
filename = basename + "_" + str(frozen_hash) + ".sdfg"
try:
frozen_sdfg = dace.SDFG.from_file(filename)
except FileNotFoundError:
# otherwise, wrap and save sdfg from scratch
inner_sdfg = self.sdfg()

frozen_sdfg = freeze_origin_domain_sdfg(
inner_sdfg,
arg_names=self.__sdfg_signature__()[0],
field_info=self.field_info,
origin=origin,
domain=domain,
)
frozen_sdfg.save(filename)
filename = f"{basename}_{shash(origin, domain)}.sdfg"
frozen_sdfg.save(filename)

self._frozen_cache[key] = DaCeFrozenStencil(self, origin, domain, frozen_sdfg)
return self._frozen_cache[key]

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def stencil_id(self) -> StencilID:
# typeignore because attrclass StencilID has generated constructor
return StencilID( # type: ignore
self.builder.options.qualified_name,
gt_utils.shashed_id(gt_utils.shashed_id(fingerprint), self.options_id),
gt_utils.shashed_id(fingerprint, self.options_id),
)

@property
Expand Down
7 changes: 2 additions & 5 deletions src/gt4py/cartesian/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,14 @@ class BuildOptions(AttributeClassLike):

@property
def qualified_name(self):
name = ".".join([self.module, self.name])
return name
return ".".join([self.module, self.name])

@property
def shashed_id(self):
result = gt_utils.shashed_id(
return gt_utils.shashed_id(
self.name, self.module, self.format_source, *tuple(sorted(self.backend_opts.items()))
)

return result


@attribclass(frozen=True)
class StencilID(AttributeClassLike):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex
if isinstance(value, ast.Name) and name not in assigned_symbols
}

call_id = gt_utils.shashed_id(call_name)[:3]
call_id = gt_utils.shashed_id(call_name, length=3)
call_id_suffix = f"{call_id}_{node.lineno}_{node.col_offset}"
template_fmt = "{name}__" + call_id_suffix

Expand Down
12 changes: 7 additions & 5 deletions src/gt4py/cartesian/gtc/dace/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from functools import lru_cache
from typing import TYPE_CHECKING

import dace
import numpy as np
from dace import dtypes, symbolic

from gt4py import eve
from gt4py.cartesian.gtc import common
Expand All @@ -22,9 +22,9 @@
from gt4py.cartesian.gtc.dace import daceir as dcir


def data_type_to_dace_typeclass(data_type):
def data_type_to_dace_typeclass(data_type: common.DataType) -> dtypes.typeclass:
dtype = np.dtype(common.data_type_to_typestr(data_type))
return dace.dtypes.typeclass(dtype.type)
return dtypes.typeclass(dtype.type)


def get_axis_bound_str(axis_bound, var_name):
Expand Down Expand Up @@ -65,5 +65,7 @@ def get_axis_bound_diff_str(axis_bound1, axis_bound2, var_name: str):


@lru_cache(maxsize=None)
def get_dace_symbol(name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32):
return dace.symbol(name, dtype=data_type_to_dace_typeclass(dtype))
def get_dace_symbol(
name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32
) -> symbolic.symbol:
return symbolic.symbol(name, dtype=data_type_to_dace_typeclass(dtype))
28 changes: 13 additions & 15 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import dace
import dace.data
import numpy as np
from dace import data, dtypes, properties, subsets, symbolic

from gt4py import eve
from gt4py.cartesian.gtc import common, oir
Expand All @@ -23,33 +22,32 @@
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo:
def get_dace_debuginfo(node: common.LocNode) -> dtypes.DebugInfo:
if node.loc is None:
return dace.dtypes.DebugInfo(0)
return dtypes.DebugInfo(0)

return dace.dtypes.DebugInfo(
return dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)


def array_dimensions(array: dace.data.Array):
dims = [
def array_dimensions(array: data.Array) -> list[bool]:
return [
any(
re.match(f"__.*_{k}_stride", str(sym))
for st in array.strides
for sym in dace.symbolic.pystr_to_symbolic(st).free_symbols
for sym in symbolic.pystr_to_symbolic(st).free_symbols
)
or any(
re.match(f"__{k}", str(sym))
for sh in array.shape
for sym in dace.symbolic.pystr_to_symbolic(sh).free_symbols
for sym in symbolic.pystr_to_symbolic(sh).free_symbols
)
for k in "IJK"
]
return dims


def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]:
def replace_strides(arrays: list[data.Array], get_layout_map) -> dict[str, str]:
symbol_mapping = {}
for array in arrays:
dims = array_dimensions(array)
Expand All @@ -61,7 +59,7 @@ def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str,
for idx in reversed(np.argsort(layout)):
symbol = array.strides[idx]
if symbol.is_symbol:
symbol_mapping[str(symbol)] = dace.symbolic.pystr_to_symbolic(stride)
symbol_mapping[str(symbol)] = symbolic.pystr_to_symbolic(stride)
stride *= array.shape[idx]
return symbol_mapping

Expand Down Expand Up @@ -318,7 +316,7 @@ def compute_dcir_access_infos(
collect_write=True,
include_full_domain=False,
**kwargs,
) -> dace.properties.DictProperty:
) -> properties.DictProperty:
if block_extents is None:
assert isinstance(oir_node, oir.Stencil)
block_extents = compute_horizontal_block_extents(oir_node)
Expand Down Expand Up @@ -514,7 +512,7 @@ def make_dace_subset(
context_info: dcir.FieldAccessInfo,
access_info: dcir.FieldAccessInfo,
data_dims: Tuple[int, ...],
) -> dace.subsets.Range:
) -> subsets.Range:
clamped_access_info = access_info
clamped_context_info = context_info
for axis in access_info.axes():
Expand All @@ -531,7 +529,7 @@ def make_dace_subset(
].to_dace_symbolic()
res_ranges.append((subset_start - context_start, subset_end - context_start - 1, 1))
res_ranges.extend((0, dim - 1, 1) for dim in data_dims)
return dace.subsets.Range(res_ranges)
return subsets.Range(res_ranges)


def untile_memlets(memlets: Sequence[dcir.Memlet], axes: Sequence[dcir.Axis]) -> List[dcir.Memlet]:
Expand Down
Loading