Skip to content
Merged
2 changes: 1 addition & 1 deletion ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_factories(

compilation_config = CompilationConfig(
backend=backend,
rebuild=True,
rebuild=False,
validate_args=True,
format_source=False,
device_sync=False,
Expand Down
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ndsl.dsl.caches.cache_location import identify_code_path
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import floating_point_precision
from ndsl.dsl.typing import get_precision
from ndsl.optional_imports import cupy as cp


Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
)

if floating_point_precision() == 32:
if get_precision() == 32:
# When using 32-bit float, we flip the default dtypes to be all
# C, e.g. 32 bit.
dace.Config.set(
Expand Down
43 changes: 33 additions & 10 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,41 @@
DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64]


# Depreciated version of get_precision, but retained for a PACE dependency
def floating_point_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


def get_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


# We redefine the type as a way to distinguish
# the model definition of a float to other usage of the
# common numpy type in the rest of the code.
NDSL_32BIT_FLOAT_TYPE = np.float32
NDSL_64BIT_FLOAT_TYPE = np.float64
NDSL_32BIT_INT_TYPE = np.int32
NDSL_64BIT_INT_TYPE = np.int64


def global_set_floating_point_precision():
"""Set the global floating point precision for all reference
to Float in the codebase. Defaults to 64 bit."""
global Float
precision_in_bit = floating_point_precision()
def global_set_precision() -> type:
"""Set the global precision for all references of
Float and Int in the codebase. Defaults to 64 bit."""
global Float, Int
precision_in_bit = get_precision()
if precision_in_bit == 64:
return NDSL_64BIT_FLOAT_TYPE
return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE
elif precision_in_bit == 32:
return NDSL_32BIT_FLOAT_TYPE
return NDSL_32BIT_FLOAT_TYPE, NDSL_32BIT_INT_TYPE
else:
NotImplementedError(
raise NotImplementedError(
f"{precision_in_bit} bit precision not implemented or tested"
)


# Default float and int types
Float = global_set_floating_point_precision()
Int = np.int_
Float, Int = global_set_precision()
Bool = np.bool_

FloatField = Field[gtscript.IJK, Float]
Expand All @@ -68,10 +74,27 @@ def global_set_floating_point_precision():
FloatFieldK = Field[gtscript.K, Float]
FloatFieldK64 = Field[gtscript.K, np.float64]
FloatFieldK32 = Field[gtscript.K, np.float32]

IntField = Field[gtscript.IJK, Int]
IntField64 = Field[gtscript.IJK, np.int64]
IntField32 = Field[gtscript.IJK, np.int32]
IntFieldI = Field[gtscript.I, Int]
IntFieldI64 = Field[gtscript.I, np.int64]
IntFieldI32 = Field[gtscript.I, np.int32]
IntFieldJ = Field[gtscript.J, Int]
IntFieldJ64 = Field[gtscript.J, np.int64]
IntFieldJ32 = Field[gtscript.J, np.int32]
IntFieldIJ = Field[gtscript.IJ, Int]
IntFieldIJ64 = Field[gtscript.IJ, np.int64]
IntFieldIJ32 = Field[gtscript.IJ, np.int32]
IntFieldK = Field[gtscript.K, Int]
IntFieldK64 = Field[gtscript.K, np.int64]
IntFieldK32 = Field[gtscript.K, np.int32]

BoolField = Field[gtscript.IJK, Bool]
BoolFieldI = Field[gtscript.I, Bool]
BoolFieldJ = Field[gtscript.J, Bool]
BoolFieldK = Field[gtscript.K, Bool]
BoolFieldIJ = Field[gtscript.IJ, Bool]

Index3D = Tuple[int, int, int]
Expand Down