Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
18c9d33
Temporaries base dataclass
FlorianDeconinck Oct 15, 2025
1a241f0
Lint
FlorianDeconinck Oct 15, 2025
207964c
Hide `_transient` flag of Quantity away
FlorianDeconinck Oct 16, 2025
6360031
`QuantityFactory` has now a `is_local` option
FlorianDeconinck Oct 16, 2025
fab9490
Update wording to `Local`, trash `Temporaries` state idea
FlorianDeconinck Oct 16, 2025
d0f36b2
lint
FlorianDeconinck Oct 16, 2025
c86f87e
Public API clean up
FlorianDeconinck Oct 16, 2025
2351c17
Simplify and fix test
FlorianDeconinck Oct 16, 2025
0bd632a
Oops, restore code for `from_array` allocator
FlorianDeconinck Oct 16, 2025
bc1dd39
Merge branch 'develop' into feature/Temporaries
twicki Oct 17, 2025
df4de93
Merge branch 'develop' into feature/Temporaries
FlorianDeconinck Oct 21, 2025
b3eaea1
Remove keyword allocation
FlorianDeconinck Oct 21, 2025
e5d8abd
Introduce `Local` and `NDSLRuntime`
FlorianDeconinck Oct 21, 2025
f4a03d0
Lint new files
FlorianDeconinck Oct 21, 2025
2f79ef7
Remove the odd `_transient` and tag transientness correctly in `Local`
FlorianDeconinck Oct 21, 2025
83f09fd
Repeat of the Quantity trick to get a proper type hint
FlorianDeconinck Oct 21, 2025
2b50c20
Lint `local.py`
FlorianDeconinck Oct 21, 2025
7975e99
Protect against bad init for orchestration
FlorianDeconinck Oct 21, 2025
0589786
Lint
FlorianDeconinck Oct 21, 2025
e56531e
Revert uneeded change to `Quantity`
FlorianDeconinck Oct 22, 2025
27e408a
Correct type hint for Callable
FlorianDeconinck Oct 22, 2025
675f162
Revert orthogonal changes to this PR
FlorianDeconinck Oct 22, 2025
75853d8
Lint
FlorianDeconinck Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ndsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
StorageReport,
)
from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater
from .dsl.ndsl_runtime import NDSLRuntime
from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector
from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
from .exceptions import OutOfBoundsError
Expand All @@ -27,7 +28,7 @@
from .performance.collector import NullPerformanceCollector, PerformanceCollector
from .performance.profiler import NullProfiler, Profiler
from .performance.report import Experiment, Report, TimeReport
from .quantity import Quantity, State
from .quantity import Local, Quantity, State
from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import
from .testing.dummy_comm import DummyComm
from .types import Allocator
Expand Down Expand Up @@ -86,4 +87,6 @@
"Allocator",
"MetaEnumStr",
"State",
"NDSLRuntime",
"Local",
]
5 changes: 5 additions & 0 deletions ndsl/dsl/dace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dace_config import DaceConfig
from .orchestration import orchestrate, orchestrate_function


__all__ = ["DaceConfig", "orchestrate", "orchestrate_function"]
127 changes: 127 additions & 0 deletions ndsl/dsl/ndsl_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import inspect
import warnings
from collections.abc import Callable
from typing import Any

from ndsl.dsl.dace import DaceConfig, orchestrate
from ndsl.dsl.typing import Float
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Local, Quantity


_TOP_LEVEL: object | None = None


class NDSLRuntime:
"""Base class to tool runtime code, allows use of Locals, orchestration and
debug tools.

The __call__ function will automatically be orchestrated."""

def __init__(self, dace_config: DaceConfig) -> None:
self._dace_config = dace_config
# Use this flag to detect that the init wasn't done properly
self._base_class_was_properly_super_init = True

def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None:
# WARNING: no code outside the `init_decorator` this is cls
# function, it will be called ONLY ONCE for monkey-patching the
# Class - not the instance !

def init_decorator(previous_init: Callable) -> Callable:
def new_init(
self: NDSLRuntime,
*args: list[Any],
**kwargs: dict[str, Any],
) -> None:
global _TOP_LEVEL
if _TOP_LEVEL is None:
_TOP_LEVEL = self
previous_init(self, *args, **kwargs)
self.__post_init__()

return new_init

cls.__init__ = init_decorator(cls.__init__) # type: ignore[method-assign]

def __post_init__(self) -> None:
if not hasattr(self, "_base_class_was_properly_super_init"):
raise RuntimeError(
f"Class {type(self).__name__} inherit from NDSLRuntime but didn't call super().__init__."
)

# Check quantity allocation of NDSLRuntime supervised code
if _TOP_LEVEL == self:

