Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f79a7f0
Fix typo: stream._handle -> stream.handle
oleksandr-pavlyk May 20, 2025
de7b3c9
Move definition of LaunchConfig class to separate file
oleksandr-pavlyk May 20, 2025
5bd64a7
Introduce _module.KernelOccupancy class
oleksandr-pavlyk May 20, 2025
b89c95f
Add occupancy tests, except for cluster-related queries
oleksandr-pavlyk May 20, 2025
9679e0e
Fix type in querying handle from Stream argument
oleksandr-pavlyk May 22, 2025
ff322ec
Add tests for cluster-related occupancy descriptors
oleksandr-pavlyk May 22, 2025
fd8302f
Introduce MaxPotentialBlockSizeOccupancyResult named tuple
oleksandr-pavlyk May 22, 2025
40d799a
KernelOccupancy.max_potential_block_size support for CUoccupancyB2DSize
oleksandr-pavlyk May 22, 2025
5968ff0
Add test for B2DSize usage in max_potential_block_size
oleksandr-pavlyk May 22, 2025
fdbad93
Merge branch 'main' into feature/occupancy
oleksandr-pavlyk May 22, 2025
436f111
Merge branch 'main' into feature/occupancy
oleksandr-pavlyk May 29, 2025
428f4fa
Improved max_potential_block_size.__doc__
oleksandr-pavlyk May 30, 2025
f1ff0f5
Add test for dynamic_shared_memory_needed arg of invalid type
oleksandr-pavlyk May 30, 2025
39a08f6
Mention feature/occupancy in 0.3.0 release notes
oleksandr-pavlyk May 30, 2025
f74dcf1
Add symbols to api_private.rst
oleksandr-pavlyk Jun 3, 2025
e2adc57
Reduce test name verbosity
oleksandr-pavlyk Jun 3, 2025
496eb5b
Add doc-strings to KernelOccupancy methods.
oleksandr-pavlyk Jun 3, 2025
f74db2c
fix rendering
leofang Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cuda_core/cuda/core/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from cuda.core.experimental import utils
from cuda.core.experimental._device import Device
from cuda.core.experimental._event import Event, EventOptions
from cuda.core.experimental._launcher import LaunchConfig, launch
from cuda.core.experimental._launch_config import LaunchConfig
from cuda.core.experimental._launcher import launch
from cuda.core.experimental._linker import Linker, LinkerOptions
from cuda.core.experimental._module import ObjectCode
from cuda.core.experimental._program import Program, ProgramOptions
Expand Down
97 changes: 97 additions & 0 deletions cuda_core/cuda/core/experimental/_launch_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Optional, Union

from cuda.core.experimental._device import Device
from cuda.core.experimental._utils.cuda_utils import (
CUDAError,
cast_to_3_tuple,
driver,
get_binding_version,
handle_return,
)

# TODO: revisit this treatment for py313t builds
_inited = False


def _lazy_init():
global _inited
if _inited:
return

global _use_ex
# binding availability depends on cuda-python version
_py_major_minor = get_binding_version()
_driver_ver = handle_return(driver.cuDriverGetVersion())
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
_inited = True


@dataclass
class LaunchConfig:
"""Customizable launch options.

Attributes
----------
grid : Union[tuple, int]
Collection of threads that will execute a kernel function.
cluster : Union[tuple, int]
Group of blocks (Thread Block Cluster) that will execute on the same
GPU Processing Cluster (GPC). Blocks within a cluster have access to
distributed shared memory and can be explicitly synchronized.
block : Union[tuple, int]
Group of threads (Thread Block) that will execute on the same
streaming multiprocessor (SM). Threads within a thread blocks have
access to shared memory and can be explicitly synchronized.
shmem_size : int, optional
Dynamic shared-memory size per thread block in bytes.
(Default to size 0)

"""

# TODO: expand LaunchConfig to include other attributes
grid: Union[tuple, int] = None
cluster: Union[tuple, int] = None
block: Union[tuple, int] = None
shmem_size: Optional[int] = None

def __post_init__(self):
_lazy_init()
self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid)
self.block = cast_to_3_tuple("LaunchConfig.block", self.block)
# thread block clusters are supported starting H100
if self.cluster is not None:
if not _use_ex:
err, drvers = driver.cuDriverGetVersion()
drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else ""
raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}")
cc = Device().compute_capability
if cc < (9, 0):
raise CUDAError(
f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})"
)
self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster)
if self.shmem_size is None:
self.shmem_size = 0


