Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,10 @@ def _parse_sdfg(
return None

with DaCeProgress(config, f"Parse code of {dace_program.name} to SDFG"):
closure = dace_program.__sdfg_closure__()
sdfg = dace_program.to_sdfg(
*args,
**dace_program.__sdfg_closure__(),
**closure,
**kwargs,
save=False,
simplify=False,
Expand Down
59 changes: 37 additions & 22 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ndsl.logging import ndsl_log
from ndsl.quantity import Quantity
from ndsl.quantity.field_bundle import FieldBundleType, MarkupFieldBundleType
from ndsl.quantity.tracer_bundle_type import MarkupTracerBundleType
from ndsl.testing.comparison import LegacyMetric


Expand Down Expand Up @@ -334,6 +335,11 @@ def __init__(
types.name, do_markup=False
)

if isinstance(types, MarkupTracerBundleType):
raise NotImplementedError(
"TracerBundle markup types can't be resolved yet."
)

self.stencil_object = gtscript.stencil(
definition=func,
externals=externals,
Expand Down Expand Up @@ -379,9 +385,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> None:
ndsl_log.debug(f"Running {self._func_name}")

# Marshal arguments
args_list = list(args)
_convert_quantities_to_storage(args_list, kwargs)
args = tuple(args_list)
args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs)
args_as_kwargs = dict(zip(self._argument_names, args))

# Ranks comparison tool
Expand Down Expand Up @@ -524,25 +528,36 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t
)


def _convert_quantities_to_storage(args, kwargs): # type: ignore[no-untyped-def]
for i, arg in enumerate(args):
try:
# Check that 'dims' is an attribute of arg. If so,
# this means it's a Quantity, so we need
# to pull off the ndarray.
arg.dims
args[i] = arg.data
except AttributeError:
pass
for name, arg in kwargs.items():
try:
# Check that 'dims' is an attribute of arg. If so,
# this means it's a Quantity, so we need
# to pull off the ndarray.
arg.dims
kwargs[name] = arg.data
except AttributeError:
pass
def _convert_NDSL_concepts_to_storage(args: tuple, kwargs: dict) -> tuple[tuple, dict]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

Alternative is to have a ABC class and/or implement a __to_gt4py_storage__ functions that would be tested, like DaCe or numpy does for it's interface systems

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is really just to make it work. I could see something like a GT4PyStorageConvertible base class or so that would expose a function convert the concept to something that GT4Py understands. I'll come back to that once we have a working prototype (including orchestration).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also be a "refactor later" tasks, it's in the back enough to be "safe"

"""Go through the list of args and kwargs to replace NDSL concepts.

This function replaces NDSL concepts (like a `Quantity`) in the argument list
with things that a GT4Py stencils understands (e.g. ndarrays).
"""

# The `args` tuple is immutable (all tuples are), so let's build a temporary
# argument list, such that we can replace arguments in that list.
arg_list = list(args)

for index, argument in enumerate(args):
# if isinstance(argument, TracerBundle):
# # Reduce the TracerBundle to a Quantity (which is handled below)
# arg_list[index] = argument.data.data

if isinstance(argument, Quantity):
# For Quantities, we need to pass on the underlying ndarray
arg_list[index] = argument.data

for name, argument in kwargs.items():
# if isinstance(argument, TracerBundle):
# # Reduce the TracerBundle to a Quantity (which is handled below)
# kwargs[name] = argument.data

if isinstance(argument, Quantity):
# For Quantities, we need to pass on the underlying ndarray
kwargs[name] = argument.data

return (tuple(arg_list), kwargs)


class GridIndexing:
Expand Down
29 changes: 20 additions & 9 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
number_of_halo_points: int = 0,
raise_on_data_copy: bool = False,
):
"""Initialize a Quantity.

Expand All @@ -54,6 +55,7 @@ def __init__(
allow_mismatch_float_precision (bool, optional): allow for precision that is
not the simulation-wide default configuration. Defaults to False.
number_of_halo_points (int, optional): Number of halo points used. Defaults to 0.
raise_on_data_copy: raise if `data` is copied into this quantity

Raises:
ValueError: Data-type mismatch between configuration and input-data
Expand All @@ -68,10 +70,13 @@ def __init__(
f"Floating-point data type mismatch, asked for {data.dtype}, "
f"Pace configured for {Float}"
)
if origin is None:
origin = (0,) * len(dims) # default origin at origin of array
else:
origin = tuple(origin)

# Track if we copied data and if so, raise at the end in case
# `raise_on_data_copy` is True. The initialization errors on the safe side.
did_copy_data: bool = True

# default origin at origin of array
origin = (0,) * len(dims) if origin is None else tuple(origin)

if extent is None:
extent = tuple(length - start for length, start in zip(data.shape, origin))
Expand Down Expand Up @@ -104,19 +109,20 @@ def __init__(
]
)

self._data = (
data
if is_optimal_layout(data, dimensions)
else self._initialize_data(
if is_optimal_layout(data, dimensions):
self._data = data
did_copy_data = False
else:
self._data = self._initialize_data(
data,
origin=origin,
gt4py_backend=gt4py_backend,
dimensions=dimensions,
)
)
else:
# We have no info about the gt4py_backend, so just assign it.
self._data = data
did_copy_data = False

_validate_quantity_property_lengths(data.shape, dims, origin, extent)
self._metadata = QuantityMetadata(
Expand All @@ -134,6 +140,11 @@ def __init__(
self.data, self.dims, self.origin, self.extent
)

if raise_on_data_copy and did_copy_data:
raise RuntimeError(
"Data was copied into this quantity despite `raise_on_copy_data=True`."
)

@classmethod
def from_data_array(
cls,
Expand Down
194 changes: 194 additions & 0 deletions ndsl/quantity/tracer_bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import copy
from enum import Enum, auto
from typing import Any

from dace import SDFG, SDFGState
from dace.data import create_datadescriptor
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python.newast import ProgramVisitor

from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.initialization.allocator import Quantity, QuantityFactory
from ndsl.quantity.tracer_bundle_type import TracerBundleTypeRegistry


@oprepo.replaces_method("ndsl.quantity.tracer_bundle.TracerBundle", "size")
def _tracer_bundle_fill_tracer(
pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any
) -> None:
raise NotImplementedError("let's just see if we get here")


@oprepo.replaces_method("ndsl.quantity.TracerBundle", "size")
def _tracer_bundle_fill_tracer_2(
pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any
) -> None:
raise NotImplementedError("let's just see if we get here 2")


@oprepo.replaces_method("tracers", "size")
def _tracer_bundle_fill_tracer_3(
pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any
) -> None:
raise NotImplementedError("let's just see if we get here 3")


@oprepo.replaces("fill_tracer_by_name")
def _fill_tracer_by_name(
pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any
) -> None:
bundle = args[0]
tracer_name = args[1]
fill_value = args[2]

array_name = f"bundle_{bundle.type_name}"
if array_name not in sdfg.arrays:
sdfg.arrays[array_name] = create_datadescriptor(bundle.data.data)

# insert tasklet to assign the value

# connect tasklet. add missing inputs if necessary

raise NotImplementedError("let's see if we get here")


class Region(Enum):
compute_domain = auto()


class Tracer(Quantity):
"""A Tracer is a specialized Quantity, grouped together in a TracerBundle."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def fill(self, value: Any, *, restrict_to: Region | None = None) -> None:
if restrict_to is None:
super().data[:] = value
elif restrict_to is Region.compute_domain:
super().field[:] = value
else:
raise NotImplementedError(f"Unknown restriction {restrict_to}.")


