Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
01bf96a
Frontend: allow FieldDescriptor to be used when defining temporaries …
FlorianDeconinck Oct 13, 2025
5615911
Make `debug` backend respect temporaries cartesian sizes
FlorianDeconinck Oct 14, 2025
79a0102
Update common `empty` method to pass down dims for temporaries
FlorianDeconinck Oct 14, 2025
3fb1a02
Make `npir` backend respect temporaries cartesian sizes
FlorianDeconinck Oct 14, 2025
c62d5c6
Unit test
FlorianDeconinck Oct 14, 2025
0f3b8dc
Solve type annotation against `gtscript`
FlorianDeconinck Oct 16, 2025
684d8fa
Flag as experimental feature (lack of `gt:X` support)
FlorianDeconinck Oct 16, 2025
5c3f84b
Fix `scalars to temps` pass in `numpy`
FlorianDeconinck Oct 16, 2025
430de2c
Fix to `npir` dimensions resolution in slicing
FlorianDeconinck Oct 16, 2025
1e0ca8e
Fix NPIR test to add `dimensions`
FlorianDeconinck Oct 16, 2025
a53ee1e
Fix proper calculation of `dimensions` when data dimensions are involved
FlorianDeconinck Oct 16, 2025
5fc038f
Inject `dimensions` in `gtcpp.Temporary` with a default to 3D for now…
FlorianDeconinck Oct 16, 2025
15e2278
Merge branch 'main' into cartesian/feature/2d_temporaries
FlorianDeconinck Oct 16, 2025
b1c595f
Cleaner appending dimensions
FlorianDeconinck Oct 17, 2025
f8bad86
Fix bad read on `from_gtscript`, better docs
FlorianDeconinck Oct 17, 2025
991e85e
Merge remote-tracking branch 'fdeconinck/cartesian/feature/2d_tempora…
FlorianDeconinck Oct 17, 2025
a2f1e7e
Swap `Self` for `Domain` + __future__
FlorianDeconinck Oct 17, 2025
90ed8c9
Fix `Domain.from_gtscript` proper translation of parallel axes
FlorianDeconinck Oct 17, 2025
ae53356
ADR
FlorianDeconinck Oct 20, 2025
3b04477
Correcting ADR
FlorianDeconinck Oct 21, 2025
bc6e274
Merge branch 'main' into cartesian/feature/2d_temporaries
FlorianDeconinck Oct 21, 2025
9dbf30e
Use `warn_experimental_feature`
FlorianDeconinck Oct 21, 2025
43803f4
Move the Domain convert from gtscript.Axis into the frontend processor
FlorianDeconinck Oct 21, 2025
cbe5862
Add fv3 dynamics example for 3D temps that behave like 2D on read
FlorianDeconinck Oct 21, 2025
411b591
Lint
FlorianDeconinck Oct 21, 2025
a92af2f
Typo
FlorianDeconinck Oct 21, 2025
5fe3303
Typo on ADR
FlorianDeconinck Oct 22, 2025
2d9772f
Merge branch 'main' into cartesian/feature/2d_temporaries
FlorianDeconinck Oct 22, 2025
2c265c6
Merge `main` into `cartesian/feature/2d_temporaries`
FlorianDeconinck Oct 22, 2025
752530c
Merge branch 'main' into cartesian/feature/2d_temporaries
FlorianDeconinck Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions docs/development/ADRs/cartesian/experimental/2d-temporaries.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# ⚠️ 2 Dimenions Temporaries
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

In the context of porting physics parametrizations, we have encoutered multiple examples of computation requiring a temporary 2D storage within a stencil.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

We thus decided to expand the DLS be able to port these stencils. We accept the increase in DSL surface and test it as an [experimental feature](../experimental-features.md).

## Context

Porting physics parametrizations, we found temporary 2D field used, mostly when computing particular level (surface, phase change, etc.). Previous workaround was to define the temporaries as an argument. The goal is to undo this workaround that complexify calling code.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

Use case in GEOS's shallow convection parametrization, `total_pwo_solid_phase` will be used later in the stencil but not out of it.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

