Skip to content

Locals#266

Merged
FlorianDeconinck merged 23 commits into
NOAA-GFDL:developfrom
FlorianDeconinck:feature/Temporaries
Oct 22, 2025
Merged

Locals#266
FlorianDeconinck merged 23 commits into
NOAA-GFDL:developfrom
FlorianDeconinck:feature/Temporaries

Conversation

@FlorianDeconinck
Copy link
Copy Markdown
Collaborator

@FlorianDeconinck FlorianDeconinck commented Oct 15, 2025

Locals

AS: All of the above is geared to, and only to, orch:dace:X - but the system will be deployed at NDSL level in a backend-blind way.

A limitation to our optimization pipeline is the inability to differentiate "global" fields from "transient". Take the following code:

class PhysicsParam:

    def __init__(self):
        ...
        self._local_field = quantity_factory.zeros()

    def __call__(self, input_A, output_B):

        self.stencil(input_A, self._local_field)
        ...
        self.stencil(self._local_field, output_B)

class Driver:
    def __init__(self):
        ...
        self.physics = PhysicsParam()

    def __call__(self, input_A, output_B):
        physics(input_A, output_B)


if __name__ == "__main__"
    d = Driver()
    d()

In this configuration the field self._local_field is technically "temporary" to PhysicsParam code and should therefore be transient in DaCe SDFG lingo. Multiple problems arise:

  • how to allow the user to flag the Quantity as transient. Should we? This is both an highly technical feature, but it has an explainable usage.
  • there is a difference of behavior between backends. orch:X backends will have memory irrelevant in self._local_field since DaCe would take over it. All other backends would still have the last value used

To address those, we introduce a concept of Local which is a field that is forbidden to be used outside of the module it was defined in. In the previous example self._local_field would become a Local. This Local would pass the transient flag to the dace.Data descriptor and open the way for futher optimization.

From this, we need to implement a few systems:

  • A way to forbid access of Locals outside of the module that allocated it. E.g.:
if __name__ == "__main__":
    d = Driver()
    d()
    print(d.physics._local_field) # <-- ILLEGAL, this is LOCAL to PhysicsParam
  • A way to detect Quantity allocation within "NDSL" code to warn about potential under-optimized use.

Below is a proposal - which has been tested for orchestration:

from ndsl import orchestrate, DaceConfig
import inspect
import warnings
import os
import numpy as np

# ===============  NDSL OLD CONCEPTS (summarized) =========== #


class Quantity:
    def __init__(self) -> None:
        self.data = np.empty((2, 2))
        self.transient = False

    def __str__(self) -> str:
        return "I am quantity"


# ===============  NDSL NEW CONCEPTS =========== #


class Local(Quantity):
    """This could NOT be used: with the __post_init__
    we could flip switches ourselves on quantities and make it all
    automagical without the need for a Local class"""

    def __init__(self) -> None:
        super().__init__()
        self.transient = True

    def __str__(self) -> str:
        return "I am a Local"


TOP_LEVEL: object | None = None


class NDSLRuntime:
    def __init__(self, dace_config: DaceConfig) -> None:
        self._dace_config = dace_config

    def __init_subclass__(cls, **kwargs):
        """WARNING: no code outside the `init_decorator` this is cls
        function, it will be called ONLY ONCE for monkey-patching the
        Class - not the instance !"""

        def init_decorator(previous_init):
            def new_init(self, *args, **kwargs):
                global TOP_LEVEL
                if TOP_LEVEL is None:
                    TOP_LEVEL = self
                previous_init(self, *args, **kwargs)
                self.__post_init__()

            return new_init

        cls.__init__ = init_decorator(cls.__init__)

    def __post_init__(self):
        # Check quantity allocation of NDSLRuntime supervised code
        if TOP_LEVEL == self:

            def check_for_quantity(object_: object):
                for key, value in object_.__dict__.items():
                    if isinstance(value, Quantity) and not isinstance(value, Local):
                        warnings.warn(
                            f"{type(self).__name__}.{key} is a Quantity instead of a Locals"
                            " on a NDSLRuntime - our eyebrows are frowned."
                        )
                    elif isinstance(value, NDSLRuntime):
                        check_for_quantity(value)

            check_for_quantity(self)

        # Orchestrate __call__ by default
        orchestrate(
            obj=self,
            config=self._dace_config,
        )

    def __getattribute__(self, name):
        attr = super().__getattribute__(name)
        # We look at the direct caller frame for our own `self`
        # in the locals.
        # Any other case are forbidden.
        if isinstance(attr, Local):
            caller_frame = inspect.currentframe().f_back
            if (
                not caller_frame
                or "self" not in caller_frame.f_locals
                or not isinstance(caller_frame.f_locals["self"], type(self))
            ):
                raise RuntimeError("Locals called outside!")

        return attr