def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
_lazy_init()
drv_cfg = driver.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
drv_cfg.sharedMemBytes = config.shmem_size
attrs = [] # TODO: support more attributes
if config.cluster:
attr = driver.CUlaunchAttribute()
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
dim = attr.value.clusterDim
dim.x, dim.y, dim.z = config.cluster
attrs.append(attr)
drv_cfg.numAttrs = len(attrs)
drv_cfg.attrs = attrs
return drv_cfg
71 changes: 4 additions & 67 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Optional, Union

from cuda.core.experimental._device import Device
from cuda.core.experimental._kernel_arg_handler import ParamHolder
from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config
from cuda.core.experimental._module import Kernel
from cuda.core.experimental._stream import Stream
from cuda.core.experimental._utils.clear_error_support import assert_type
from cuda.core.experimental._utils.cuda_utils import (
CUDAError,
cast_to_3_tuple,
check_or_create_options,
driver,
get_binding_version,
Expand All @@ -37,54 +33,6 @@ def _lazy_init():
_inited = True


@dataclass
class LaunchConfig:
"""Customizable launch options.

Attributes
----------
grid : Union[tuple, int]
Collection of threads that will execute a kernel function.
cluster : Union[tuple, int]
Group of blocks (Thread Block Cluster) that will execute on the same
GPU Processing Cluster (GPC). Blocks within a cluster have access to
distributed shared memory and can be explicitly synchronized.
block : Union[tuple, int]
Group of threads (Thread Block) that will execute on the same
streaming multiprocessor (SM). Threads within a thread blocks have
access to shared memory and can be explicitly synchronized.
shmem_size : int, optional
Dynamic shared-memory size per thread block in bytes.
(Default to size 0)

"""

# TODO: expand LaunchConfig to include other attributes
grid: Union[tuple, int] = None
cluster: Union[tuple, int] = None
block: Union[tuple, int] = None
shmem_size: Optional[int] = None

def __post_init__(self):
_lazy_init()
self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid)
self.block = cast_to_3_tuple("LaunchConfig.block", self.block)
# thread block clusters are supported starting H100
if self.cluster is not None:
if not _use_ex:
err, drvers = driver.cuDriverGetVersion()
drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else ""
raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}")
cc = Device().compute_capability
if cc < (9, 0):
raise CUDAError(
f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})"
)
self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster)
if self.shmem_size is None:
self.shmem_size = 0


def launch(stream, config, kernel, *kernel_args):
"""Launches a :obj:`~_module.Kernel`
object with launch-time configuration.
Expand Down Expand Up @@ -114,6 +62,7 @@ def launch(stream, config, kernel, *kernel_args):
f"stream must either be a Stream object or support __cuda_stream__ (got {type(stream)})"
) from e
assert_type(kernel, Kernel)
_lazy_init()
config = check_or_create_options(LaunchConfig, config, "launch config")

# TODO: can we ensure kernel_args is valid/safe to use here?
Expand All @@ -127,25 +76,13 @@ def launch(stream, config, kernel, *kernel_args):
# mainly to see if the "Ex" API is available and if so we use it, as it's more feature
# rich.
if _use_ex:
drv_cfg = driver.CUlaunchConfig()
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
drv_cfg = _to_native_launch_config(config)
drv_cfg.hStream = stream.handle
drv_cfg.sharedMemBytes = config.shmem_size
attrs = [] # TODO: support more attributes
if config.cluster:
attr = driver.CUlaunchAttribute()
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
dim = attr.value.clusterDim
dim.x, dim.y, dim.z = config.cluster
attrs.append(attr)
drv_cfg.numAttrs = len(attrs)
drv_cfg.attrs = attrs
handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
else:
# TODO: check if config has any unsupported attrs
handle_return(
driver.cuLaunchKernel(
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream._handle, args_ptr, 0
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0
)
)
108 changes: 107 additions & 1 deletion cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Optional, Union
from warnings import warn

from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config
from cuda.core.experimental._stream import Stream
from cuda.core.experimental._utils.clear_error_support import (
assert_type,
assert_type_str_or_bytes,
Expand Down Expand Up @@ -184,6 +186,102 @@ def cluster_scheduling_policy_preference(self, device_id: int = None) -> int:
)


MaxPotentialBlockSizeOccupancyResult = namedtuple("MaxPotential", ("min_grid_size", "max_block_size"))


class KernelOccupancy:
""" """

def __new__(self, *args, **kwargs):
raise RuntimeError("KernelOccupancy cannot be instantiated directly. Please use Kernel APIs.")

slots = ("_handle",)

@classmethod
def _init(cls, handle):
self = super().__new__(cls)
self._handle = handle

return self

def max_active_blocks_per_multiprocessor(self, block_size: int, dynamic_shared_memory_size: int) -> int:
"""int : Occupancy of the kernel"""
return handle_return(
driver.cuOccupancyMaxActiveBlocksPerMultiprocessor(self._handle, block_size, dynamic_shared_memory_size)
)

def max_potential_block_size(
self, dynamic_shared_memory_needed: Union[int, driver.CUoccupancyB2DSize], block_size_limit: int
) -> MaxPotentialBlockSizeOccupancyResult:
"""MaxPotentialBlockSizeOccupancyResult: Suggested launch configuration for reasonable occupancy.

Returns the minimum grid size needed to achieve the maximum occupancy and
the maximum block size that can achieve the maximum occupancy.

Parameters
----------
dynamic_shared_memory_needed: Union[int, driver.CUoccupancyB2DSize]
The amount of dynamic shared memory in bytes needed by block.
Use `0` if block does not need shared memory. Use C-callable
represented by :obj:`~driver.CUoccupancyB2DSize` to encode
amount of needed dynamic shared memory which varies depending
on tne block size.
block_size_limit: int
Known upper limit on the kernel block size. Use `0` to indicate
the maximum block size permitted by the device / kernel instead

Returns
-------
:obj:`~MaxPotentialBlockSizeOccupancyResult`
An object with `min_grid_size` amd `max_block_size` attributes encoding
the suggested launch configuration.

Note
----
Please be advised that use of C-callable that requires Python Global
Interpreter Lock may lead to deadlocks.

"""
if isinstance(dynamic_shared_memory_needed, int):
min_grid_size, max_block_size = handle_return(
driver.cuOccupancyMaxPotentialBlockSize(
self._handle, None, dynamic_shared_memory_needed, block_size_limit
)
)
elif isinstance(dynamic_shared_memory_needed, driver.CUoccupancyB2DSize):
min_grid_size, max_block_size = handle_return(
driver.cuOccupancyMaxPotentialBlockSize(
self._handle, dynamic_shared_memory_needed.getPtr(), 0, block_size_limit
)
)
else:
raise TypeError(
"dynamic_shared_memory_needed expected to have type int, or CUoccupancyB2DSize, "
f"got {type(dynamic_shared_memory_needed)}"
)
return MaxPotentialBlockSizeOccupancyResult(min_grid_size=min_grid_size, max_block_size=max_block_size)

def available_dynamic_shared_memory_per_block(self, num_blocks_per_multiprocessor: int, block_size: int) -> int:
"""int: Dynamic shared memory available per block for given launch configuration."""
return handle_return(
driver.cuOccupancyAvailableDynamicSMemPerBlock(self._handle, num_blocks_per_multiprocessor, block_size)
)

def max_potential_cluster_size(self, config: LaunchConfig, stream: Optional[Stream] = None) -> int:
""" "int: The maximum cluster size that can be launched for this kernel and launch configuration"""
drv_cfg = _to_native_launch_config(config)
if stream is not None:
drv_cfg.hStream = stream.handle
return handle_return(driver.cuOccupancyMaxPotentialClusterSize(self._handle, drv_cfg))

def max_active_clusters(self, config: LaunchConfig, stream: Optional[Stream] = None) -> int:
""" "int: The maximum number of clusters that could co-exist on the target device"""
drv_cfg = _to_native_launch_config(config)
if stream is not None:
drv_cfg.hStream = stream.handle
return handle_return(driver.cuOccupancyMaxActiveClusters(self._handle, drv_cfg))


ParamInfo = namedtuple("ParamInfo", ["offset", "size"])


Expand All @@ -198,7 +296,7 @@ class Kernel:

"""

__slots__ = ("_handle", "_module", "_attributes")
__slots__ = ("_handle", "_module", "_attributes", "_occupancy")

def __new__(self, *args, **kwargs):
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
Expand All @@ -211,6 +309,7 @@ def _from_obj(cls, obj, mod):
ker._handle = obj
ker._module = mod
ker._attributes = None
ker._occupancy = None
return ker

@property
Expand Down Expand Up @@ -250,6 +349,13 @@ def arguments_info(self) -> list[ParamInfo]:
_, param_info = self._get_arguments_info(param_info=True)
return param_info

@property
def occupancy(self) -> KernelOccupancy:
"""Get the read-only attributes of this kernel."""
if self._occupancy is None:
self._occupancy = KernelOccupancy._init(self._handle)
return self._occupancy

# TODO: implement from_handle()


Expand Down
1 change: 1 addition & 0 deletions cuda_core/docs/source/release/0.3.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ New features
------------

- :class:`Kernel` adds :property:`Kernel.num_arguments` and :property:`Kernel.arguments_info` for introspection of kernel arguments. (#612)
- Add pythonic access to kernel occupancy calculation functions via :property:`Kernel.occupancy`. (#648)

New examples
------------
Expand Down
Loading