_TracerName = str
_TracerIndex = int
_TracerMapping = dict[_TracerName, _TracerIndex]
_TracerDataMapping = dict[_TracerIndex, Tracer]


class TracerBundle:
"""A TracerBundle groups a given set of named/nameless tracers into a single
four-dimensional Quantity.

All tracers can be accessed by index, e.g. `tracer[1]`. Named tracers can be
accessed by name too, e.g. `tracer.vapor` assuming `vapor` is defined in the
`mapping` of names to tracer indices. `len(tracers)` returns the size of this
TracerBundle.
"""

def __init__(
self,
*,
type_name: str,
quantity_factory: QuantityFactory,
mapping: _TracerMapping = {},
unit: str = "g/kg",
) -> None:
"""
Initialize a TracerBundle of a given size.

Args:
type_name (str): name under which this bundle's type is registered.
quantity_factory: QuantityFactory to build tracers with.
mapping: Optional mapping of names to tracer ids, e.g. `{"vapor": 3}`.
unit: Optional unit of the tracers (one for all).
"""
types: Any = TracerBundleTypeRegistry.T(type_name, do_markup=False)

size = types[0].data_dims[0]
factory = _tracer_quantity_factory(quantity_factory, size)

# TODO: zeros() or empty()? should this be an option?
self.data = factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "tracers"], dtype=types[0].dtype, units=unit
)
self._size = size
self._name_mapping = mapping
self._data_mapping: _TracerDataMapping = {}
self.type_name = type_name

def __len__(self) -> int:
"""Number of tracers in this bundle."""
return self._size

def size(self) -> int:
return self._size

def __getattr__(self, name: _TracerName) -> Tracer | None:
"""Access tracers by name, e.g. `tracers.ice`."""

index = self._name_mapping.get(name, None)
if index is None:
# This replicates as close possible the default behavior of getattr
# without breaking orchestration
return None

return self._by_index(index)

def __getitem__(self, index: _TracerIndex) -> Tracer:
"""Access tracers by index, e.g. `tracers[i]`."""
return self._by_index(index)

def _by_index(self, index: _TracerIndex) -> Tracer:
if index < 0 or index >= self._size:
# Note: it is important to raise an IndexError to support iterations of
# the form `for tracer in tracers`.
raise IndexError(f"You can only select tracers in range [0, {self._size}).")

# Memoize tracers accessed such that we always return the same instance
# regardless of whether users access through __getattr__() or __getitem__().
if index not in self._data_mapping:
self._data_mapping[index] = Tracer(
data=self.data.data[:, :, :, index],
dims=self.data.dims[:-1],
origin=self.data.origin[:-1],
extent=self.data.extent[:-1],
units=self.data.units,
# Ensure we never copy data into a tracer
raise_on_data_copy=True,
)

return self._data_mapping[index]

def fill_tracer(
self, index: _TracerIndex, *, value: Any, compute_domain_only: bool = False
) -> None:
if compute_domain_only:
self.data.field[:, :, :, index] = value
else:
self.data.data[:, :, :, index] = value

def fill_tracer_by_name(
self, name: str, *, value: Any, compute_domain_only: bool = False
) -> None:
index = self._name_mapping[name]

if compute_domain_only:
self.data.field[:, :, :, index] = value
else:
self.data.data[:, :, :, index] = value


def _tracer_quantity_factory(
quantity_factory: QuantityFactory, number_of_tracers: int
) -> QuantityFactory:
"""Create a tracer factory from a given cartesian quantity factory.

Args:
quantity_factory: Cartesian 3D factory to start from.
number_of_tracers: number of tracers in this bundle.
"""
tracer_factory = copy.copy(quantity_factory)
tracer_factory.set_extra_dim_lengths(tracers=number_of_tracers)
return tracer_factory
Loading
Loading