diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 43b6c4f3..4d959d15 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -42,6 +42,7 @@ ) from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp +from ndsl.quantity import Quantity, State _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False @@ -543,12 +544,18 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] - func = type.__getattribute__(type(obj), method_to_orchestrate) + func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime + for arg_name, annotation in func.__annotations__.items(): + if annotation in [Quantity, State] or ( + isinstance(annotation, type) and issubclass(annotation, State) + ): + func.__annotations__[arg_name] = DaceCompiletime + # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call wrapped = _LazyComputepathMethod(func, config).__get__(obj) diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index d0c9f1b0..2ff2c16a 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -1,8 +1,11 @@ -from ndsl import QuantityFactory, StencilFactory +import dataclasses + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +from ndsl.quantity import Quantity, State def _stencil(out: Field[float]): @@ -40,3 +43,40 @@ def test_memory_reallocation(): code(qty_B) assert (qty_A.field[0, 0, :] == 3).all() assert (qty_B.field[0, 0, :] == 2).all() + + +@dataclasses.dataclass +class AState(State): + the_quantity: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +class DefaultTypeProgram(NDSLRuntime): + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ): + super().__init__(stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo(_stencil, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, a_quantity: Quantity, a_state: AState): + self.stencil(a_quantity) + self.stencil(a_state.the_quantity) + + +def test_default_types_are_compiletime(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 2, 0 + ) + qty_A = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "A") + state_A = AState.zeros(quantity_factory) + code = DefaultTypeProgram(stencil_factory, quantity_factory) + code(qty_A, state_A)