Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gt4py/cartesian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import typing

from . import (
backend,
caching,
cli,
config,
Expand All @@ -28,6 +29,7 @@

__all__ = [
"StencilObject",
"backend",
"caching",
"cli",
"config",
Expand Down
122 changes: 55 additions & 67 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,7 @@
import pathlib
import time
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Protocol,
Tuple,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol

from typing_extensions import deprecated

Expand All @@ -45,32 +33,6 @@
from gt4py.cartesian.stencil_builder import StencilBuilder
from gt4py.cartesian.stencil_object import StencilObject

REGISTRY = gt_utils.Registry()


def from_name(name: str) -> Optional[Type[Backend]]:
backend = REGISTRY.get(name, None)
if not backend:
raise NotImplementedError(
f"Backend {name} is not implemented, options are: {REGISTRY.names}"
)
return backend


def register(backend_cls: Type[Backend]) -> Type[Backend]:
assert issubclass(backend_cls, Backend) and backend_cls.name is not None

if isinstance(backend_cls.name, str):
gt_storage.register(backend_cls.name, backend_cls.storage_info)
return REGISTRY.register(backend_cls.name, backend_cls)

else:
raise ValueError(
"Invalid 'name' attribute ('{name}') in backend class '{cls}'".format(
name=backend_cls.name, cls=backend_cls
)
)


class Backend(abc.ABC):
#: Backend name
Expand All @@ -82,7 +44,7 @@ class Backend(abc.ABC):
#: - versioning: is versioning on?
#: - description [optional]
#: - type
options: ClassVar[Dict[str, Any]]
options: ClassVar[dict[str, Any]]

#: Backend-specific storage parametrization:
#:
Expand All @@ -100,15 +62,15 @@ class Backend(abc.ABC):
#:
#: Languages should be spelled using the official spelling
#: but lower case ("python", "fortran", "rust").
languages: ClassVar[Optional[Dict[str, Any]]] = None
languages: ClassVar[dict[str, Any] | None] = None

# __impl_opts:
# "disable-code-generation": bool
# "disable-cache-validation": bool

builder: StencilBuilder

def __init__(self, builder: StencilBuilder):
def __init__(self, builder: StencilBuilder) -> None:
self.builder = builder

@classmethod
Expand All @@ -125,7 +87,7 @@ def filter_options_for_id(
return filtered_options

@abc.abstractmethod
def load(self) -> Optional[Type[StencilObject]]:
def load(self) -> type[StencilObject] | None:
"""
Load the stencil class from the generated python module.

Expand All @@ -140,7 +102,7 @@ def load(self) -> Optional[Type[StencilObject]]:
pass

@abc.abstractmethod
def generate(self) -> Type[StencilObject]:
def generate(self) -> type[StencilObject]:
"""
Generate the stencil class from GTScript's internal representation.

Expand All @@ -156,19 +118,45 @@ def generate(self) -> Type[StencilObject]:
pass

@property
def extra_cache_info(self) -> Dict[str, Any]:
def extra_cache_info(self) -> dict[str, Any]:
"""Provide additional data to be stored in cache info file (subclass hook)."""
return {}

@property
def extra_cache_validation_keys(self) -> List[str]:
def extra_cache_validation_keys(self) -> list[str]:
"""List keys from extra_cache_info to be validated during consistency check."""
return []


REGISTRY = gt_utils.Registry[type[Backend]]()


def from_name(name: str) -> type[Backend]:
"""Return a backend by name."""
backend_cls = REGISTRY.get(name, None)
if backend_cls is None:
raise NotImplementedError(
f"Backend '{name}' is not implemented. Options are: {REGISTRY.names}."
)
return backend_cls


def register(backend_cls: type[Backend]) -> type[Backend]:
"""Register a backend."""
assert issubclass(backend_cls, Backend) and backend_cls.name is not None

if isinstance(backend_cls.name, str):
gt_storage.register(backend_cls.name, backend_cls.storage_info)
return REGISTRY.register(backend_cls.name, backend_cls)

raise ValueError(
f"Invalid 'name' attribute ('{backend_cls.name}') in backend class '{backend_cls}'."
)


class CLIBackendMixin(Backend):
@abc.abstractmethod
def generate_computation(self) -> Dict[str, Union[str, Dict]]:
def generate_computation(self) -> dict[str, str | dict]:
"""
Generate the computation source code in a way agnostic of the way it is going to be used.

Expand Down Expand Up @@ -218,7 +206,7 @@ def mystencil(...):
raise NotImplementedError

@abc.abstractmethod
def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
def generate_bindings(self, language_name: str) -> dict[str, str | dict]:
"""
Generate bindings source code from ``language_name`` to the target language of the backend.

Expand All @@ -243,9 +231,9 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:


class BaseBackend(Backend):
MODULE_GENERATOR_CLASS: ClassVar[Type[BaseModuleGenerator]]
MODULE_GENERATOR_CLASS: ClassVar[type[BaseModuleGenerator]]

def load(self) -> Optional[Type[StencilObject]]:
def load(self) -> type[StencilObject] | None:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand All @@ -266,11 +254,11 @@ def load(self) -> Optional[Type[StencilObject]]:

return stencil_class

def generate(self) -> Type[StencilObject]:
def generate(self) -> type[StencilObject]:
self.check_options(self.builder.options)
return self.make_module()

def _load(self) -> Type[StencilObject]:
def _load(self) -> type[StencilObject]:
stencil_class_name = self.builder.class_name
file_name = str(self.builder.module_path)
stencil_module = gt_utils.make_module_from_file(stencil_class_name, file_name)
Expand All @@ -292,7 +280,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None:
stacklevel=2,
)

def make_module(self, **kwargs: Any) -> Type[StencilObject]:
def make_module(self, **kwargs: Any) -> type[StencilObject]:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand All @@ -312,15 +300,15 @@ def make_module(self, **kwargs: Any) -> Type[StencilObject]:

return module

def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str:
def make_module_source(self, *, args_data: ModuleData | None = None, **kwargs: Any) -> str:
"""Generate the module source code with or without stencil id."""
args_data = args_data or make_args_data_from_gtir(self.builder.gtir_pipeline)
source = self.MODULE_GENERATOR_CLASS()(args_data, self.builder, **kwargs)
return source


class MakeModuleSourceCallable(Protocol):
def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ...
def __call__(self, *, args_data: ModuleData | None = None, **kwargs: Any) -> str: ...


class PurePythonBackendCLIMixin(CLIBackendMixin):
Expand All @@ -334,12 +322,12 @@ class PurePythonBackendCLIMixin(CLIBackendMixin):
#: :py:meth:`BaseBackend`.
make_module_source: MakeModuleSourceCallable

def generate_computation(self) -> Dict[str, Union[str, Dict]]:
def generate_computation(self) -> dict[str, str | dict]:
file_name = self.builder.module_path.name
source = self.make_module_source(ir=self.builder.gtir)
return {str(file_name): source}

def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
def generate_bindings(self, language_name: str) -> dict[str, str | dict]:
"""Pure python backends typically will not support bindings."""
return super().generate_bindings(language_name)

Expand All @@ -362,7 +350,7 @@ def pyext_build_dir_path(self) -> pathlib.Path:
return self.builder.pkg_path.joinpath(self.pyext_module_name + "_BUILD")

@property
def extra_cache_info(self) -> Dict[str, Any]:
def extra_cache_info(self) -> dict[str, Any]:
pyext_file_path = self.builder.backend_data.get("pyext_file_path", None)
pyext_md5 = ""
if pyext_file_path:
Expand All @@ -374,23 +362,23 @@ def extra_cache_info(self) -> Dict[str, Any]:
}

@property
def extra_cache_validation_keys(self) -> List[str]:
def extra_cache_validation_keys(self) -> list[str]:
keys = super().extra_cache_validation_keys
if self.extra_cache_info["pyext_md5"]:
keys.append("pyext_md5")
return keys

@abc.abstractmethod
def generate(self) -> Type[StencilObject]:
def generate(self) -> type[StencilObject]:
pass

def build_extension_module(
self,
pyext_sources: Dict[str, Any],
pyext_build_opts: Dict[str, str],
pyext_sources: dict[str, Any],
pyext_build_opts: dict[str, str],
*,
uses_cuda: bool = False,
) -> Tuple[str, str]:
) -> tuple[str, str]:
# Build extension module
pyext_build_path = pathlib.Path(
os.path.relpath(self.pyext_build_dir_path, pathlib.Path.cwd())
Expand All @@ -409,7 +397,7 @@ def build_extension_module(
pyext_target_file_path = self.builder.pkg_path
qualified_pyext_name = self.pyext_module_path

pyext_build_args: Dict[str, Any] = dict(
pyext_build_args: dict[str, Any] = dict(
name=qualified_pyext_name,
sources=sources,
build_path=str(pyext_build_path),
Expand All @@ -431,15 +419,15 @@ def build_extension_module(
return module_name, file_path


def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], Type[Backend]]:
def disabled(message: str, *, enabled_env_var: str) -> Callable[[type[Backend]], type[Backend]]:
# We push for hard deprecation here by raising by default and warning if enabling has been forced.
enabled = bool(int(os.environ.get(enabled_env_var, "0")))
if enabled:
return deprecated(message)
else:

def _decorator(cls: Type[Backend]) -> Type[Backend]:
def _no_generate(obj) -> Type[StencilObject]:
def _decorator(cls: type[Backend]) -> type[Backend]:
def _no_generate(obj) -> type[StencilObject]:
raise NotImplementedError(
f"Disabled '{cls.name}' backend: 'f{message}'\n",
f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'",
Expand Down
13 changes: 5 additions & 8 deletions src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, ClassVar

from gt4py import storage as gt_storage
from gt4py.cartesian.backend.base import CLIBackendMixin, disabled, register
Expand Down Expand Up @@ -37,12 +37,12 @@


class CudaExtGenerator(BackendCodegen):
def __init__(self, class_name, module_name, backend):
def __init__(self, class_name: str, module_name: str, backend: CudaBackend) -> None:
self.class_name = class_name
self.module_name = module_name
self.backend = backend

def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]:
def __call__(self, stencil_ir: gtir.Stencil) -> dict[str, dict[str, str]]:
stencil_ir = GtirPipeline(stencil_ir, self.backend.builder.stencil_id).full()
base_oir = GTIRToOIR().visit(stencil_ir)
oir_pipeline = self.backend.builder.options.backend_opts.get(
Expand Down Expand Up @@ -141,15 +141,12 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin):
MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator
GT_BACKEND_T = "gpu"

def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
def generate_extension(self, **kwargs: Any) -> tuple[str, str]:
return self.make_extension(uses_cuda=True)

def generate(self) -> Type[StencilObject]:
def generate(self) -> type[StencilObject]:
self.check_options(self.builder.options)

pyext_module_name: Optional[str]
pyext_file_path: Optional[str]

# TODO(havogt) add bypass if computation has no effect
pyext_module_name, pyext_file_path = self.generate_extension()

Expand Down
Loading