```python
with computation(FORWARD), interval(...):
if MELT_GLAC == 1 and cumulus == 1:
if K <= ktf-1:
if ierr == 0:
dp = 100.*(po_cup-po_cup[0,0,1])\
pwo_eff = 0.5*(pwo+pwo[0,0,1])
pwo_solid_phase = (1.-p_liq_ice)*pwo_eff
total_pwo_solid_phase = total_pwo_solid_phase+pwo_solid_phase*dp/constants.MAPL_GRAV
```

Use case in GEOS's microphysics. Vertical interpolation to a given level requires to compound pressure around it. Here `pb`, `pt` and `boolean_2d_mask` are all temporary 2D fields.

```python
def vertical_interpolation(
field: FloatField,
interpolated_field: FloatFieldIJ,
p_interface_mb: FloatField,
target_pressure: Float,
pb: FloatFieldIJ,
pt: FloatFieldIJ,
boolean_2d_mask: BoolFieldIJ,
):
"""
Interpolate to a specific vertical level.

Only works for non-interface fields. Must be constructed using Z_DIM.

Arguments:
field (in): three dimensional field to be interpolated to a specific pressure
interpolated_field (out): output two dimension field of interpolated values
p_interface_mb (in): interface pressure in mb
target_pressure (in): target pressure for interpolation in Pascals
pb (in): placeholder 2d quantity, can be removed onces 2d temporaries are available
pt (in): placeholder 2d quantity, can be removed onces 2d temporaries are available
boolean_2d_mask (in): boolean mask to track when each cell is modified
"""
# mask tracks which points have been touched. check later on ensures that every point has been touched
with computation(FORWARD), interval(0, 1):
boolean_2d_mask = False

with computation(FORWARD), interval(-1, None):
pb = 0.5 * (log(p_interface_mb * 100) + log(p_interface_mb[0, 0, 1] * 100))

with computation(BACKWARD), interval(1, None):
pt = 0.5 * (log(p_interface_mb[0, 0, -1] * 100) + log(p_interface_mb * 100))
if log(target_pressure) > pt and log(target_pressure) <= pb and not boolean_2d_mask:
al = (pb - log(target_pressure)) / (pb - pt)
interpolated_field = field[0, 0, -1] * al + field * (1.0 - al)
boolean_2d_mask = True
pb = pt

with computation(FORWARD), interval(-1, None):
pt = 0.5 * (log(p_interface_mb * 100) + log(p_interface_mb[0, 0, -1] * 100))
pb = 0.5 * (log(p_interface_mb * 100) + log(p_interface_mb[0, 0, 1] * 100))
if (
log(target_pressure) > pb
and log(target_pressure) <= log(p_interface_mb[0, 0, 1] * 100)
and not boolean_2d_mask
):
interpolated_field = field
boolean_2d_mask = True

# ensure every point was actually touched
if boolean_2d_mask == False: # noqa
interpolated_field = field

# reset masks and temporaries for later use
with computation(FORWARD), interval(0, 1):
boolean_2d_mask = False
pb = 0
pt = 0
```

2D temporaries would clean those code and allow for futher optimization (scalarization) in all performance backend.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

## Decision

All the guardraisls for 2D temporaries are already in place:
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

- allocation can only be done under `interval == 1` and `computation` in `FORWARD` or `BACKWARD`,
- use of 2D fields are fully defined, using the above rules again.