# ===============  USER CODE =========== #


class NestedCode(NDSLRuntime):
    def __init__(self, dace_config: DaceConfig) -> None:
        super().__init__(dace_config)
        self.tmp = Local()
        self.tmp_as_qty__baaaad = Quantity()

    def __call__(self):
        self.tmp.data[:] = 12


class Code(NDSLRuntime):
    def __init__(self, dace_config: DaceConfig) -> None:
        super().__init__(dace_config)
        self.nested_code = NestedCode(dace_config)

    def __call__(self):
        self.nested_code()


os.environ["FV3_DACEMODE"] = "BuildAndRun"
c = Code(DaceConfig(None, "dace:cpu_kfirst"))
c()


print("Illegal use of a Local: ", c.nested_code.tmp)  # ILLEGAL
c.nested_code.tmp.data[:] = 6  # ILLEGAL

Remained to be solved:

  • API for Local allocation (LocalFactory? Method on NDSLRuntime taking a QuantityFactory ? ... ?)

PS: This also introduces a NDSLRuntime base class, which can be used for a better orchestration, debug and overall gives consistency and structure to NDSL code

Allow `units` to not be specified
+ unit test
Copy link
Copy Markdown
Collaborator

@romanc romanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered making a temporary quantity, i.e.

class Temporary(Quantity):
    def __init__(...):
        super().__init__(...)
        # Note: I wouldn't expose `_transient` in the constructor
        # of `Quantity`  because it will only raise questions.
        self._transient = True

That would still allow to group all temporaries in a State object, e.g.

@dataclass
class MyTemporaries(State):
    tmp_a: Temporary = dataclass.field( ... )
    tmp_b: Temporary = dataclass.field( ... )

It would open the possibility of State objects with Quantity and Temporary members, which we can't do with the current proposal. That might or might not be a good idea. I'm not familiar enough with what atmospheric scientist expect from the "State" concept. My CS-background finds it weird that I have to group my temps and non-temps separate states.

@twicki
Copy link
Copy Markdown
Collaborator

twicki commented Oct 16, 2025

I share Roman's perspective. If I understand the goal right, we want to have a system that replaces the class-based temporary-initialization work done for example here:
https://github.com/NOAA-GFDL/PyFV3/blob/8b56c122e168afb7996e0d99d7f77e5290ba97eb/pyfv3/stencils/d_sw.py#L848-L868
and these are closer to quantities than states. now if we shold be able to group all of them together into a temporary-state for d-sw - i am not 100% sure. It might make sense if we then can just call all_temporaries.reset() to make sure we don't miss anything. But maybe they are also unrelated enough to just keep them individually separate

@FlorianDeconinck FlorianDeconinck changed the title Temporaries Locals Oct 16, 2025
@FlorianDeconinck
Copy link
Copy Markdown
Collaborator Author

Post NASA team discussion:

  • We decided to move away from the loaded Temporary concept into a Local concept, which clearly makes a difference between in-stencil and out-stencil definitions.
  • We will give the building block for a single Local to be build. More complex python structure remain available for the user as always (but not making State out of Locals to not mix concepts).
  • We keep Quantity has the building block for everything: the unit that retains data + metadata.

@FlorianDeconinck FlorianDeconinck marked this pull request as draft October 16, 2025 18:37
Copy link
Copy Markdown
Collaborator

@twicki twicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really clean, I like this approach.

There is not really a need for a separate (inheriting) class Locals because we can create a frontend-api that makes it very obvious, right? The "problem" I see with an optional kwarg is that people might drop it. We can assume our users are clever enough to not do that, right?

@FlorianDeconinck
Copy link
Copy Markdown
Collaborator Author

PR updated with code sample & explanation of the issue/remedy

@FlorianDeconinck FlorianDeconinck marked this pull request as ready for review October 21, 2025 20:22
@FlorianDeconinck
Copy link
Copy Markdown
Collaborator Author

All right the basics are in.

We have a Local, derived to Quantity, that flags itself has transient when asked by DaCe.

We have a NDSLRuntime base class that when derived and super() init knows how to:

  • orchestrate the __call__ function if it exists
  • make Local (and can detect their misuse)

Copy link
Copy Markdown
Collaborator

@romanc romanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice to me. I'm glad we did a couple rounds of prototyping.

Comment thread ndsl/dsl/ndsl_runtime.py Outdated
Comment thread ndsl/dsl/ndsl_runtime.py Outdated
Comment thread ndsl/dsl/ndsl_runtime.py Outdated
Comment thread ndsl/initialization/allocator.py Outdated
Comment thread ndsl/quantity/state.py Outdated
@FlorianDeconinck FlorianDeconinck added this pull request to the merge queue Oct 22, 2025
Merged via the queue into NOAA-GFDL:develop with commit abda75f Oct 22, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants