Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from ndsl.logging import ndsl_log
from ndsl.optional_imports import cupy as cp
from ndsl.quantity import Quantity, State


_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False
Expand Down Expand Up @@ -543,12 +544,18 @@ def orchestrate(
if dace_compiletime_args is None:
dace_compiletime_args = []

func = type.__getattribute__(type(obj), method_to_orchestrate)
func: Callable = type.__getattribute__(type(obj), method_to_orchestrate)

# Flag argument as dace.constant
for argument in dace_compiletime_args:
func.__annotations__[argument] = DaceCompiletime

for arg_name, annotation in func.__annotations__.items():
if annotation in [Quantity, State] or (
isinstance(annotation, type) and issubclass(annotation, State)
):
func.__annotations__[arg_name] = DaceCompiletime

# Build DaCe orchestrated wrapper
# This is a JIT object, e.g. DaCe compilation will happen on call
wrapped = _LazyComputepathMethod(func, config).__get__(obj)
Expand Down
44 changes: 42 additions & 2 deletions tests/dsl/orchestration/test_call.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ndsl import QuantityFactory, StencilFactory
import dataclasses

from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.boilerplate import get_factories_single_tile_orchestrated
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval
from ndsl.quantity import Quantity, State


def _stencil(out: Field[float]):
Expand Down Expand Up @@ -40,3 +43,40 @@ def test_memory_reallocation():
code(qty_B)
assert (qty_A.field[0, 0, :] == 3).all()
assert (qty_B.field[0, 0, :] == 2).all()


@dataclasses.dataclass
class AState(State):
the_quantity: Quantity = dataclasses.field(
metadata={
"name": "A",
"dims": [X_DIM, Y_DIM, Z_DIM],
"units": "kg kg-1",
"intent": "?",
"dtype": Float,
}
)


class DefaultTypeProgram(NDSLRuntime):
def __init__(
self,
stencil_factory: StencilFactory,
quantity_factory: QuantityFactory,
):
super().__init__(stencil_factory.config.dace_config)
self.stencil = stencil_factory.from_dims_halo(_stencil, [X_DIM, Y_DIM, Z_DIM])

def __call__(self, a_quantity: Quantity, a_state: AState):
self.stencil(a_quantity)
self.stencil(a_state.the_quantity)


def test_default_types_are_compiletime():
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0
)
qty_A = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "A")
state_A = AState.zeros(quantity_factory)
code = DefaultTypeProgram(stencil_factory, quantity_factory)
code(qty_A, state_A)