The main change is on the frontend of stencils and proper forwarding of dimensions through to code genreation, except for `gt:X` backend (see [Consequences](#consequences)).
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

Type hint for temporaries have been introduced before for mixed precision e.g.:
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

```python
with computation(PARALLEL), inteval(...):
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
tmp_3D_as_f32: float32 = 0
```

We propose to extend this type hint to allow specification of the dimensions re-using the `FieldDescriptor`, e.g.

```python
with computation(FORWARD), inteval(...):
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
tmp_2D_as_f32: Field[IJ, np.float64] = 0
```

We also guard against any definition of other type of temporaries (e.g. 1D temporary) because they are ill-defined within `gtscript` which pre-suppose that the horizontal dimensions are always computed upon.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

The type hint can be a little verbose, so we offer to expand upon the `dtypes` dictionnary present in stencil configuration to give a `str: type` pair that can be swapped at parsing time, as long as the type is a derivative of `FieldDescriptor` so the relevant informations can be retrieved.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

```python
@stencil(backend=..., dtype={"My2DType": Field[IJ, np.float64]})
def the_stencil(...):
with computation(FORWARD), inteval(...):
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
tmp_2D: My2DType = 0
...
```

## Consequences

Users of GT4Py can now defined 2D temporaries in-stencil.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

There's one remaining hiccups: GridTools C++ does not natively offer 2D temporaries. GridTools pre-suppose that all computation of temporaries are done on the 3D grid at least - and won't allocate below that dimensionality.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
This is a fair limitation since GridTools doesn't provide the guardrails against race condition that GT4Py does. We can circumvent this by creating a pool of buffers passed as arguments for `gt:X` backends - see [issue](https://github.com/GridTools/gt4py/issues/2322).

## References

- [PR](https://github.com/GridTools/gt4py/pull/2314) PR where the proposal and the discussion occured.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
- [Issue](https://github.com/GridTools/gt4py/issues/2322) to extend support to the `gt:X` backends
52 changes: 47 additions & 5 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import textwrap
import time
import types
import warnings
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -733,7 +734,7 @@ def __init__(
domain: nodes.Domain,
options: gt_definitions.BuildOptions,
temp_decls: Optional[Dict[str, nodes.FieldDecl]] = None,
dtypes: Optional[Dict[Type, Type]] = None,
dtypes: Optional[Dict[Type | str, Type]] = None,
):
fields = fields or {}
parameters = parameters or {}
Expand Down Expand Up @@ -807,7 +808,14 @@ def __init__(
"round_away_from_zero": nodes.NativeFunction.ROUND_AWAY_FROM_ZERO,
} # Conversion table for functions to NativeFunctions

self.temporary_type_to_native_type = {
# Filter the field type from `dtypes`
self.temporary_field_type = {}
if self.dtypes:
for name, _type in self.dtypes.items():
if isinstance(_type, gtscript._FieldDescriptor):
self.temporary_field_type[name] = _type

self.temporary_type_as_str_to_native_type = {
"int32": nodes.DataType.INT32,
"int64": nodes.DataType.INT64,
"float32": nodes.DataType.FLOAT32,
Expand Down Expand Up @@ -1607,20 +1615,54 @@ def _resolve_assign(
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
)
dtype = nodes.DataType.AUTO
axes = nodes.Domain.LatLonGrid().axes_names
if target_annotation is not None:
source = ast.unparse(target_annotation)
try:
dtype = eval(source, self.temporary_type_to_native_type)
dtype_or_field_desc = eval(
source,
self.temporary_type_as_str_to_native_type
| self.temporary_field_type
| gtscript.__dict__,
)
except NameError:
raise GTScriptSyntaxError(
message=f"Failed to recognize type {source} for local symbol {name}."
f"Available types are {self.temporary_type_to_native_type.keys()}",
f"Available types are {self.temporary_type_as_str_to_native_type.keys()}, {self.dtypes}, "
"or `Field[IJ, dtype]`.",
loc=nodes.Location.from_ast_node(t),
) from None
# If Field, we have to expand to resolve axes and true type
if isinstance(dtype_or_field_desc, gtscript._FieldDescriptor):
field_desc = dtype_or_field_desc
if field_desc.axes != gtscript.IJ:
raise GTScriptSyntaxError(
message=f"Typed temporaries must be IJ, temporaries for axes {field_desc.axes}"
" is not yet available. Contact the team.",
loc=nodes.Location.from_ast_node(t),
) from None
if self.backend_name.startswith("gt:"):
raise NotImplementedError(
"2D temporaries (e.g. `tmp: Field[IJ, float] = ...) is an experimental feature "
"and not yet implemented for the `gt:X` backends."
)
warnings.warn(
"2D temporaries is an experimental feature. Please read "
"<https://github.com/GridTools/gt4py/blob/main/docs/development/ADRs/cartesian/experimental-features.md> "
"to understand the consequences.",
category=UserWarning,
stacklevel=2,
)
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

axes = nodes.Domain.from_gtscript(field_desc.axes).axes_names
dtype = nodes.DataType.from_dtype(field_desc.dtype)
else:
dtype = dtype_or_field_desc

field_decl = nodes.FieldDecl(
name=name,
data_type=dtype,
axes=nodes.Domain.LatLonGrid().axes_names,
axes=axes,
is_api=False,
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
)
Expand Down
15 changes: 15 additions & 0 deletions src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@

import numpy as np

import gt4py.cartesian.gtscript as gt_script
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
from gt4py.cartesian.definitions import CartesianSpace
from gt4py.cartesian.utils.attrib import (
Any as Any,
Expand Down Expand Up @@ -199,6 +200,20 @@ class Domain(Node):
parallel_axes = attribute(of=ListOf[Axis])
sequential_axis = attribute(of=Axis, optional=True)

@classmethod
def from_gtscript(cls, gt_axis: list[gt_script.Axis]) -> Domain:
sequential_axis = None
parallel_axes = []
for axis in gt_axis:
if axis in (gt_script.I, gt_script.J):
parallel_axes.append(Axis(name=axis.name))
else:
sequential_axis = Axis(name=axis.name)
return cls(
parallel_axes=parallel_axes,
sequential_axis=sequential_axis,
)

@classmethod
def LatLonGrid(cls):
return cls(
Expand Down
26 changes: 18 additions & 8 deletions src/gt4py/cartesian/gtc/debug/debug_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,30 @@ def visit_Interval(self, interval: oir.Interval, **kwargs) -> str:
return ",".join([self.visit(interval.start, **kwargs), self.visit(interval.end, **kwargs)])

def visit_Temporary(self, temporary_declaration: oir.Temporary, **kwargs) -> str:
# Cartesian IJ
field_extents = kwargs["field_extents"]
local_field_extent = field_extents[temporary_declaration.name]
i_padding: int = local_field_extent[0][1] - local_field_extent[0][0]
j_padding: int = local_field_extent[1][1] - local_field_extent[1][0]
shape: list[str] = [f"i_size + {i_padding}", f"j_size + {j_padding}", "k_size"]
shape: list[str] = [f"i_size + {i_padding}", f"j_size + {j_padding}"]
field_offset = tuple(-ext[0] for ext in local_field_extent)
offset = [str(off) for off in field_offset]

# Cartesian vertical dimension K
if temporary_declaration.dimensions[2]:
shape.append("k_size")
offset.append("0")

# Data dimensions
data_dimensions: list[str] = [str(dim) for dim in temporary_declaration.data_dims]
shape = shape + data_dimensions
shape_decl = ", ".join(shape)
offset += ["0"] * len(temporary_declaration.data_dims)

# All together to write the `Field.empty` call
dims = temporary_declaration.dimensions
shape_decl = ", ".join(shape + data_dimensions)
dtype: str = self.visit(temporary_declaration.dtype)
field_offset = tuple(-ext[0] for ext in local_field_extent)
offset = [str(off) for off in field_offset] + ["0"] * (
1 + len(temporary_declaration.data_dims)
)
return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}))"

return f"{temporary_declaration.name} = Field.empty(({shape_decl}), {dtype}, ({', '.join(offset)}), {dims})"

def visit_DataType(self, data_type: gtc_common.DataType, **_) -> str:
if data_type in {gtc_common.DataType.BOOL}:
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/cartesian/gtc/gtcpp/gtcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Cast(common.Cast[Expr], Expr):
class Temporary(LocNode):
name: eve.Coerced[eve.SymbolName]
dtype: common.DataType
dimensions: tuple[bool, bool, bool] = (True, True, True)
Comment thread
romanc marked this conversation as resolved.
data_dims: Tuple[int, ...] = eve.field(default_factory=tuple)


Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/cartesian/gtc/numpy/npir.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ class TemporaryDecl(Decl):

Parameters
----------
data_dims: Data dimensions sizes (static)
dimensions: Mask which cartesian dimensions are present
offset: Origin of the temporary field.
padding: Buffer added to compute domain as field size.
"""

data_dims: Tuple[int, ...] = eve.field(default_factory=tuple)
dimensions: Tuple[bool, bool, bool]
offset: Tuple[int, int]
padding: Tuple[int, int]

Expand Down
25 changes: 19 additions & 6 deletions src/gt4py/cartesian/gtc/numpy/npir_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,20 @@ def add_declared(self, *args):
def visit_TemporaryDecl(
self, node: npir.TemporaryDecl, **kwargs
) -> Union[str, Collection[str]]:
shape = [f"_dI_ + {node.padding[0]}", f"_dJ_ + {node.padding[1]}", "_dK_"] + [
str(dim) for dim in node.data_dims
]
offset = [str(off) for off in node.offset] + ["0"] * (1 + len(node.data_dims))
# Cartesian IJ
shape = [f"_dI_ + {node.padding[0]}", f"_dJ_ + {node.padding[1]}"]
offset = [str(off) for off in node.offset]
# Vertical dimension K
if node.dimensions[2]:
shape += ["_dK_"]
offset += ["0"]
# Data dimensions
shape += [str(dim) for dim in node.data_dims]
offset += ["0"] * len(node.data_dims)

dtype = self.visit(node.dtype, **kwargs)
return f"{node.name} = Field.empty(({', '.join(shape)}), {dtype}, ({', '.join(offset)}))"
dims = node.dimensions
return f"{node.name} = Field.empty(({', '.join(shape)}), {dtype}, ({', '.join(offset)}), {dims})"

LocalScalarDecl = as_fmt(
"{name} = Field.empty((_dI_ + {upper[0] + lower[0]}, _dJ_ + {upper[1] + lower[1]}, {ksize}), {dtype}, ({', '.join(str(l) for l in lower)}, 0))"
Expand All @@ -107,6 +115,7 @@ def visit_FieldSlice(self, node: npir.FieldSlice, **kwargs: Any) -> Union[str, C
if isinstance(node.k_offset, npir.VarKOffset)
else node.k_offset
)

offsets: Tuple[Optional[int], Optional[int], Union[str, int, None]] = (
node.i_offset,
node.j_offset,
Expand All @@ -116,7 +125,11 @@ def visit_FieldSlice(self, node: npir.FieldSlice, **kwargs: Any) -> Union[str, C
# To determine: when is the symbol name not in the symtable?
if node.name in kwargs.get("symtable", {}):
decl = kwargs["symtable"][node.name]
dimensions = decl.dimensions if isinstance(decl, npir.FieldDecl) else [True] * 3
dimensions = (
decl.dimensions
if isinstance(decl, npir.FieldDecl | npir.TemporaryDecl)
else [True] * 3
)
offsets = cast(
Tuple[Optional[int], Optional[int], Union[str, int, None]],
tuple(off if has_dim else None for has_dim, off in zip(dimensions, offsets)),
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/cartesian/gtc/numpy/oir_to_npir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def visit_Temporary(
return npir.TemporaryDecl(
name=node.name,
dtype=node.dtype,
dimensions=node.dimensions,
data_dims=node.data_dims,
offset=offset,
padding=padding,
Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/cartesian/gtc/numpy/scalars_to_temps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Temporary:
name: str
dtype: common.DataType
extent: Extent
dimensions: tuple[bool, bool, bool]


def _all_local_scalars_are_unique_type(stencil: npir.Computation) -> bool:
Expand Down Expand Up @@ -59,7 +60,10 @@ def visit_HorizontalBlock(
for decl in node.declarations:
if decl.name not in temps_from_scalars:
temps_from_scalars[decl.name] = Temporary(
name=decl.name, dtype=decl.dtype, extent=node.extent
name=decl.name,
dtype=decl.dtype,
extent=node.extent,
dimensions=(True, True, True),
)
else:
temps_from_scalars[decl.name].extent |= node.extent
Expand All @@ -85,6 +89,7 @@ def visit_Computation(self, node: npir.Computation) -> npir.Computation:
name=d.name,
offset=(-d.extent[0][0], -d.extent[1][0]),
padding=(d.extent[0][1] - d.extent[0][0], d.extent[1][1] - d.extent[1][0]),
dimensions=d.dimensions,
dtype=d.dtype,
)
for d in temps_from_scalars.values()
Expand Down
Loading