diff --git a/src/gt4py/cartesian/__init__.py b/src/gt4py/cartesian/__init__.py index 90df315d5c..4363476c96 100644 --- a/src/gt4py/cartesian/__init__.py +++ b/src/gt4py/cartesian/__init__.py @@ -11,6 +11,7 @@ import typing from . import ( + backend, caching, cli, config, @@ -28,6 +29,7 @@ __all__ = [ "StencilObject", + "backend", "caching", "cli", "config", diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 7cd250ff09..3decc75988 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -15,19 +15,7 @@ import pathlib import time import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - List, - Optional, - Protocol, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol from typing_extensions import deprecated @@ -45,32 +33,6 @@ from gt4py.cartesian.stencil_builder import StencilBuilder from gt4py.cartesian.stencil_object import StencilObject -REGISTRY = gt_utils.Registry() - - -def from_name(name: str) -> Optional[Type[Backend]]: - backend = REGISTRY.get(name, None) - if not backend: - raise NotImplementedError( - f"Backend {name} is not implemented, options are: {REGISTRY.names}" - ) - return backend - - -def register(backend_cls: Type[Backend]) -> Type[Backend]: - assert issubclass(backend_cls, Backend) and backend_cls.name is not None - - if isinstance(backend_cls.name, str): - gt_storage.register(backend_cls.name, backend_cls.storage_info) - return REGISTRY.register(backend_cls.name, backend_cls) - - else: - raise ValueError( - "Invalid 'name' attribute ('{name}') in backend class '{cls}'".format( - name=backend_cls.name, cls=backend_cls - ) - ) - class Backend(abc.ABC): #: Backend name @@ -82,7 +44,7 @@ class Backend(abc.ABC): #: - versioning: is versioning on? #: - description [optional] #: - type - options: ClassVar[Dict[str, Any]] + options: ClassVar[dict[str, Any]] #: Backend-specific storage parametrization: #: @@ -100,7 +62,7 @@ class Backend(abc.ABC): #: #: Languages should be spelled using the official spelling #: but lower case ("python", "fortran", "rust"). - languages: ClassVar[Optional[Dict[str, Any]]] = None + languages: ClassVar[dict[str, Any] | None] = None # __impl_opts: # "disable-code-generation": bool @@ -108,7 +70,7 @@ class Backend(abc.ABC): builder: StencilBuilder - def __init__(self, builder: StencilBuilder): + def __init__(self, builder: StencilBuilder) -> None: self.builder = builder @classmethod @@ -125,7 +87,7 @@ def filter_options_for_id( return filtered_options @abc.abstractmethod - def load(self) -> Optional[Type[StencilObject]]: + def load(self) -> type[StencilObject] | None: """ Load the stencil class from the generated python module. @@ -140,7 +102,7 @@ def load(self) -> Optional[Type[StencilObject]]: pass @abc.abstractmethod - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: """ Generate the stencil class from GTScript's internal representation. @@ -156,19 +118,45 @@ def generate(self) -> Type[StencilObject]: pass @property - def extra_cache_info(self) -> Dict[str, Any]: + def extra_cache_info(self) -> dict[str, Any]: """Provide additional data to be stored in cache info file (subclass hook).""" return {} @property - def extra_cache_validation_keys(self) -> List[str]: + def extra_cache_validation_keys(self) -> list[str]: """List keys from extra_cache_info to be validated during consistency check.""" return [] +REGISTRY = gt_utils.Registry[type[Backend]]() + + +def from_name(name: str) -> type[Backend]: + """Return a backend by name.""" + backend_cls = REGISTRY.get(name, None) + if backend_cls is None: + raise NotImplementedError( + f"Backend '{name}' is not implemented. Options are: {REGISTRY.names}." + ) + return backend_cls + + +def register(backend_cls: type[Backend]) -> type[Backend]: + """Register a backend.""" + assert issubclass(backend_cls, Backend) and backend_cls.name is not None + + if isinstance(backend_cls.name, str): + gt_storage.register(backend_cls.name, backend_cls.storage_info) + return REGISTRY.register(backend_cls.name, backend_cls) + + raise ValueError( + f"Invalid 'name' attribute ('{backend_cls.name}') in backend class '{backend_cls}'." + ) + + class CLIBackendMixin(Backend): @abc.abstractmethod - def generate_computation(self) -> Dict[str, Union[str, Dict]]: + def generate_computation(self) -> dict[str, str | dict]: """ Generate the computation source code in a way agnostic of the way it is going to be used. @@ -218,7 +206,7 @@ def mystencil(...): raise NotImplementedError @abc.abstractmethod - def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: + def generate_bindings(self, language_name: str) -> dict[str, str | dict]: """ Generate bindings source code from ``language_name`` to the target language of the backend. @@ -243,9 +231,9 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: class BaseBackend(Backend): - MODULE_GENERATOR_CLASS: ClassVar[Type[BaseModuleGenerator]] + MODULE_GENERATOR_CLASS: ClassVar[type[BaseModuleGenerator]] - def load(self) -> Optional[Type[StencilObject]]: + def load(self) -> type[StencilObject] | None: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() @@ -266,11 +254,11 @@ def load(self) -> Optional[Type[StencilObject]]: return stencil_class - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) return self.make_module() - def _load(self) -> Type[StencilObject]: + def _load(self) -> type[StencilObject]: stencil_class_name = self.builder.class_name file_name = str(self.builder.module_path) stencil_module = gt_utils.make_module_from_file(stencil_class_name, file_name) @@ -292,7 +280,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None: stacklevel=2, ) - def make_module(self, **kwargs: Any) -> Type[StencilObject]: + def make_module(self, **kwargs: Any) -> type[StencilObject]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() @@ -312,7 +300,7 @@ def make_module(self, **kwargs: Any) -> Type[StencilObject]: return module - def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: + def make_module_source(self, *, args_data: ModuleData | None = None, **kwargs: Any) -> str: """Generate the module source code with or without stencil id.""" args_data = args_data or make_args_data_from_gtir(self.builder.gtir_pipeline) source = self.MODULE_GENERATOR_CLASS()(args_data, self.builder, **kwargs) @@ -320,7 +308,7 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs class MakeModuleSourceCallable(Protocol): - def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ... + def __call__(self, *, args_data: ModuleData | None = None, **kwargs: Any) -> str: ... class PurePythonBackendCLIMixin(CLIBackendMixin): @@ -334,12 +322,12 @@ class PurePythonBackendCLIMixin(CLIBackendMixin): #: :py:meth:`BaseBackend`. make_module_source: MakeModuleSourceCallable - def generate_computation(self) -> Dict[str, Union[str, Dict]]: + def generate_computation(self) -> dict[str, str | dict]: file_name = self.builder.module_path.name source = self.make_module_source(ir=self.builder.gtir) return {str(file_name): source} - def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: + def generate_bindings(self, language_name: str) -> dict[str, str | dict]: """Pure python backends typically will not support bindings.""" return super().generate_bindings(language_name) @@ -362,7 +350,7 @@ def pyext_build_dir_path(self) -> pathlib.Path: return self.builder.pkg_path.joinpath(self.pyext_module_name + "_BUILD") @property - def extra_cache_info(self) -> Dict[str, Any]: + def extra_cache_info(self) -> dict[str, Any]: pyext_file_path = self.builder.backend_data.get("pyext_file_path", None) pyext_md5 = "" if pyext_file_path: @@ -374,23 +362,23 @@ def extra_cache_info(self) -> Dict[str, Any]: } @property - def extra_cache_validation_keys(self) -> List[str]: + def extra_cache_validation_keys(self) -> list[str]: keys = super().extra_cache_validation_keys if self.extra_cache_info["pyext_md5"]: keys.append("pyext_md5") return keys @abc.abstractmethod - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: pass def build_extension_module( self, - pyext_sources: Dict[str, Any], - pyext_build_opts: Dict[str, str], + pyext_sources: dict[str, Any], + pyext_build_opts: dict[str, str], *, uses_cuda: bool = False, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: # Build extension module pyext_build_path = pathlib.Path( os.path.relpath(self.pyext_build_dir_path, pathlib.Path.cwd()) @@ -409,7 +397,7 @@ def build_extension_module( pyext_target_file_path = self.builder.pkg_path qualified_pyext_name = self.pyext_module_path - pyext_build_args: Dict[str, Any] = dict( + pyext_build_args: dict[str, Any] = dict( name=qualified_pyext_name, sources=sources, build_path=str(pyext_build_path), @@ -431,15 +419,15 @@ def build_extension_module( return module_name, file_path -def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], Type[Backend]]: +def disabled(message: str, *, enabled_env_var: str) -> Callable[[type[Backend]], type[Backend]]: # We push for hard deprecation here by raising by default and warning if enabling has been forced. enabled = bool(int(os.environ.get(enabled_env_var, "0"))) if enabled: return deprecated(message) else: - def _decorator(cls: Type[Backend]) -> Type[Backend]: - def _no_generate(obj) -> Type[StencilObject]: + def _decorator(cls: type[Backend]) -> type[Backend]: + def _no_generate(obj) -> type[StencilObject]: raise NotImplementedError( f"Disabled '{cls.name}' backend: 'f{message}'\n", f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'", diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index 7b77175750..3408d178b7 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from gt4py import storage as gt_storage from gt4py.cartesian.backend.base import CLIBackendMixin, disabled, register @@ -37,12 +37,12 @@ class CudaExtGenerator(BackendCodegen): - def __init__(self, class_name, module_name, backend): + def __init__(self, class_name: str, module_name: str, backend: CudaBackend) -> None: self.class_name = class_name self.module_name = module_name self.backend = backend - def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def __call__(self, stencil_ir: gtir.Stencil) -> dict[str, dict[str, str]]: stencil_ir = GtirPipeline(stencil_ir, self.backend.builder.stencil_id).full() base_oir = GTIRToOIR().visit(stencil_ir) oir_pipeline = self.backend.builder.options.backend_opts.get( @@ -141,15 +141,12 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin): MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator GT_BACKEND_T = "gpu" - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return self.make_extension(uses_cuda=True) - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - pyext_module_name: Optional[str] - pyext_file_path: Optional[str] - # TODO(havogt) add bypass if computation has no effect pyext_module_name, pyext_file_path = self.generate_extension() diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5e4d33bf71..bc9c03b005 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -12,7 +12,7 @@ import os import pathlib import re -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import dace import dace.data @@ -228,7 +228,7 @@ def _sdfg_add_arrays_and_edges( ) -def _sdfg_specialize_symbols(wrapper_sdfg, domain: Tuple[int, ...]): +def _sdfg_specialize_symbols(wrapper_sdfg, domain: tuple[int, ...]): ival, jval, kval = domain[0], domain[1], domain[2] for sdfg in wrapper_sdfg.all_sdfgs_recursive(): if sdfg.parent_nsdfg_node is not None: @@ -316,7 +316,7 @@ def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, doma class SDFGManager: # Cache loaded SDFGs across all instances - _loaded_sdfgs: ClassVar[Dict[str, dace.SDFG]] = dict() + _loaded_sdfgs: ClassVar[dict[str, dace.SDFG]] = dict() def __init__(self, builder): self.builder = builder @@ -374,7 +374,7 @@ def _expanded_sdfg(self): def expanded_sdfg(self): return copy.deepcopy(self._expanded_sdfg()) - def _frozen_sdfg(self, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...]): + def _frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]): frozen_hash = shash(origin, domain) # check if same sdfg already cached on disk path = self.builder.module_path @@ -398,21 +398,20 @@ def _frozen_sdfg(self, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, SDFGManager._loaded_sdfgs[path] = sdfg return SDFGManager._loaded_sdfgs[path] - def frozen_sdfg(self, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...]): + def frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]): return copy.deepcopy(self._frozen_sdfg(origin=origin, domain=domain)) class DaCeExtGenerator(BackendCodegen): - def __init__(self, class_name, module_name, backend): + def __init__(self, class_name: str, module_name: str, backend: BaseDaceBackend) -> None: self.class_name = class_name self.module_name = module_name self.backend = backend - def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def __call__(self, stencil_ir: gtir.Stencil) -> dict[str, dict[str, str]]: manager = SDFGManager(self.backend.builder) sdfg = manager.expanded_sdfg() - sources: Dict[str, Dict[str, str]] implementation = DaCeComputationCodegen.apply(stencil_ir, self.backend.builder, sdfg) bindings = DaCeBindingsCodegen.apply( @@ -420,11 +419,10 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: ) bindings_ext = "cu" if self.backend.storage_info["device"] == "gpu" else "cpp" - sources = { + return { "computation": {"computation.hpp": implementation}, "bindings": {f"bindings.{bindings_ext}": bindings}, } - return sources class DaCeComputationCodegen: @@ -573,7 +571,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDF return generated_code - def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> List[str]: + def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> list[str]: oir = GTIRToOIR().visit(stencil_ir) field_extents = compute_fields_extents(oir, add_k=True) @@ -581,7 +579,7 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> List[ field_name: max(boundary[0], 0) for field_name, boundary in compute_k_boundary(stencil_ir).items() } - offset_dict: Dict[str, Tuple[int, int, int]] = { + offset_dict: dict[str, tuple[int, int, int]] = { k: (max(-v[0][0], 0), max(-v[1][0], 0), k_origins[k] if k in k_origins else 0) for k, v in field_extents.items() } @@ -663,8 +661,8 @@ def unique_index(self) -> int: mako_template = bindings_main_template() - def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> List[str]: - res: Dict[str, str] = {} + def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> list[str]: + res: dict[str, str] = {} for name in sdfg.signature_arglist(with_types=False, for_call=True): if name in sdfg.arrays: @@ -684,8 +682,8 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li res[name] = "{dtype} {name}".format(dtype=sdfg.symbols[name].ctype, name=name) return list(res[node.name] for node in stencil_ir.params if node.name in res) - def generate_sid_params(self, sdfg: dace.SDFG) -> List[str]: - res: List[str] = [] + def generate_sid_params(self, sdfg: dace.SDFG) -> list[str]: + res: list[str] = [] for name, array in sdfg.arrays.items(): if array.transient: @@ -756,12 +754,9 @@ class BaseDaceBackend(BaseGTBackend, CLIBackendMixin): GT_BACKEND_T = "dace" PYEXT_GENERATOR_CLASS = DaCeExtGenerator - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - pyext_module_name: Optional[str] - pyext_file_path: Optional[str] - # TODO(havogt) add bypass if computation has no effect pyext_module_name, pyext_file_path = self.generate_extension() @@ -787,7 +782,7 @@ class DaceCPUBackend(BaseDaceBackend): options = BaseGTBackend.GT_BACKEND_OPTS - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return self.make_extension(uses_cuda=False) @@ -811,5 +806,5 @@ class DaceGPUBackend(BaseDaceBackend): "device_sync": {"versioning": True, "type": bool}, } - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return self.make_extension(uses_cuda=True) diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index 687eae2738..a6cb2be565 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -181,7 +181,6 @@ def normalize_args( **kwargs, ): backend_cls = gt_backend.from_name(backend) - assert backend_cls is not None args_iter = iter(args) args_as_kwargs = { name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index b7b6018eec..80dcf0b211 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -12,7 +12,7 @@ import os import textwrap import time -from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Final from gt4py.cartesian import backend as gt_backend, config as gt_config, utils as gt_utils from gt4py.cartesian.backend import Backend @@ -41,7 +41,7 @@ def pybuffer_to_sid( *, name: str, ctype: str, - domain_dim_flags: Tuple[bool, bool, bool], + domain_dim_flags: tuple[bool, bool, bool], data_ndim: int, stride_kind_index: int, backend: Backend, @@ -125,8 +125,8 @@ def gtir_has_effect(pipeline: GtirPipeline) -> bool: class PyExtModuleGenerator(BaseModuleGenerator): """Module Generator for use with backends that generate c++ python extensions.""" - pyext_module_name: Optional[str] - pyext_file_path: Optional[str] + pyext_module_name: str | None + pyext_file_path: str | None def __init__(self): super().__init__() @@ -134,7 +134,7 @@ def __init__(self): self.pyext_file_path = None def __call__( - self, args_data: ModuleData, builder: Optional[StencilBuilder] = None, **kwargs: Any + self, args_data: ModuleData, builder: StencilBuilder | None = None, **kwargs: Any ) -> str: self.pyext_module_name = kwargs["pyext_module_name"] self.pyext_file_path = kwargs["pyext_file_path"] @@ -175,7 +175,7 @@ def generate_implementation(self) -> str: ir = self.builder.gtir sources = gt_utils.text.TextBlock(indent_size=BaseModuleGenerator.TEMPLATE_INDENT_SIZE) - args: List[str] = [] + args: list[str] = [] for decl in ir.params: args.append(decl.name) if isinstance(decl, gtir.FieldDecl): @@ -199,19 +199,19 @@ def generate_implementation(self) -> str: class BackendCodegen: - TEMPLATE_FILES: Dict[str, str] + TEMPLATE_FILES: dict[str, str] @abc.abstractmethod - def __init__(self, class_name: str, module_name: str, backend: Any): + def __init__(self, class_name: str, module_name: str, backend: Backend) -> None: pass @abc.abstractmethod - def __call__(self, ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def __call__(self, ir: gtir.Stencil) -> dict[str, dict[str, str]]: """Return a dict with the keys 'computation' and 'bindings' to dicts of filenames to source.""" pass -GTBackendOptions = Dict[str, Dict[str, Any]] +GTBackendOptions = dict[str, dict[str, Any]] class BaseGTBackend(gt_backend.BasePyExtBackend, gt_backend.CLIBackendMixin): @@ -229,18 +229,18 @@ class BaseGTBackend(gt_backend.BasePyExtBackend, gt_backend.CLIBackendMixin): MODULE_GENERATOR_CLASS = PyExtModuleGenerator - PYEXT_GENERATOR_CLASS: Type[BackendCodegen] + PYEXT_GENERATOR_CLASS: type[BackendCodegen] @abc.abstractmethod - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: pass - def generate_computation(self) -> Dict[str, Union[str, Dict]]: + def generate_computation(self) -> dict[str, str | dict]: dir_name = f"{self.builder.options.name}_src" src_files = self._make_extension_sources() return {dir_name: src_files["computation"]} - def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: + def generate_bindings(self, language_name: str) -> dict[str, str | dict]: if language_name != "python": return super().generate_bindings(language_name) @@ -249,7 +249,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: return {dir_name: src_files["bindings"]} @abc.abstractmethod - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: """ Generate and build a python extension for the stencil computation. @@ -257,14 +257,14 @@ def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: """ pass - def make_extension(self, *, uses_cuda: bool = False) -> Tuple[str, str]: + def make_extension(self, *, uses_cuda: bool = False) -> tuple[str, str]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() # Generate source - gt_pyext_files: Dict[str, Any] - gt_pyext_sources: Dict[str, Any] + gt_pyext_files: dict[str, Any] + gt_pyext_sources: dict[str, Any] if self.builder.options._impl_opts.get("disable-code-generation", False): # Pass NOTHING to the self.builder means try to reuse the source code files gt_pyext_files = {} @@ -307,7 +307,7 @@ def make_extension(self, *, uses_cuda: bool = False) -> Tuple[str, str]: return result - def _make_extension_sources(self) -> Dict[str, Dict[str, str]]: + def _make_extension_sources(self) -> dict[str, dict[str, str]]: """Generate the source for the stencil independently from use case.""" if "computation_src" in self.builder.backend_data: return self.builder.backend_data["computation_src"] diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 91f616891b..fe9604d8c9 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from gt4py import storage as gt_storage from gt4py.cartesian.backend.base import CLIBackendMixin, register @@ -35,12 +35,12 @@ class GTExtGenerator(BackendCodegen): - def __init__(self, class_name, module_name, backend): + def __init__(self, class_name: str, module_name: str, backend: BaseGTBackend) -> None: self.class_name = class_name self.module_name = module_name self.backend = backend - def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def __call__(self, stencil_ir: gtir.Stencil) -> dict[str, dict[str, str]]: stencil_ir = GtirPipeline(stencil_ir, self.backend.builder.stencil_id).full() base_oir = GTIRToOIR().visit(stencil_ir) oir_pipeline = self.backend.builder.options.backend_opts.get( @@ -129,15 +129,12 @@ class GTBaseBackend(BaseGTBackend, CLIBackendMixin): options = BaseGTBackend.GT_BACKEND_OPTS PYEXT_GENERATOR_CLASS = GTExtGenerator - def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]: + def _generate_extension(self, uses_cuda: bool) -> tuple[str, str]: return self.make_extension(uses_cuda=uses_cuda) - def generate(self) -> Type[StencilObject]: + def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - pyext_module_name: Optional[str] - pyext_file_path: Optional[str] - # TODO(havogt) add bypass if computation has no effect pyext_module_name, pyext_file_path = self.generate_extension() @@ -156,7 +153,7 @@ class GTCpuIfirstBackend(GTBaseBackend): languages: ClassVar[dict] = {"computation": "c++", "bindings": ["python"]} storage_info = gt_storage.layout.CPUIFirstLayout - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return super()._generate_extension(uses_cuda=False) @@ -169,7 +166,7 @@ class GTCpuKfirstBackend(GTBaseBackend): languages: ClassVar[dict] = {"computation": "c++", "bindings": ["python"]} storage_info = gt_storage.layout.CPUKFirstLayout - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return super()._generate_extension(uses_cuda=False) @@ -187,5 +184,5 @@ class GTGpuBackend(GTBaseBackend): } storage_info = gt_storage.layout.CUDALayout - def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: + def generate_extension(self, **kwargs: Any) -> tuple[str, str]: return super()._generate_extension(uses_cuda=True) diff --git a/src/gt4py/cartesian/frontend/base.py b/src/gt4py/cartesian/frontend/base.py index 5e542cd36d..e0b2423cae 100644 --- a/src/gt4py/cartesian/frontend/base.py +++ b/src/gt4py/cartesian/frontend/base.py @@ -9,26 +9,12 @@ from __future__ import annotations import abc -from typing import Any, Dict, Optional, Type, Union +from typing import Any from gt4py.cartesian import utils as gt_utils from gt4py.cartesian.definitions import BuildOptions, StencilID from gt4py.cartesian.gtc import gtir -from gt4py.cartesian.type_hints import AnnotatedStencilFunc, StencilFunc - - -REGISTRY = gt_utils.Registry() -AnyStencilFunc = Union[StencilFunc, AnnotatedStencilFunc] - - -def from_name(name: str) -> Optional[Type[Frontend]]: - """Return frontend by name.""" - return REGISTRY.get(name, None) - - -def register(frontend_cls: Type[Frontend]) -> None: - """Register a new frontend.""" - return REGISTRY.register(frontend_cls.name, frontend_cls) +from gt4py.cartesian.type_hints import AnnotatedStencilFunc, AnyStencilFunc class Frontend(abc.ABC): @@ -41,7 +27,7 @@ def get_stencil_id( cls, qualified_name: str, definition: AnyStencilFunc, - externals: Dict[str, Any], + externals: dict[str, Any], options_id: str, ) -> StencilID: """ @@ -71,8 +57,8 @@ def get_stencil_id( def generate( cls, definition: AnyStencilFunc, - externals: Dict[str, Any], - dtypes: Dict[Type, Type], + externals: dict[str, Any], + dtypes: dict[type, type], options: BuildOptions, backend_name: str, ) -> gtir.Stencil: @@ -95,7 +81,7 @@ def generate( @classmethod @abc.abstractmethod def prepare_stencil_definition( - cls, definition: AnyStencilFunc, externals: Dict[str, Any] + cls, definition: AnyStencilFunc, externals: dict[str, Any] ) -> AnnotatedStencilFunc: """ Annotate the stencil function if not already done so. @@ -109,3 +95,21 @@ def prepare_stencil_definition( If there is a error resolving external types. """ pass + + +REGISTRY = gt_utils.Registry[type[Frontend]]() + + +def from_name(name: str) -> type[Frontend]: + """Return frontend by name.""" + frontend_cls = REGISTRY.get(name, None) + if frontend_cls is None: + raise NotImplementedError( + f"Frontend '{name} is not implemented. Options are: {REGISTRY.names}." + ) + return frontend_cls + + +def register(frontend_cls: type[Frontend]) -> type[Frontend]: + """Register a new frontend.""" + return REGISTRY.register(frontend_cls.name, frontend_cls) diff --git a/src/gt4py/cartesian/loader.py b/src/gt4py/cartesian/loader.py index 889618ef0d..8b597759c5 100644 --- a/src/gt4py/cartesian/loader.py +++ b/src/gt4py/cartesian/loader.py @@ -15,9 +15,9 @@ from __future__ import annotations import types -from typing import TYPE_CHECKING, Any, Dict, Type +from typing import TYPE_CHECKING, Any -from gt4py.cartesian import backend as gt_backend, frontend as gt_frontend +from gt4py.cartesian import frontend as gt_frontend from gt4py.cartesian.stencil_builder import StencilBuilder from gt4py.cartesian.type_hints import StencilFunc @@ -31,19 +31,13 @@ def load_stencil( frontend_name: str, backend_name: str, definition_func: StencilFunc, - externals: Dict[str, Any], - dtypes: Dict[Type, Type], + externals: dict[str, Any], + dtypes: dict[type, type], build_options: BuildOptions, -) -> Type[StencilObject]: +) -> type[StencilObject]: """Generate a new class object implementing the provided definition.""" # Load components - backend_cls = gt_backend.from_name(backend_name) - if backend_cls is None: - raise ValueError(f"Unknown backend name ({backend_name})") - frontend = gt_frontend.from_name(frontend_name) - if frontend is None: - raise ValueError(f"Invalid frontend name ({frontend_name})") builder = ( StencilBuilder( @@ -60,8 +54,8 @@ def gtscript_loader( definition_func: StencilFunc, backend: str, build_options: BuildOptions, - externals: Dict[str, Any], - dtypes: Dict[Type, Type], + externals: dict[str, Any], + dtypes: dict[type, type], ) -> StencilObject: if not isinstance(definition_func, types.FunctionType): raise ValueError("Invalid stencil definition object ({obj})".format(obj=definition_func)) diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index 69fff482f6..9ceba35fc4 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -9,13 +9,13 @@ from __future__ import annotations import pathlib -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any from gt4py import cartesian as gt4pyc from gt4py.cartesian.definitions import BuildOptions, StencilID from gt4py.cartesian.gtc import gtir from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline -from gt4py.cartesian.type_hints import AnnotatedStencilFunc, StencilFunc +from gt4py.cartesian.type_hints import AnnotatedStencilFunc, AnyStencilFunc if TYPE_CHECKING: @@ -50,30 +50,25 @@ class StencilBuilder: def __init__( self, - definition_func: Union[StencilFunc, AnnotatedStencilFunc], + definition_func: AnyStencilFunc, *, - backend: Optional[Union[str, Type[BackendType]]] = None, - options: Optional[BuildOptions] = None, - frontend: Optional[Type[FrontendType]] = None, + backend: str | type[BackendType] | None = None, + options: BuildOptions | None = None, + frontend: type[FrontendType] | None = None, ): self._definition = definition_func self.options = options or BuildOptions(**self.default_options_dict(definition_func)) backend = backend or "numpy" backend = gt4pyc.backend.from_name(backend) if isinstance(backend, str) else backend - if backend is None: - raise RuntimeError(f"Unknown backend: {backend}") - frontend = frontend or gt4pyc.frontend.from_name("gtscript") - if frontend is None: - raise RuntimeError(f"Unknown frontend: {frontend}") self.backend: BackendType = backend(self) - self.frontend: Type[FrontendType] = frontend + self.frontend: type[FrontendType] = frontend self.with_caching("jit") - self._externals: Dict[str, Any] = {} - self._dtypes: Dict[Type, Type] = {} + self._externals: dict[str, Any] = {} + self._dtypes: dict[type, type] = {} - def build(self) -> Type[StencilObject]: + def build(self) -> type[StencilObject]: """Generate, compile and/or load everything necessary to provide a usable stencil class.""" # load or generate stencil_class = None if self.options.rebuild else self.backend.load() @@ -85,11 +80,11 @@ def build(self) -> Type[StencilObject]: stencil_class = self.backend.generate() return stencil_class - def generate_computation(self) -> Dict[str, Union[str, Dict]]: + def generate_computation(self) -> dict[str, str | dict]: """Generate the stencil source code, fail if backend does not support CLI.""" return self.cli_backend.generate_computation() - def generate_bindings(self, targe_language: str) -> Dict[str, Union[str, Dict]]: + def generate_bindings(self, targe_language: str) -> dict[str, str | dict]: """Generate ``target_language`` bindings source, fail if backend does not support CLI.""" return self.cli_backend.generate_bindings(targe_language) @@ -114,7 +109,7 @@ def with_caching( ----- Resets all cached build data. """ - self._build_data: Dict[str, Any] = {} + self._build_data: dict[str, Any] = {} kwargs = {**self.options.cache_settings, **kwargs} self.caching = gt4pyc.caching.strategy_factory(caching_strategy_name, self, *args, **kwargs) return self @@ -141,7 +136,7 @@ def with_options( self.options = BuildOptions(name=name, module=module, **kwargs) # type: ignore return self - def with_changed_options(self: StencilBuilder, **kwargs: Dict[str, Any]) -> StencilBuilder: + def with_changed_options(self: StencilBuilder, **kwargs: dict[str, Any]) -> StencilBuilder: old_options = self.options.as_dict() # BuildOptions constructor expects ``impl_opts`` keyword # but BuildOptions.as_dict outputs ``_impl_opts`` key @@ -163,18 +158,15 @@ def with_backend(self: StencilBuilder, backend_name: str) -> StencilBuilder: """ self._build_data = {} backend = gt4pyc.backend.from_name(backend_name) - assert backend is not None self.backend = backend(self) return self @classmethod - def default_options_dict( - cls, definition_func: Union[StencilFunc, AnnotatedStencilFunc] - ) -> Dict[str, Any]: + def default_options_dict(cls, definition_func: AnyStencilFunc) -> dict[str, Any]: return {"name": definition_func.__name__, "module": definition_func.__module__} @classmethod - def name_to_options_args(cls, name: Optional[str]) -> Dict[str, str]: + def name_to_options_args(cls, name: str | None) -> dict[str, str]: """Check for qualified name, extract also module option in that case.""" if not name: return {} @@ -185,7 +177,7 @@ def name_to_options_args(cls, name: Optional[str]) -> Dict[str, str]: return data @classmethod - def nest_impl_options(cls, options_dict: Dict[str, Any]) -> Dict[str, Any]: + def nest_impl_options(cls, options_dict: dict[str, Any]) -> dict[str, Any]: impl_opts = options_dict.setdefault("impl_opts", {}) # The following is not a dict comprehension because: # The backend-specific options (starting with ``_``) are nested under @@ -203,18 +195,18 @@ def definition(self) -> AnnotatedStencilFunc: ) @property - def externals(self) -> Dict[str, Any]: + def externals(self) -> dict[str, Any]: return self._build_data.get("externals") or self._build_data.setdefault( "externals", self._externals.copy() ) @property - def dtypes(self) -> Dict[Type, Type]: + def dtypes(self) -> dict[type, type]: return self._build_data.get("dtypes") or self._build_data.setdefault( "dtypes", self._dtypes.copy() ) - def with_externals(self: StencilBuilder, externals: Dict[str, Any]) -> StencilBuilder: + def with_externals(self: StencilBuilder, externals: dict[str, Any]) -> StencilBuilder: """ Fluidly set externals for this build. @@ -225,17 +217,17 @@ def with_externals(self: StencilBuilder, externals: Dict[str, Any]) -> StencilBu self.with_caching(self.caching.name) return self - def with_dtypes(self: StencilBuilder, dtypes: Dict[Type, Type]) -> StencilBuilder: + def with_dtypes(self: StencilBuilder, dtypes: dict[type, type]) -> StencilBuilder: self._build_data = {} self._dtypes = dtypes self.with_caching(self.caching.name) return self @property - def backend_data(self) -> Dict[str, Any]: + def backend_data(self) -> dict[str, Any]: return self._build_data.get("backend_data", {}).copy() - def with_backend_data(self: StencilBuilder, data: Dict[str, Any]) -> StencilBuilder: + def with_backend_data(self: StencilBuilder, data: dict[str, Any]) -> StencilBuilder: self._build_data["backend_data"] = {**self.backend_data, **data} return self diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 1cdb859eab..090c8c2e10 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from numbers import Number from pickle import dumps -from typing import Any, Callable, ClassVar, Dict, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, ClassVar, Literal, Union, cast import numpy as np @@ -32,14 +32,14 @@ cupy = None FieldType = Union["cp.ndarray", np.ndarray] -OriginType = Union[Tuple[int, int, int], Dict[str, Tuple[int, ...]]] +OriginType = Union[tuple[int, int, int], dict[str, tuple[int, ...]]] def _compute_domain_origin_cache_key( - field_args_info: Dict[str, Optional[ArgsInfo]], - parameter_args: Dict[str, Optional[Number]], - domain: Optional[Tuple[int, ...]], - origin: Optional[OriginType], + field_args_info: dict[str, ArgsInfo | None], + parameter_args: dict[str, Number | None], + domain: tuple[int, ...] | None, + origin: OriginType | None, ) -> int: field_data = tuple( (name, arg.array.shape, arg.origin or (0, 0, 0)) @@ -53,15 +53,15 @@ def _compute_domain_origin_cache_key( class ArgsInfo: device: str array: FieldType - original_object: Optional[Any] = None - origin: Optional[Tuple[int, ...]] = None - dimensions: Optional[Tuple[str, ...]] = None + original_object: Any | None = None + origin: tuple[int, ...] | None = None + dimensions: tuple[str, ...] | None = None def _extract_array_infos( - field_args: Dict[str, Optional[FieldType]], device: Literal["cpu", "gpu"] -) -> Dict[str, Optional[ArgsInfo]]: - array_infos: Dict[str, Optional[ArgsInfo]] = {} + field_args: dict[str, FieldType | None], device: Literal["cpu", "gpu"] +) -> dict[str, ArgsInfo | None]: + array_infos: dict[str, ArgsInfo | None] = {} for name, arg in field_args.items(): if arg is None: array_infos[name] = None @@ -86,8 +86,8 @@ def _extract_array_infos( def _extract_stencil_arrays( - array_infos: Dict[str, Optional[ArgsInfo]], -) -> Dict[str, Optional[FieldType]]: + array_infos: dict[str, ArgsInfo | None], +) -> dict[str, FieldType | None]: return {name: info.array if info is not None else None for name, info in array_infos.items()} @@ -96,8 +96,8 @@ class FrozenStencil: """Stencil with pre-computed domain and origin for each field argument.""" stencil_object: StencilObject - origin: Dict[str, Tuple[int, ...]] - domain: Tuple[int, ...] + origin: dict[str, tuple[int, ...]] + domain: tuple[int, ...] def __post_init__(self): for name, field_info in self.stencil_object.field_info.items(): @@ -188,7 +188,7 @@ class StencilObject(abc.ABC): _gt_id_: str definition_func: Callable[..., Any] - _domain_origin_cache: ClassVar[Dict[int, Tuple[Tuple[int, ...], Dict[str, Tuple[int, ...]]]]] + _domain_origin_cache: ClassVar[dict[int, tuple[tuple[int, ...], dict[str, tuple[int, ...]]]]] """Stores domain/origin pairs that have been used by hash.""" def __new__(cls, *args, **kwargs): @@ -248,22 +248,22 @@ def domain_info(self) -> DomainInfo: @property @abc.abstractmethod - def field_info(self) -> Dict[str, FieldInfo]: + def field_info(self) -> dict[str, FieldInfo]: pass @property @abc.abstractmethod - def parameter_info(self) -> Dict[str, ParameterInfo]: + def parameter_info(self) -> dict[str, ParameterInfo]: pass @property @abc.abstractmethod - def constants(self) -> Dict[str, Any]: + def constants(self) -> dict[str, Any]: pass @property @abc.abstractmethod - def options(self) -> Dict[str, Any]: + def options(self) -> dict[str, Any]: pass @abc.abstractmethod @@ -276,8 +276,8 @@ def __call__(self, *args, **kwargs) -> None: @staticmethod def _make_origin_dict( - origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None], - ) -> Dict[str, Tuple[int, ...]]: + origin: dict[str, tuple[int, ...]] | tuple[int, ...] | int | None, + ) -> dict[str, tuple[int, ...]]: try: if isinstance(origin, dict): # This is needed because the keys in origin are StringLiteral as of DaCe v0.14, and they @@ -287,9 +287,9 @@ def _make_origin_dict( if origin is None: return {} if isinstance(origin, collections.abc.Iterable): - return {"_all_": cast(Tuple[int, ...], Index.from_value(origin))} + return {"_all_": cast(tuple[int, ...], Index.from_value(origin))} if isinstance(origin, int): - return {"_all_": cast(Tuple[int, ...], Index.from_k(origin))} + return {"_all_": cast(tuple[int, ...], Index.from_k(origin))} except Exception: pass @@ -297,10 +297,10 @@ def _make_origin_dict( @staticmethod def _get_max_domain( - array_infos: Dict[str, Optional[ArgsInfo]], + array_infos: dict[str, ArgsInfo | None], domain_infos: DomainInfo, - field_infos: Dict[str, FieldInfo], - origin: Dict[str, Tuple[int, ...]], + field_infos: dict[str, FieldInfo], + origin: dict[str, tuple[int, ...]], *, squeeze: bool = True, ) -> Shape: @@ -344,10 +344,10 @@ def _get_max_domain( def _validate_args( # Function is too complex self, - arg_infos: Dict[str, Optional[ArgsInfo]], - param_args: Dict[str, Any], - domain: Tuple[int, ...], - origin: Dict[str, Tuple[int, ...]], + arg_infos: dict[str, ArgsInfo | None], + param_args: dict[str, Any], + domain: tuple[int, ...], + origin: dict[str, tuple[int, ...]], ) -> None: """ Validate input arguments to _call_run. @@ -399,9 +399,6 @@ def _validate_args( # Function is too complex arg_info = arg_infos[name] assert arg_info is not None - backend_cls = gt_backend.from_name(self.backend) - assert backend_cls is not None - if not backend_cls.storage_info["is_optimal_layout"]( arg_info.array, tuple( @@ -485,10 +482,10 @@ def _validate_args( # Function is too complex @staticmethod def _normalize_origins( - array_infos: Dict[str, Optional[ArgsInfo]], - field_infos: Dict[str, FieldInfo], - origin: Optional[OriginType], - ) -> Dict[str, Tuple[int, ...]]: + array_infos: dict[str, ArgsInfo | None], + field_infos: dict[str, FieldInfo], + origin: OriginType | None, + ) -> dict[str, tuple[int, ...]]: origin = StencilObject._make_origin_dict(origin) all_origin = origin.get("_all_", None) # Set an appropriate origin for all fields @@ -520,13 +517,13 @@ def _normalize_origins( def _call_run( self, - field_args: Dict[str, FieldType], - parameter_args: Dict[str, Any], - domain: Optional[Tuple[int, ...]], - origin: Optional[OriginType], + field_args: dict[str, FieldType], + parameter_args: dict[str, Any], + domain: tuple[int, ...] | None, + origin: OriginType | None, *, validate_args: bool = True, - exec_info: Optional[Dict[str, Any]] = None, + exec_info: dict[str, Any] | None = None, ) -> None: """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses). @@ -560,7 +557,6 @@ def _call_run( if exec_info is not None: exec_info["call_run_start_time"] = time.perf_counter() backend_cls = gt_backend.from_name(self.backend) - assert backend_cls is not None device = backend_cls.storage_info["device"] array_infos = _extract_array_infos(field_args, device) @@ -593,7 +589,7 @@ def _call_run( exec_info["call_run_end_time"] = time.perf_counter() def freeze( - self: StencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] + self: StencilObject, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...] ) -> FrozenStencil: """Return a StencilObject wrapper with a fixed domain and origin for each argument. diff --git a/src/gt4py/cartesian/type_hints.py b/src/gt4py/cartesian/type_hints.py index b94c2a0bbd..10967cecc0 100644 --- a/src/gt4py/cartesian/type_hints.py +++ b/src/gt4py/cartesian/type_hints.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict +from typing import Any from typing_extensions import Protocol @@ -15,8 +15,11 @@ class StencilFunc(Protocol): __name__: str __module__: str - def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ... + def __call__(self, *args: Any, **kwargs: dict[str, Any]) -> None: ... class AnnotatedStencilFunc(StencilFunc, Protocol): - _gtscript_: Dict[str, Any] + _gtscript_: dict[str, Any] + + +AnyStencilFunc = StencilFunc | AnnotatedStencilFunc diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index 084377827d..961cb4dc4c 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -19,6 +19,7 @@ import sys import time import types +from typing import Generic, TypeVar from gt4py.cartesian import config as gt_config @@ -333,20 +334,23 @@ def restore_module(patch, *, verify=True): current.__dict__[name] = original_value -class Registry(dict): +T = TypeVar("T") + + +class Registry(Generic[T], dict[str, T]): @property - def names(self): + def names(self) -> list[str]: return list(self.keys()) - def register(self, name, item=NOTHING): + def register(self, name: str, item: T) -> T: if name in self.keys(): - raise ValueError("Name already exists in registry") + raise ValueError(f"Name {name} already exists in registry.") def _wrapper(obj): self[name] = obj return obj - return _wrapper if item is NOTHING else _wrapper(item) + return _wrapper(item) class ClassProperty: diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 38cb6caca8..6f838d70c1 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -60,7 +60,6 @@ def id_version(): def get_array_library(backend: str): """Return device ready array maker library""" backend_cls = gt4pyc.backend.from_name(backend) - assert backend_cls is not None if backend_cls.storage_info["device"] == "gpu": assert cp is not None return cp diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py index c3bf40e456..cb5a508d20 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py @@ -39,7 +39,6 @@ def test_numpy_allocators(backend, order): def test_bad_layout_warns(backend): xp = get_array_library(backend) backend_cls = gt4pyc.backend.from_name(backend) - assert backend_cls is not None shape = (10, 10, 10) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_definitive_assignment_analysis.py b/tests/cartesian_tests/unit_tests/test_gtc/test_definitive_assignment_analysis.py index 148eb4897f..43ad8cb866 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_definitive_assignment_analysis.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_definitive_assignment_analysis.py @@ -7,18 +7,17 @@ # SPDX-License-Identifier: BSD-3-Clause import typing -from typing import Callable, List, Tuple, TypedDict +from typing import Callable import pytest -from gt4py.cartesian.backend import from_name from gt4py.cartesian.gtc.passes import gtir_definitive_assignment_analysis as daa from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval, stencil from gt4py.cartesian.stencil_builder import StencilBuilder # A list of dictionaries containing a stencil definition and the expected test case outputs -test_data: List[Tuple[Callable, bool]] = [] +test_data: list[tuple[Callable, bool]] = [] def register_test_case(*, valid):