diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 53c910dc..1cae1063 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -1,5 +1,5 @@ import os -from typing import Tuple, Union, cast +from typing import Tuple, TypeAlias, Union, cast import gt4py.cartesian.gtscript as gtscript import numpy as np @@ -34,13 +34,13 @@ def get_precision() -> int: # We redefine the type as a way to distinguish # the model definition of a float to other usage of the # common numpy type in the rest of the code. -NDSL_32BIT_FLOAT_TYPE = np.float32 -NDSL_64BIT_FLOAT_TYPE = np.float64 -NDSL_32BIT_INT_TYPE = np.int32 -NDSL_64BIT_INT_TYPE = np.int64 +NDSL_32BIT_FLOAT_TYPE: TypeAlias = np.float32 +NDSL_64BIT_FLOAT_TYPE: TypeAlias = np.float64 +NDSL_32BIT_INT_TYPE: TypeAlias = np.int32 +NDSL_64BIT_INT_TYPE: TypeAlias = np.int64 -def global_set_precision() -> type: +def global_set_precision() -> Tuple[TypeAlias, TypeAlias]: """Set the global precision for all references of Float and Int in the codebase. Defaults to 64 bit.""" global Float, Int