From 18c9d336750345efbef308d089a2b763ffe25cdd Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 15 Oct 2025 13:40:41 -0400 Subject: [PATCH 01/21] Temporaries base dataclass Allow `units` to not be specified + unit test --- ndsl/__init__.py | 3 +- ndsl/quantity/__init__.py | 8 +-- ndsl/quantity/quantity.py | 10 ++- ndsl/quantity/state.py | 6 +- ndsl/quantity/temporaries.py | 24 +++++++ tests/quantity/test_temporaries.py | 102 +++++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 ndsl/quantity/temporaries.py create mode 100644 tests/quantity/test_temporaries.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 097c5dca..6f37e548 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -27,7 +27,7 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity, State +from .quantity import Quantity, State, Temporaries from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator @@ -86,4 +86,5 @@ "Allocator", "MetaEnumStr", "State", + "Temporaries", ] diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index ee7e4d78..2f1d2d70 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,11 +1,7 @@ from .metadata import QuantityHaloSpec, QuantityMetadata from .quantity import Quantity from .state import State +from .temporaries import Temporaries -__all__ = [ - "Quantity", - "QuantityMetadata", - "QuantityHaloSpec", - "State", -] +__all__ = ["Quantity", "QuantityMetadata", "QuantityHaloSpec", "State", "Temporaries"] diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index a0ba12be..006b5f4f 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -35,6 +35,8 @@ def __init__( extent: Optional[Sequence[int]] = None, gt4py_backend: Union[str, None] = None, allow_mismatch_float_precision: bool = False, + *, + transient: bool = False, ): """ Initialize a Quantity. @@ -49,6 +51,8 @@ def __init__( be derived from a Storage if given as the data argument, otherwise the storage attribute is disabled and will raise an exception. Will raise a TypeError if this is given with a gt4py storage type as data + transient: [DEV ONLY] Flag the quantity for Transient use within DaCe + orchestration """ if ( @@ -70,6 +74,8 @@ def __init__( else: extent = tuple(extent) + self.transient = transient + if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. data = np.asarray(data) @@ -322,7 +328,9 @@ def __descriptor__(self) -> Any: If the internal data given doesn't follow the protocol it will most likely fail. """ - return dace.data.create_datadescriptor(self.data) + desc = dace.data.create_datadescriptor(self.data) + desc.transient = self.transient + return desc def transpose( self, diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 035b2f6f..c99f6ad9 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -58,7 +58,11 @@ def _init_recursive(cls): initial_quantities[_field.name] = quantity_factory_allocator( _field.metadata["dims"], - _field.metadata["units"], + ( + _field.metadata["units"] + if "units" in _field.metadata.keys() + else "unspecified" + ), dtype=_field.metadata["dtype"], allow_mismatch_float_precision=True, ) diff --git a/ndsl/quantity/temporaries.py b/ndsl/quantity/temporaries.py new file mode 100644 index 00000000..d8845fd1 --- /dev/null +++ b/ndsl/quantity/temporaries.py @@ -0,0 +1,24 @@ +import dataclasses + +from ndsl.quantity import Quantity, State + + +@dataclasses.dataclass +class Temporaries(State): + """Base class to collect temporaries Quantities. + + You _cannot_ expect the temporaries memory to be available outside of + the class it has been defined in. + + Shares the `ndsl.quantity.State` API, see `State` docs. + """ + + def __post_init__(self): + def _post_init_recursive(dataclass: Temporaries): + for _field in dataclasses.fields(dataclass): + if dataclasses.is_dataclass(_field.type): + _post_init_recursive(dataclass.__getattribute__(_field.name)) + elif _field.type == Quantity: + dataclass.__getattribute__(_field.name).transient = True + + _post_init_recursive(self) diff --git a/tests/quantity/test_temporaries.py b/tests/quantity/test_temporaries.py new file mode 100644 index 00000000..5899dfbb --- /dev/null +++ b/tests/quantity/test_temporaries.py @@ -0,0 +1,102 @@ +import dataclasses + +from ndsl import ( + Quantity, + State, + StencilFactory, + Temporaries, + orchestrate, + QuantityFactory, +) +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +@dataclasses.dataclass +class CodeTmps(Temporaries): + @dataclasses.dataclass + class Inner: + TmpA: Quantity = dataclasses.field( + metadata={ + "dims": [X_DIM, Y_DIM, Z_DIM], + "dtype": Float, + } + ) + + inner: Inner + TmpC: Quantity = dataclasses.field( + metadata={ + "dims": [X_DIM, Y_DIM, Z_DIM], + "dtype": Float, + } + ) + + +@dataclasses.dataclass +class CodeState(State): + @dataclasses.dataclass + class Inner: + A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + inner: Inner + C: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def the_copy_stencil(from_: FloatField, to: FloatField): + with computation(PARALLEL), interval(...): + to = from_ + + +class Code: + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + dace_compiletime_args=["state", "tmps"], + ) + + self.copy = stencil_factory.from_dims_halo( + the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] + ) + self.tmps = CodeTmps.zeros(quantity_factory) + + def __call__(self, state: CodeState): + self.copy(state.inner.A, self.tmps.inner.TmpA) + self.copy(self.tmps.inner.TmpA, state.C) + + +def test_temporaries(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + + state = CodeState.full(quantity_factory, 42.42) + + c = Code(stencil_factory, quantity_factory) + c(state) + + assert c.tmps.inner.TmpA.transient + assert c.tmps.TmpC.transient + assert not state.inner.A.transient + assert not state.C.transient + assert (state.inner.A.data[:] == state.C.data[:]).all() From 1a241f0f21dfddf303245b9234baf1fb69a817c2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 15 Oct 2025 13:48:54 -0400 Subject: [PATCH 02/21] Lint --- tests/quantity/test_temporaries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantity/test_temporaries.py b/tests/quantity/test_temporaries.py index 5899dfbb..4ee5e110 100644 --- a/tests/quantity/test_temporaries.py +++ b/tests/quantity/test_temporaries.py @@ -2,11 +2,11 @@ from ndsl import ( Quantity, + QuantityFactory, State, StencilFactory, Temporaries, orchestrate, - QuantityFactory, ) from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float From 207964c16fbb352a4ebe35d4e818c981541365b3 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 14:33:00 -0400 Subject: [PATCH 03/21] Hide `_transient` flag of Quantity away --- ndsl/quantity/quantity.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 006b5f4f..03433b29 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -35,8 +35,6 @@ def __init__( extent: Optional[Sequence[int]] = None, gt4py_backend: Union[str, None] = None, allow_mismatch_float_precision: bool = False, - *, - transient: bool = False, ): """ Initialize a Quantity. @@ -51,8 +49,6 @@ def __init__( be derived from a Storage if given as the data argument, otherwise the storage attribute is disabled and will raise an exception. Will raise a TypeError if this is given with a gt4py storage type as data - transient: [DEV ONLY] Flag the quantity for Transient use within DaCe - orchestration """ if ( @@ -74,7 +70,10 @@ def __init__( else: extent = tuple(extent) - self.transient = transient + # Dev note: this is a hidden flag to be able to give DaCe the context + # of use for further optimization. It's the equivqlent to `Local` + # in NDSL lingo + self._transient = False if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. @@ -329,7 +328,7 @@ def __descriptor__(self) -> Any: fail. """ desc = dace.data.create_datadescriptor(self.data) - desc.transient = self.transient + desc.transient = self._transient return desc def transpose( From 63600319ed8c14a3e00a3f0a50ae5516c1d71adf Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 14:34:12 -0400 Subject: [PATCH 04/21] `QuantityFactory` has now a `is_local` option --- ndsl/initialization/allocator.py | 48 ++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index a4b55b7f..b1db48fc 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -76,13 +76,20 @@ def empty( dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """Allocate a Quantity - values are random. Equivalent to `numpy.empty`""" return self._allocate( - self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision + self._numpy.empty, + dims, + units, + dtype, + allow_mismatch_float_precision, + is_local, ) def zeros( @@ -90,13 +97,20 @@ def zeros( dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value 0. Equivalent to `numpy.zeros`""" return self._allocate( - self._numpy.zeros, dims, units, dtype, allow_mismatch_float_precision + self._numpy.zeros, + dims, + units, + dtype, + allow_mismatch_float_precision, + is_local, ) def ones( @@ -104,13 +118,20 @@ def ones( dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value 1. Equivalent to `numpy.ones`""" return self._allocate( - self._numpy.ones, dims, units, dtype, allow_mismatch_float_precision + self._numpy.ones, + dims, + units, + dtype, + allow_mismatch_float_precision, + is_local, ) def full( @@ -119,13 +140,20 @@ def full( units: str, value, # no type hint because it would be a TypeVar = Type[dtype] and mypy says no dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value. Equivalent to `numpy.full`""" quantity = self._allocate( - self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision + self._numpy.empty, + dims, + units, + dtype, + allow_mismatch_float_precision, + is_local, ) quantity.data[:] = value return quantity @@ -135,7 +163,9 @@ def from_array( data: np.ndarray, dims: Sequence[str], units: str, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """ Create a Quantity from a numpy array. @@ -148,8 +178,8 @@ def from_array( units=units, dtype=data.dtype, allow_mismatch_float_precision=allow_mismatch_float_precision, + is_local=is_local, ) - base.data[:] = base.np.asarray(data) return base def from_compute_array( @@ -157,7 +187,9 @@ def from_compute_array( data: np.ndarray, dims: Sequence[str], units: str, + *, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: """ Create a Quantity from a numpy array. @@ -170,6 +202,7 @@ def from_compute_array( units=units, dtype=data.dtype, allow_mismatch_float_precision=allow_mismatch_float_precision, + is_local=is_local, ) base.view[:] = base.np.asarray(data) return base @@ -181,6 +214,7 @@ def _allocate( units: str, dtype: type = Float, allow_mismatch_float_precision: bool = False, + is_local: bool = False, ) -> Quantity: origin = self.sizer.get_origin(dims) extent = self.sizer.get_extent(dims) @@ -201,7 +235,7 @@ def _allocate( ) except TypeError: data = allocator(shape, dtype=dtype) - return Quantity( + quantity = Quantity( data, dims=dims, units=units, @@ -210,6 +244,8 @@ def _allocate( gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, ) + quantity._transient = is_local + return quantity def get_quantity_halo_spec( self, From fab94909cec9ffd7fc8d4a1c05c0dc47d1f4e49f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 14:36:10 -0400 Subject: [PATCH 05/21] Update wording to `Local`, trash `Temporaries` state idea --- ndsl/quantity/temporaries.py | 24 ------------ .../{test_temporaries.py => test_locals.py} | 38 +++++-------------- 2 files changed, 9 insertions(+), 53 deletions(-) delete mode 100644 ndsl/quantity/temporaries.py rename tests/quantity/{test_temporaries.py => test_locals.py} (70%) diff --git a/ndsl/quantity/temporaries.py b/ndsl/quantity/temporaries.py deleted file mode 100644 index d8845fd1..00000000 --- a/ndsl/quantity/temporaries.py +++ /dev/null @@ -1,24 +0,0 @@ -import dataclasses - -from ndsl.quantity import Quantity, State - - -@dataclasses.dataclass -class Temporaries(State): - """Base class to collect temporaries Quantities. - - You _cannot_ expect the temporaries memory to be available outside of - the class it has been defined in. - - Shares the `ndsl.quantity.State` API, see `State` docs. - """ - - def __post_init__(self): - def _post_init_recursive(dataclass: Temporaries): - for _field in dataclasses.fields(dataclass): - if dataclasses.is_dataclass(_field.type): - _post_init_recursive(dataclass.__getattribute__(_field.name)) - elif _field.type == Quantity: - dataclass.__getattribute__(_field.name).transient = True - - _post_init_recursive(self) diff --git a/tests/quantity/test_temporaries.py b/tests/quantity/test_locals.py similarity index 70% rename from tests/quantity/test_temporaries.py rename to tests/quantity/test_locals.py index 4ee5e110..57f6e578 100644 --- a/tests/quantity/test_temporaries.py +++ b/tests/quantity/test_locals.py @@ -5,7 +5,6 @@ QuantityFactory, State, StencilFactory, - Temporaries, orchestrate, ) from ndsl.boilerplate import get_factories_single_tile_orchestrated @@ -14,26 +13,6 @@ from ndsl.dsl.typing import FloatField -@dataclasses.dataclass -class CodeTmps(Temporaries): - @dataclasses.dataclass - class Inner: - TmpA: Quantity = dataclasses.field( - metadata={ - "dims": [X_DIM, Y_DIM, Z_DIM], - "dtype": Float, - } - ) - - inner: Inner - TmpC: Quantity = dataclasses.field( - metadata={ - "dims": [X_DIM, Y_DIM, Z_DIM], - "dtype": Float, - } - ) - - @dataclasses.dataclass class CodeState(State): @dataclasses.dataclass @@ -78,14 +57,16 @@ def __init__( self.copy = stencil_factory.from_dims_halo( the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] ) - self.tmps = CodeTmps.zeros(quantity_factory) + self.local = quantity_factory.empty( + [X_DIM, Y_DIM, Z_DIM], units="n/a", is_local=True + ) def __call__(self, state: CodeState): - self.copy(state.inner.A, self.tmps.inner.TmpA) - self.copy(self.tmps.inner.TmpA, state.C) + self.copy(state.inner.A, self.local) + self.copy(self.local, state.C) -def test_temporaries(): +def test_locals(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( 5, 5, 3, 0, backend="dace:cpu_kfirst" ) @@ -95,8 +76,7 @@ def test_temporaries(): c = Code(stencil_factory, quantity_factory) c(state) - assert c.tmps.inner.TmpA.transient - assert c.tmps.TmpC.transient - assert not state.inner.A.transient - assert not state.C.transient + assert c.local._transient + assert not state.inner.A._transient + assert not state.C._transient assert (state.inner.A.data[:] == state.C.data[:]).all() From d0f36b23a064878aff706586d1cc40a97c8d402c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 14:37:28 -0400 Subject: [PATCH 06/21] lint --- tests/quantity/test_locals.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/quantity/test_locals.py b/tests/quantity/test_locals.py index 57f6e578..4a452147 100644 --- a/tests/quantity/test_locals.py +++ b/tests/quantity/test_locals.py @@ -1,12 +1,6 @@ import dataclasses -from ndsl import ( - Quantity, - QuantityFactory, - State, - StencilFactory, - orchestrate, -) +from ndsl import Quantity, QuantityFactory, State, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float from ndsl.dsl.gt4py import PARALLEL, computation, interval From c86f87e8168230f402030329bad742394a3033d4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 14:46:06 -0400 Subject: [PATCH 07/21] Public API clean up --- ndsl/__init__.py | 3 +-- ndsl/quantity/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 6f37e548..097c5dca 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -27,7 +27,7 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity, State, Temporaries +from .quantity import Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator @@ -86,5 +86,4 @@ "Allocator", "MetaEnumStr", "State", - "Temporaries", ] diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index 2f1d2d70..e6b6fe47 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,7 +1,6 @@ from .metadata import QuantityHaloSpec, QuantityMetadata from .quantity import Quantity from .state import State -from .temporaries import Temporaries -__all__ = ["Quantity", "QuantityMetadata", "QuantityHaloSpec", "State", "Temporaries"] +__all__ = ["Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"] From 2351c17f1f8f3c69076ac5fd7908283dc733e6d4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 15:08:34 -0400 Subject: [PATCH 08/21] Simplify and fix test --- tests/quantity/test_locals.py | 51 +++++++++-------------------------- 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/tests/quantity/test_locals.py b/tests/quantity/test_locals.py index 4a452147..163da731 100644 --- a/tests/quantity/test_locals.py +++ b/tests/quantity/test_locals.py @@ -1,38 +1,10 @@ -import dataclasses - -from ndsl import Quantity, QuantityFactory, State, StencilFactory, orchestrate +from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float +from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField -@dataclasses.dataclass -class CodeState(State): - @dataclasses.dataclass - class Inner: - A: Quantity = dataclasses.field( - metadata={ - "name": "A", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg kg-1", - "intent": "?", - "dtype": Float, - } - ) - - inner: Inner - C: Quantity = dataclasses.field( - metadata={ - "name": "C", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg kg-1", - "intent": "?", - "dtype": Float, - } - ) - - def the_copy_stencil(from_: FloatField, to: FloatField): with computation(PARALLEL), interval(...): to = from_ @@ -45,7 +17,7 @@ def __init__( orchestrate( obj=self, config=stencil_factory.config.dace_config, - dace_compiletime_args=["state", "tmps"], + dace_compiletime_args=["A", "B"], ) self.copy = stencil_factory.from_dims_halo( @@ -55,9 +27,9 @@ def __init__( [X_DIM, Y_DIM, Z_DIM], units="n/a", is_local=True ) - def __call__(self, state: CodeState): - self.copy(state.inner.A, self.local) - self.copy(self.local, state.C) + def __call__(self, A, B): + self.copy(A, self.local) + self.copy(self.local, B) def test_locals(): @@ -65,12 +37,13 @@ def test_locals(): 5, 5, 3, 0, backend="dace:cpu_kfirst" ) - state = CodeState.full(quantity_factory, 42.42) + A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") c = Code(stencil_factory, quantity_factory) - c(state) + c(A_, B_) assert c.local._transient - assert not state.inner.A._transient - assert not state.C._transient - assert (state.inner.A.data[:] == state.C.data[:]).all() + assert not A_._transient + assert not B_._transient + assert (A_.field[:] == B_.field[:]).all() From 0bd632a44d3c82e576ceec95fc9afede0aeb1d4c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 16 Oct 2025 15:25:17 -0400 Subject: [PATCH 09/21] Oops, restore code for `from_array` allocator --- ndsl/initialization/allocator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index b1db48fc..e10634e3 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -180,6 +180,7 @@ def from_array( allow_mismatch_float_precision=allow_mismatch_float_precision, is_local=is_local, ) + base.data[:] = base.np.asarray(data) return base def from_compute_array( From b3eaea16e00618d9c122f95ec26a68c4b0113080 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 15:35:52 -0400 Subject: [PATCH 10/21] Remove keyword allocation --- ndsl/initialization/allocator.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 06be0b90..afff527b 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -79,7 +79,6 @@ def empty( dtype: type = Float, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """Allocate a Quantity - values are random. @@ -90,7 +89,6 @@ def empty( units, dtype, allow_mismatch_float_precision, - is_local, ) def zeros( @@ -100,7 +98,6 @@ def zeros( dtype: type = Float, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value 0. @@ -111,7 +108,6 @@ def zeros( units, dtype, allow_mismatch_float_precision, - is_local, ) def ones( @@ -121,7 +117,6 @@ def ones( dtype: type = Float, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value 1. @@ -132,7 +127,6 @@ def ones( units, dtype, allow_mismatch_float_precision, - is_local, ) def full( @@ -143,7 +137,6 @@ def full( dtype: type = Float, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """Allocate a Quantity and fill it with the value. @@ -154,7 +147,6 @@ def full( units, dtype, allow_mismatch_float_precision, - is_local, ) quantity.data[:] = value return quantity @@ -166,7 +158,6 @@ def from_array( units: str, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """ Create a Quantity from a numpy array. @@ -179,7 +170,6 @@ def from_array( units=units, dtype=data.dtype, allow_mismatch_float_precision=allow_mismatch_float_precision, - is_local=is_local, ) base.data[:] = base.np.asarray(data) return base @@ -191,7 +181,6 @@ def from_compute_array( units: str, *, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: """ Create a Quantity from a numpy array. @@ -204,7 +193,6 @@ def from_compute_array( units=units, dtype=data.dtype, allow_mismatch_float_precision=allow_mismatch_float_precision, - is_local=is_local, ) base.view[:] = base.np.asarray(data) return base @@ -216,7 +204,6 @@ def _allocate( units: str, dtype: type = Float, allow_mismatch_float_precision: bool = False, - is_local: bool = False, ) -> Quantity: origin = self.sizer.get_origin(dims) extent = self.sizer.get_extent(dims) @@ -246,7 +233,6 @@ def _allocate( gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, ) - quantity._transient = is_local return quantity def get_quantity_halo_spec( From e5d8abd1844460823eb830570bad04814b1ccc44 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:21:58 -0400 Subject: [PATCH 11/21] Introduce `Local` and `NDSLRuntime` --- ndsl/__init__.py | 5 +- ndsl/dsl/dace/__init__.py | 5 ++ ndsl/dsl/ndsl_runtime.py | 115 ++++++++++++++++++++++++++++++++++ ndsl/quantity/__init__.py | 5 +- ndsl/quantity/local.py | 30 +++++++++ tests/quantity/test_locals.py | 37 ++++++----- 6 files changed, 180 insertions(+), 17 deletions(-) create mode 100644 ndsl/dsl/ndsl_runtime.py create mode 100644 ndsl/quantity/local.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 097c5dca..aa275711 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -15,6 +15,7 @@ StorageReport, ) from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError @@ -27,7 +28,7 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity, State +from .quantity import Local, Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator @@ -86,4 +87,6 @@ "Allocator", "MetaEnumStr", "State", + "NDSLRuntime", + "Local", ] diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index e69de29b..597c2a31 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -0,0 +1,5 @@ +from .dace_config import DaceConfig +from .orchestration import orchestrate, orchestrate_function + + +__all__ = ["DaceConfig", "orchestrate", "orchestrate_function"] diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py new file mode 100644 index 00000000..5038d455 --- /dev/null +++ b/ndsl/dsl/ndsl_runtime.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from ndsl.quantity import Quantity, Local +from ndsl.initialization.allocator import QuantityFactory +from ndsl.dsl.dace import orchestrate, DaceConfig +from ndsl.dsl.typing import Float +import warnings +import inspect +from typing import Any, Callable + +_TOP_LEVEL: object | None = None + + +class NDSLRuntime: + """Base class to tool runtime code, allows use of Locals, orchestration and + debug tools. + + The __call__ function will automatically orchestrated.""" + + def __init__(self, dace_config: DaceConfig) -> None: + self._dace_config = dace_config + + def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: + # WARNING: no code outside the `init_decorator` this is cls + # function, it will be called ONLY ONCE for monkey-patching the + # Class - not the instance ! + + def init_decorator(previous_init: Callable) -> Callable: + def new_init( + self: NDSLRuntime, + *args: list[Any], + **kwargs: dict[str, Any], + ) -> None: + global _TOP_LEVEL + if _TOP_LEVEL is None: + _TOP_LEVEL = self + previous_init(self, *args, **kwargs) + self.__post_init__() + + return new_init + + cls.__init__ = init_decorator(cls.__init__) # type: ignore[method-assign] + + def __post_init__(self) -> None: + # Check quantity allocation of NDSLRuntime supervised code + if _TOP_LEVEL == self: + + def check_for_quantity(object_: object) -> None: + for key, value in object_.__dict__.items(): + if isinstance(value, Quantity) and not isinstance(value, Local): + warnings.warn( + f"{type(self).__name__}.{key} is a Quantity instead of a Locals" + " on a NDSLRuntime - our eyebrows are frowned." + ) + elif isinstance(value, NDSLRuntime): + check_for_quantity(value) + + check_for_quantity(self) + + # Orchestrate __call__ by default + orchestrate( + obj=self, + config=self._dace_config, + ) + + def __getattribute__(self, name: str) -> Any: + attr = super().__getattribute__(name) + # We look at the direct caller frame for our own `self` + # in the locals. + # Any other case are forbidden. + if isinstance(attr, Local): + frame = inspect.currentframe() + if frame is None: + raise NotImplementedError( + "Locals check cannot locate frame. Talk to the team." + ) + caller_frame = frame.f_back + if ( + not caller_frame + or "self" not in caller_frame.f_locals + or not isinstance(caller_frame.f_locals["self"], type(self)) + ): + # We expect the original class to have been monkey-patched + # See `dace.dsl.orchestration.orchestrate` + unpatched_name = self.__name__[: -len("_patched")] + raise RuntimeError( + f"Forbidden Local access: {name} called outside of {unpatched_name}." + ) + + return attr + + def make_local( + self, + quantity_factory: QuantityFactory, + dims: list[str], + dtype: type = Float, + units: str = "unspecified", + *, + allow_mismatch_float_precision: bool = False, + ) -> Local: + quantity = quantity_factory.zeros( + dims, + units, + dtype, + allow_mismatch_float_precision=allow_mismatch_float_precision, + ) + return Local( + data=quantity.data, + dims=quantity.dims, + units=quantity.units, + origin=quantity.origin, + extent=quantity.extent, + gt4py_backend=quantity.gt4py_backend, + allow_mismatch_float_precision=allow_mismatch_float_precision, + ) diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index e6b6fe47..26120596 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -3,4 +3,7 @@ from .state import State -__all__ = ["Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"] +from .local import Local # isort: skip + + +__all__ = ["Local", "Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"] diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py new file mode 100644 index 00000000..04eb6d6a --- /dev/null +++ b/ndsl/quantity/local.py @@ -0,0 +1,30 @@ +from ndsl.quantity import Quantity +import numpy as np +from ndsl.optional_imports import cupy +from typing import Sequence + + +class Local(Quantity): + """Local is a Quantity that cannot be used outside of the class + it was allocated in.""" + + def __init__( + self, + data: np.ndarray | cupy.ndarray, + dims: Sequence[str], + units: str, + origin: Sequence[int] | None = None, + extent: Sequence[int] | None = None, + gt4py_backend: str | None = None, + allow_mismatch_float_precision: bool = False, + ): + super().__init__( + data, + dims, + units, + origin, + extent, + gt4py_backend, + allow_mismatch_float_precision, + ) + self._transient = True diff --git a/tests/quantity/test_locals.py b/tests/quantity/test_locals.py index 163da731..d1c11aa6 100644 --- a/tests/quantity/test_locals.py +++ b/tests/quantity/test_locals.py @@ -1,4 +1,6 @@ -from ndsl import QuantityFactory, StencilFactory, orchestrate +import pytest + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval @@ -10,29 +12,25 @@ def the_copy_stencil(from_: FloatField, to: FloatField): to = from_ -class Code: +class Code(NDSLRuntime): def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory ) -> None: - orchestrate( - obj=self, - config=stencil_factory.config.dace_config, - dace_compiletime_args=["A", "B"], - ) - + super().__init__(dace_config=stencil_factory.config.dace_config) self.copy = stencil_factory.from_dims_halo( the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] ) - self.local = quantity_factory.empty( - [X_DIM, Y_DIM, Z_DIM], units="n/a", is_local=True - ) + self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def test_check(self): + assert self.local._transient def __call__(self, A, B): self.copy(A, self.local) self.copy(self.local, B) -def test_locals(): +def test_local_and_transient_flags(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( 5, 5, 3, 0, backend="dace:cpu_kfirst" ) @@ -40,10 +38,19 @@ def test_locals(): A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") - c = Code(stencil_factory, quantity_factory) - c(A_, B_) + code = Code(stencil_factory, quantity_factory) + code(A_, B_) + + # Check that local is not reachable outside of Code + with pytest.raises(RuntimeError, match="Forbidden Local access:"): + assert code.local._transient - assert c.local._transient + # Check the local is properly transient - with access in Code + code.test_check() + + # Check regular quantity are not transient assert not A_._transient assert not B_._transient + + # Check numerics assert (A_.field[:] == B_.field[:]).all() From f4a03d0833869e2f019c06b7495aeed8f66cbdd6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:23:08 -0400 Subject: [PATCH 12/21] Lint new files --- ndsl/dsl/ndsl_runtime.py | 12 +++++++----- ndsl/quantity/local.py | 6 ++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 5038d455..b08019f3 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -1,13 +1,15 @@ from __future__ import annotations -from ndsl.quantity import Quantity, Local -from ndsl.initialization.allocator import QuantityFactory -from ndsl.dsl.dace import orchestrate, DaceConfig -from ndsl.dsl.typing import Float -import warnings import inspect +import warnings from typing import Any, Callable +from ndsl.dsl.dace import DaceConfig, orchestrate +from ndsl.dsl.typing import Float +from ndsl.initialization.allocator import QuantityFactory +from ndsl.quantity import Local, Quantity + + _TOP_LEVEL: object | None = None diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 04eb6d6a..c492f9e1 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -1,7 +1,9 @@ -from ndsl.quantity import Quantity +from typing import Sequence + import numpy as np + from ndsl.optional_imports import cupy -from typing import Sequence +from ndsl.quantity import Quantity class Local(Quantity): From 2f79ef766041a9c021dc5551582505d13c4f8afb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:28:02 -0400 Subject: [PATCH 13/21] Remove the odd `_transient` and tag transientness correctly in `Local` --- ndsl/quantity/local.py | 9 ++++++++- ndsl/quantity/quantity.py | 9 +-------- tests/quantity/test_locals.py | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index c492f9e1..214fe977 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -1,5 +1,6 @@ -from typing import Sequence +from typing import Any, Sequence +import dace import numpy as np from ndsl.optional_imports import cupy @@ -30,3 +31,9 @@ def __init__( allow_mismatch_float_precision, ) self._transient = True + + def __descriptor__(self) -> Any: + """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" + data = dace.data.create_datadescriptor(self.data) + data.transient = True + return data diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 8359240d..6eba3572 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -71,11 +71,6 @@ def __init__( else: extent = tuple(extent) - # Dev note: this is a hidden flag to be able to give DaCe the context - # of use for further optimization. It's the equivqlent to `Local` - # in NDSL lingo - self._transient = False - if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. data = np.asarray(data) @@ -330,9 +325,7 @@ def __descriptor__(self) -> Any: If the internal data given doesn't follow the protocol it will most likely fail. """ - desc = dace.data.create_datadescriptor(self.data) - desc.transient = self._transient - return desc + return dace.data.create_datadescriptor(self.data) def transpose( self, diff --git a/tests/quantity/test_locals.py b/tests/quantity/test_locals.py index d1c11aa6..a0cc3883 100644 --- a/tests/quantity/test_locals.py +++ b/tests/quantity/test_locals.py @@ -23,7 +23,7 @@ def __init__( self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) def test_check(self): - assert self.local._transient + assert self.local.__descriptor__().transient def __call__(self, A, B): self.copy(A, self.local) @@ -49,8 +49,8 @@ def test_local_and_transient_flags(): code.test_check() # Check regular quantity are not transient - assert not A_._transient - assert not B_._transient + assert not A_.__descriptor__().transient + assert not B_.__descriptor__().transient # Check numerics assert (A_.field[:] == B_.field[:]).all() From 83f09fd022bfa0dbd2bdd932ec27eeb8340b0c69 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:33:26 -0400 Subject: [PATCH 14/21] Repeat of the Quantity trick to get a proper type hint --- ndsl/quantity/local.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 214fe977..1c00739d 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -6,6 +6,9 @@ from ndsl.optional_imports import cupy from ndsl.quantity import Quantity +if cupy is None: + import numpy as cupy + class Local(Quantity): """Local is a Quantity that cannot be used outside of the class From 2b50c2000f009e924cbe854d1ec433e9e57a7839 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:34:24 -0400 Subject: [PATCH 15/21] Lint `local.py` --- ndsl/quantity/local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 1c00739d..910999ab 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -6,6 +6,7 @@ from ndsl.optional_imports import cupy from ndsl.quantity import Quantity + if cupy is None: import numpy as cupy From 7975e9915e7530391e1c17bbbf88464d7cc6d09e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 16:57:39 -0400 Subject: [PATCH 16/21] Protect against bad init for orchestration Move all unit test into a `test_ndsl_runtime` --- ndsl/dsl/ndsl_runtime.py | 19 +++++-- tests/quantity/test_locals.py | 56 ------------------ tests/test_ndsl_runtime.py | 104 ++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 61 deletions(-) delete mode 100644 tests/quantity/test_locals.py create mode 100644 tests/test_ndsl_runtime.py diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index b08019f3..a45480fc 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -21,6 +21,8 @@ class NDSLRuntime: def __init__(self, dace_config: DaceConfig) -> None: self._dace_config = dace_config + # Use this flag to detect that the init wasn't done properly + self._base_class_was_properly_super_init = True def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: # WARNING: no code outside the `init_decorator` this is cls @@ -44,6 +46,11 @@ def new_init( cls.__init__ = init_decorator(cls.__init__) # type: ignore[method-assign] def __post_init__(self) -> None: + if not hasattr(self, "_base_class_was_properly_super_init"): + raise RuntimeError( + f"Class {type(self).__name__} inherit from NDSLRuntime but didn't call super().__init__." + ) + # Check quantity allocation of NDSLRuntime supervised code if _TOP_LEVEL == self: @@ -60,10 +67,12 @@ def check_for_quantity(object_: object) -> None: check_for_quantity(self) # Orchestrate __call__ by default - orchestrate( - obj=self, - config=self._dace_config, - ) + if hasattr(self, "__call__"): + orchestrate( + obj=self, + config=self._dace_config, + ) + print(type(self)) def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) @@ -84,7 +93,7 @@ def __getattribute__(self, name: str) -> Any: ): # We expect the original class to have been monkey-patched # See `dace.dsl.orchestration.orchestrate` - unpatched_name = self.__name__[: -len("_patched")] + unpatched_name = type(self).__name__[: -len("_patched")] raise RuntimeError( f"Forbidden Local access: {name} called outside of {unpatched_name}." ) diff --git a/tests/quantity/test_locals.py b/tests/quantity/test_locals.py deleted file mode 100644 index a0cc3883..00000000 --- a/tests/quantity/test_locals.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -from ndsl import NDSLRuntime, QuantityFactory, StencilFactory -from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.gt4py import PARALLEL, computation, interval -from ndsl.dsl.typing import FloatField - - -def the_copy_stencil(from_: FloatField, to: FloatField): - with computation(PARALLEL), interval(...): - to = from_ - - -class Code(NDSLRuntime): - def __init__( - self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory - ) -> None: - super().__init__(dace_config=stencil_factory.config.dace_config) - self.copy = stencil_factory.from_dims_halo( - the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] - ) - self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) - - def test_check(self): - assert self.local.__descriptor__().transient - - def __call__(self, A, B): - self.copy(A, self.local) - self.copy(self.local, B) - - -def test_local_and_transient_flags(): - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 3, 0, backend="dace:cpu_kfirst" - ) - - A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") - B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") - - code = Code(stencil_factory, quantity_factory) - code(A_, B_) - - # Check that local is not reachable outside of Code - with pytest.raises(RuntimeError, match="Forbidden Local access:"): - assert code.local._transient - - # Check the local is properly transient - with access in Code - code.test_check() - - # Check regular quantity are not transient - assert not A_.__descriptor__().transient - assert not B_.__descriptor__().transient - - # Check numerics - assert (A_.field[:] == B_.field[:]).all() diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py new file mode 100644 index 00000000..6c763d0b --- /dev/null +++ b/tests/test_ndsl_runtime.py @@ -0,0 +1,104 @@ +import pytest + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory +from ndsl.boilerplate import ( + get_factories_single_tile_orchestrated, + get_factories_single_tile, +) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def the_copy_stencil(from_: FloatField, to: FloatField): + with computation(PARALLEL), interval(...): + to = from_ + + +class Code(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(dace_config=stencil_factory.config.dace_config) + self.copy = stencil_factory.from_dims_halo( + the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] + ) + self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def test_check(self): + assert self.local.__descriptor__().transient + + def __call__(self, A, B): + self.copy(A, self.local) + self.copy(self.local, B) + + +class BadCode_NoSuperInit(NDSLRuntime): + def __init__(self) -> None: + # Forget to init + pass + + +class Code_NoCall(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(dace_config=stencil_factory.config.dace_config) + pass + + def run(self, A, B): + pass + + +def test_runtime_make_local(): + stencil_factory, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="numpy" + ) + A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + + code = Code(stencil_factory, quantity_factory) + + # Check that local is not reachable outside of Code + with pytest.raises(RuntimeError, match="Forbidden Local access:"): + assert code.local.__descriptor__().transient + + # Check the local is properly transient - with access in Code + code.test_check() + + # Check regular quantity are not transient + assert not A_.__descriptor__().transient + assert not B_.__descriptor__().transient + + +def test_runtime_has_orchestracted_call(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + code = Code(stencil_factory, quantity_factory) + code(A_, B_) + + # We monkey patch the class, a __name__ attribute is now available + # and the original Class name is postfixed with "_patched" + assert hasattr(code, "__name__") + assert code.__name__ == "Code_patched" + assert (A_.field[:] == B_.field[:]).all() + + +def test_runtime_does_not_orchestrate_when_call_is_not_present(): + stencil_factory, _ = get_factories_single_tile_orchestrated( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + code = Code_NoCall(stencil_factory) + + # We didn't monkey patch the class, no __name__ on object + # and the original Class name is intact + assert not hasattr(code, "__name__") + assert type(code).__name__ == "Code_NoCall" + + +def test_runtime_fail_when_not_super_init(): + with pytest.raises( + RuntimeError, match="inherit from NDSLRuntime but didn't call super()" + ): + bad_code = BadCode_NoSuperInit() From 0589786de844c2ee8ee55b5fd7db7be79c50b855 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 17:01:15 -0400 Subject: [PATCH 17/21] Lint --- tests/test_ndsl_runtime.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index 6c763d0b..ad334842 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -1,16 +1,18 @@ +from typing import Any + import pytest from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import ( - get_factories_single_tile_orchestrated, get_factories_single_tile, + get_factories_single_tile_orchestrated, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField -def the_copy_stencil(from_: FloatField, to: FloatField): +def the_copy_stencil(from_: FloatField, to: FloatField) -> None: with computation(PARALLEL), interval(...): to = from_ @@ -25,10 +27,10 @@ def __init__( ) self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) - def test_check(self): + def test_check(self) -> None: assert self.local.__descriptor__().transient - def __call__(self, A, B): + def __call__(self, A, B) -> None: # type: ignore[no-untyped-def] self.copy(A, self.local) self.copy(self.local, B) @@ -44,11 +46,11 @@ def __init__(self, stencil_factory: StencilFactory) -> None: super().__init__(dace_config=stencil_factory.config.dace_config) pass - def run(self, A, B): + def run(self, A: Any, B: Any) -> None: pass -def test_runtime_make_local(): +def test_runtime_make_local() -> None: stencil_factory, quantity_factory = get_factories_single_tile( 5, 5, 3, 0, backend="numpy" ) @@ -69,7 +71,7 @@ def test_runtime_make_local(): assert not B_.__descriptor__().transient -def test_runtime_has_orchestracted_call(): +def test_runtime_has_orchestracted_call() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( 5, 5, 3, 0, backend="dace:cpu_kfirst" ) @@ -85,7 +87,7 @@ def test_runtime_has_orchestracted_call(): assert (A_.field[:] == B_.field[:]).all() -def test_runtime_does_not_orchestrate_when_call_is_not_present(): +def test_runtime_does_not_orchestrate_when_call_is_not_present() -> None: stencil_factory, _ = get_factories_single_tile_orchestrated( 5, 5, 3, 0, backend="dace:cpu_kfirst" ) @@ -97,7 +99,7 @@ def test_runtime_does_not_orchestrate_when_call_is_not_present(): assert type(code).__name__ == "Code_NoCall" -def test_runtime_fail_when_not_super_init(): +def test_runtime_fail_when_not_super_init() -> None: with pytest.raises( RuntimeError, match="inherit from NDSLRuntime but didn't call super()" ): From e56531e7522b27755d927e1a0c4ba0e9f8e0d134 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 22 Oct 2025 09:22:03 -0400 Subject: [PATCH 18/21] Revert uneeded change to `Quantity` --- ndsl/initialization/allocator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index afff527b..cc59daea 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -224,7 +224,7 @@ def _allocate( ) except TypeError: data = allocator(shape, dtype=dtype) - quantity = Quantity( + return Quantity( data, dims=dims, units=units, @@ -233,7 +233,6 @@ def _allocate( gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, ) - return quantity def get_quantity_halo_spec( self, From 27e408aa62a8d932fc8179c18c66586c5c76c5ff Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 22 Oct 2025 09:22:14 -0400 Subject: [PATCH 19/21] Correct type hint for Callable --- ndsl/dsl/ndsl_runtime.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index a45480fc..2272d413 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -2,7 +2,8 @@ import inspect import warnings -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from ndsl.dsl.dace import DaceConfig, orchestrate from ndsl.dsl.typing import Float @@ -17,7 +18,7 @@ class NDSLRuntime: """Base class to tool runtime code, allows use of Locals, orchestration and debug tools. - The __call__ function will automatically orchestrated.""" + The __call__ function will automatically be orchestrated.""" def __init__(self, dace_config: DaceConfig) -> None: self._dace_config = dace_config @@ -78,7 +79,7 @@ def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) # We look at the direct caller frame for our own `self` # in the locals. - # Any other case are forbidden. + # All other cases are forbidden. if isinstance(attr, Local): frame = inspect.currentframe() if frame is None: From 675f16262a7521be6c21df32cd6110efafb7e538 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 22 Oct 2025 10:54:08 -0400 Subject: [PATCH 20/21] Revert orthogonal changes to this PR --- ndsl/quantity/state.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 59f81107..b1c776a1 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -60,11 +60,7 @@ def _init_recursive(cls: Any) -> dict: initial_quantities[_field.name] = quantity_factory_allocator( _field.metadata["dims"], - ( - _field.metadata["units"] - if "units" in _field.metadata.keys() - else "unspecified" - ), + _field.metadata["units"], dtype=_field.metadata["dtype"], allow_mismatch_float_precision=True, ) From 75853d8af5445eba352e0923749fc9a2e5465b55 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 22 Oct 2025 10:58:22 -0400 Subject: [PATCH 21/21] Lint --- ndsl/dsl/ndsl_runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 2272d413..d7d67c27 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -2,8 +2,8 @@ import inspect import warnings -from typing import Any from collections.abc import Callable +from typing import Any from ndsl.dsl.dace import DaceConfig, orchestrate from ndsl.dsl.typing import Float