def check_for_quantity(object_: object) -> None:
for key, value in object_.__dict__.items():
if isinstance(value, Quantity) and not isinstance(value, Local):
warnings.warn(
f"{type(self).__name__}.{key} is a Quantity instead of a Locals"
" on a NDSLRuntime - our eyebrows are frowned."
)
elif isinstance(value, NDSLRuntime):
check_for_quantity(value)

check_for_quantity(self)

# Orchestrate __call__ by default
if hasattr(self, "__call__"):
orchestrate(
obj=self,
config=self._dace_config,
)
print(type(self))

def __getattribute__(self, name: str) -> Any:
attr = super().__getattribute__(name)
# We look at the direct caller frame for our own `self`
# in the locals.
# All other cases are forbidden.
if isinstance(attr, Local):
frame = inspect.currentframe()
if frame is None:
raise NotImplementedError(
"Locals check cannot locate frame. Talk to the team."
)
caller_frame = frame.f_back
if (
not caller_frame
or "self" not in caller_frame.f_locals
or not isinstance(caller_frame.f_locals["self"], type(self))
):
# We expect the original class to have been monkey-patched
# See `dace.dsl.orchestration.orchestrate`
unpatched_name = type(self).__name__[: -len("_patched")]
raise RuntimeError(
f"Forbidden Local access: {name} called outside of {unpatched_name}."
)

return attr

def make_local(
self,
quantity_factory: QuantityFactory,
dims: list[str],
dtype: type = Float,
units: str = "unspecified",
*,
allow_mismatch_float_precision: bool = False,
) -> Local:
quantity = quantity_factory.zeros(
dims,
units,
dtype,
allow_mismatch_float_precision=allow_mismatch_float_precision,
)
return Local(
data=quantity.data,
dims=quantity.dims,
units=quantity.units,
origin=quantity.origin,
extent=quantity.extent,
gt4py_backend=quantity.gt4py_backend,
allow_mismatch_float_precision=allow_mismatch_float_precision,
)
30 changes: 26 additions & 4 deletions ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,56 @@ def empty(
dims: Sequence[str],
units: str,
dtype: type = Float,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""Allocate a Quantity - values are random.

Equivalent to `numpy.empty`"""
return self._allocate(
self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision
self._numpy.empty,
dims,
units,
dtype,
allow_mismatch_float_precision,
)

def zeros(
self,
dims: Sequence[str],
units: str,
dtype: type = Float,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""Allocate a Quantity and fill it with the value 0.

Equivalent to `numpy.zeros`"""
return self._allocate(
self._numpy.zeros, dims, units, dtype, allow_mismatch_float_precision
self._numpy.zeros,
dims,
units,
dtype,
allow_mismatch_float_precision,
)

def ones(
self,
dims: Sequence[str],
units: str,
dtype: type = Float,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""Allocate a Quantity and fill it with the value 1.

Equivalent to `numpy.ones`"""
return self._allocate(
self._numpy.ones, dims, units, dtype, allow_mismatch_float_precision
self._numpy.ones,
dims,
units,
dtype,
allow_mismatch_float_precision,
)

def full(
Expand All @@ -120,13 +135,18 @@ def full(
units: str,
value: Any, # no type hint because it would be a TypeVar = type[dtype] and mypy says no
dtype: type = Float,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""Allocate a Quantity and fill it with the value.

Equivalent to `numpy.full`"""
quantity = self._allocate(
self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision
self._numpy.empty,
dims,
units,
dtype,
allow_mismatch_float_precision,
)
quantity.data[:] = value
return quantity
Expand All @@ -136,6 +156,7 @@ def from_array(
data: np.ndarray,
dims: Sequence[str],
units: str,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""
Expand All @@ -158,6 +179,7 @@ def from_compute_array(
data: np.ndarray,
dims: Sequence[str],
units: str,
*,
allow_mismatch_float_precision: bool = False,
) -> Quantity:
"""
Expand Down
10 changes: 4 additions & 6 deletions ndsl/quantity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from .state import State


__all__ = [
"Quantity",
"QuantityMetadata",
"QuantityHaloSpec",
"State",
]
from .local import Local # isort: skip


__all__ = ["Local", "Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"]
43 changes: 43 additions & 0 deletions ndsl/quantity/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Sequence

import dace
import numpy as np

from ndsl.optional_imports import cupy
from ndsl.quantity import Quantity


if cupy is None:
import numpy as cupy


class Local(Quantity):
"""Local is a Quantity that cannot be used outside of the class
it was allocated in."""

def __init__(
self,
data: np.ndarray | cupy.ndarray,
dims: Sequence[str],
units: str,
origin: Sequence[int] | None = None,
extent: Sequence[int] | None = None,
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
):
super().__init__(
data,
dims,
units,
origin,
extent,
gt4py_backend,
allow_mismatch_float_precision,
)
self._transient = True

def __descriptor__(self) -> Any:
"""Locals uses `Quantity.__descriptor__` and flag itself as transient."""
data = dace.data.create_datadescriptor(self.data)
data.transient = True
return data
Loading