diff --git a/c/experimental/stf/include/cccl/c/experimental/stf/stf.h b/c/experimental/stf/include/cccl/c/experimental/stf/stf.h index 903b71cd878..848f0f1d5db 100644 --- a/c/experimental/stf/include/cccl/c/experimental/stf/stf.h +++ b/c/experimental/stf/include/cccl/c/experimental/stf/stf.h @@ -456,36 +456,87 @@ cudaStream_t stf_fence(stf_ctx_handle ctx); //! //! \brief Create logical data from existing memory buffer //! -//! Creates logical data handle from an existing host memory buffer. -//! STF takes ownership of data management during task execution. +//! Creates logical data handle from existing memory buffer, assuming host data place. +//! This is a convenience wrapper around stf_logical_data_with_place() with host placement. //! //! \param ctx Context handle //! \param[out] ld Pointer to receive logical data handle -//! \param addr Pointer to existing data buffer +//! \param addr Pointer to existing data buffer (assumed to be host memory) //! \param sz Size of data in bytes //! //! \pre ctx must be valid context handle //! \pre ld must not be NULL -//! \pre addr must not be NULL +//! \pre addr must not be NULL and point to host-accessible memory //! \pre sz must be greater than 0 //! \post *ld contains valid logical data handle //! -//! \note Original data pointer should not be accessed during task execution -//! \note Data will be written back when logical data is destroyed or context finalized +//! \note This function assumes host memory. For device/managed memory, use stf_logical_data_with_place() +//! \note Equivalent to: stf_logical_data_with_place(ctx, ld, addr, sz, make_host_data_place()) //! //! \par Example: //! \code //! float data[1024]; //! stf_logical_data_handle ld; -//! stf_logical_data(ctx, &ld, data, sizeof(data)); +//! stf_logical_data(ctx, &ld, data, sizeof(data)); // Assumes host memory //! // ... use in tasks ... //! stf_logical_data_destroy(ld); //! \endcode //! -//! \see stf_logical_data_empty(), stf_logical_data_destroy() +//! \see stf_logical_data_with_place(), stf_logical_data_empty(), stf_logical_data_destroy() void stf_logical_data(stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz); +//! +//! \brief Create logical data handle from address with data place specification +//! +//! Creates logical data handle from existing memory buffer, explicitly specifying where +//! the memory is located (host, device, managed, etc.). This is the primary and recommended +//! logical data creation function as it provides STF with essential memory location information +//! for optimal data movement and placement strategies. +//! +//! \param ctx Context handle +//! \param[out] ld Pointer to receive logical data handle +//! \param addr Pointer to existing memory buffer +//! \param sz Size of buffer in bytes +//! \param dplace Data place specifying memory location +//! +//! \pre ctx must be valid context handle +//! \pre ld must be valid pointer to logical data handle pointer +//! \pre addr must point to valid memory of at least sz bytes +//! \pre sz must be greater than 0 +//! \pre dplace must be valid data place (not invalid) +//! +//! \post *ld contains valid logical data handle on success +//! \post Caller owns returned handle (must call stf_logical_data_destroy()) +//! +//! \par Examples: +//! \code +//! // GPU device memory (recommended for CUDA arrays) +//! float* device_ptr; +//! cudaMalloc(&device_ptr, 1000 * sizeof(float)); +//! stf_data_place dplace = make_device_data_place(0); +//! stf_logical_data_handle ld; +//! stf_logical_data_with_place(ctx, &ld, device_ptr, 1000 * sizeof(float), dplace); +//! +//! // Host memory +//! float* host_data = new float[1000]; +//! stf_data_place host_place = make_host_data_place(); +//! stf_logical_data_handle ld_host; +//! stf_logical_data_with_place(ctx, &ld_host, host_data, 1000 * sizeof(float), host_place); +//! +//! // Managed memory +//! float* managed_ptr; +//! cudaMallocManaged(&managed_ptr, 1000 * sizeof(float)); +//! stf_data_place managed_place = make_managed_data_place(); +//! stf_logical_data_handle ld_managed; +//! stf_logical_data_with_place(ctx, &ld_managed, managed_ptr, 1000 * sizeof(float), managed_place); +//! \endcode +//! +//! \see make_device_data_place(), make_host_data_place(), make_managed_data_place() + +void stf_logical_data_with_place( + stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz, stf_data_place dplace); + //! //! \brief Set symbolic name for logical data //! diff --git a/c/experimental/stf/src/stf.cu b/c/experimental/stf/src/stf.cu index c08a88b77e1..c601be20e26 100644 --- a/c/experimental/stf/src/stf.cu +++ b/c/experimental/stf/src/stf.cu @@ -44,12 +44,44 @@ cudaStream_t stf_fence(stf_ctx_handle ctx) } void stf_logical_data(stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz) +{ + // Convenience wrapper: assume host memory + stf_logical_data_with_place(ctx, ld, addr, sz, make_host_data_place()); +} + +void stf_logical_data_with_place( + stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz, stf_data_place dplace) { assert(ctx); assert(ld); auto* context_ptr = static_cast(ctx); - auto ld_typed = context_ptr->logical_data(make_slice((char*) addr, sz)); + + // Convert C data_place to C++ data_place + cuda::experimental::stf::data_place cpp_dplace; + switch (dplace.kind) + { + case STF_DATA_PLACE_HOST: + cpp_dplace = cuda::experimental::stf::data_place::host(); + break; + case STF_DATA_PLACE_DEVICE: + cpp_dplace = cuda::experimental::stf::data_place::device(dplace.u.device.dev_id); + break; + case STF_DATA_PLACE_MANAGED: + cpp_dplace = cuda::experimental::stf::data_place::managed(); + break; + case STF_DATA_PLACE_AFFINE: + cpp_dplace = cuda::experimental::stf::data_place::affine(); + break; + default: + // Invalid data place - this should not happen with valid input + assert(false && "Invalid data_place kind"); + cpp_dplace = cuda::experimental::stf::data_place::host(); // fallback + break; + } + + // Create logical data with the specified data place + auto ld_typed = context_ptr->logical_data(make_slice((char*) addr, sz), cpp_dplace); // Store the logical_data_untyped directly as opaque pointer *ld = new logical_data_untyped{ld_typed}; diff --git a/cudax/examples/stf/void_data_interface.cu b/cudax/examples/stf/void_data_interface.cu index 0340b16bf4a..bf429f23dfe 100644 --- a/cudax/examples/stf/void_data_interface.cu +++ b/cudax/examples/stf/void_data_interface.cu @@ -49,5 +49,9 @@ int main() return cuda_kernel_desc{dummy_kernel, 16, 128, 0}; }; + EXPECT(token.is_void_interface()); + EXPECT(token2.is_void_interface()); + EXPECT(token3.is_void_interface()); + ctx.finalize(); } diff --git a/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh b/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh index 167140afefc..23a106fad47 100644 --- a/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh +++ b/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh @@ -104,7 +104,9 @@ public: { // Select a stream from the pool capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream; - cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal)); + // Use relaxed capture mode to allow capturing workloads that lazily initialize + // resources (e.g., set up memory pools) + cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeRelaxed)); } auto& dot = *ctx.get_dot(); @@ -365,7 +367,9 @@ public: capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream; cudaGraph_t childGraph = nullptr; - cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal)); + // Use relaxed capture mode to allow capturing workloads that lazily initialize + // resources (e.g., set up memory pools) + cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeRelaxed)); // Launch the user provided function f(capture_stream); @@ -625,7 +629,9 @@ public: cudaStream_t capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream; cudaGraph_t childGraph = nullptr; - cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal)); + // Use relaxed capture mode to allow capturing workloads that lazily initialize + // resources (e.g., set up memory pools) + cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeRelaxed)); // Launch the user provided function if constexpr (fun_invocable_stream_deps) diff --git a/cudax/test/stf/cpp/task_get_stream.cu b/cudax/test/stf/cpp/task_get_stream.cu index 89fa74e7490..2d6509e5a73 100644 --- a/cudax/test/stf/cpp/task_get_stream.cu +++ b/cudax/test/stf/cpp/task_get_stream.cu @@ -24,7 +24,8 @@ void test_stream() context ctx; auto token = ctx.token(); - auto t = ctx.task(token.write()); + EXPECT(token.is_void_interface()); + auto t = ctx.task(token.write()); t.start(); cudaStream_t s = t.get_stream(); EXPECT(s != nullptr); diff --git a/python/cuda_cccl/CMakeLists.txt b/python/cuda_cccl/CMakeLists.txt index 922458c02e2..bcfb3b084f9 100644 --- a/python/cuda_cccl/CMakeLists.txt +++ b/python/cuda_cccl/CMakeLists.txt @@ -17,9 +17,27 @@ message( # Build cccl.c.parallel and add CCCL's install rules set(_cccl_root ../..) + +# Build and install C++ library first set(CCCL_TOPLEVEL_PROJECT ON) # Enable the developer builds -set(CCCL_ENABLE_C_PARALLEL ON) # Build the cccl.c.parallel library +set(CCCL_ENABLE_C_PARALLEL ON) +set(CCCL_ENABLE_C_EXPERIMENTAL_STF ON) # Enable C experimental STF library (triggers c/ directory) +set(CCCL_ENABLE_UNSTABLE ON) # Enable unstable features + +# Disable all testing, examples, and benchmarks - we only want the libraries +set(CCCL_ENABLE_TESTING OFF) +set(CCCL_ENABLE_EXAMPLES OFF) +set(CCCL_ENABLE_BENCHMARKS OFF) +set(CCCL_C_PARALLEL_ENABLE_TESTING OFF) +set(CCCL_C_EXPERIMENTAL_STF_ENABLE_TESTING OFF) +# Note: CCCL_ENABLE_CUDAX must be ON because STF depends on it (via CCCL_ENABLE_UNSTABLE) +# But disable cudax tests, examples, and header testing +set(cudax_ENABLE_TESTING OFF) +set(cudax_ENABLE_EXAMPLES OFF) +set(cudax_ENABLE_HEADER_TESTING OFF) set(CCCL_C_PARALLEL_LIBRARY_OUTPUT_DIRECTORY ${SKBUILD_PROJECT_NAME}) +set(CCCL_C_EXPERIMENTAL_STF_LIBRARY_OUTPUT_DIRECTORY ${SKBUILD_PROJECT_NAME}) + # Just install the rest: set(libcudacxx_ENABLE_INSTALL_RULES ON) set(CUB_ENABLE_INSTALL_RULES ON) @@ -34,10 +52,21 @@ add_subdirectory(${_cccl_root} _parent_cccl) set(CMAKE_INSTALL_LIBDIR "${old_libdir}") # pop set(CMAKE_INSTALL_INCLUDEDIR "${old_includedir}") # pop +# Create CCCL::cudax alias for STF (normally created by cccl-config.cmake) +if (TARGET cudax::cudax AND NOT TARGET CCCL::cudax) + add_library(CCCL::cudax ALIAS cudax::cudax) +endif() + # ensure the destination directory exists +file(MAKE_DIRECTORY "cuda/stf/${CUDA_VERSION_DIR}/cccl") file(MAKE_DIRECTORY "cuda/compute/${CUDA_VERSION_DIR}/cccl") # Install version-specific binaries +install( + TARGETS cccl.c.experimental.stf + DESTINATION cuda/stf/${CUDA_VERSION_DIR}/cccl +) + install( TARGETS cccl.c.parallel DESTINATION cuda/compute/${CUDA_VERSION_DIR}/cccl @@ -110,12 +139,40 @@ add_custom_target( DEPENDS "${_generated_extension_src}" ) +message(STATUS "STF Using Cython ${CYTHON_VERSION}") +set( + stf_pyx_source_file + "${cuda_cccl_SOURCE_DIR}/cuda/stf/_stf_bindings_impl.pyx" +) +set(_stf_generated_extension_src "${cuda_cccl_BINARY_DIR}/_stf_bindings_impl.c") +set(_stf_depfile "${cuda_cccl_BINARY_DIR}/_stf_bindings_impl.c.dep") +add_custom_command( + OUTPUT "${_stf_generated_extension_src}" + COMMAND "${Python3_EXECUTABLE}" -m cython + ARGS + ${CYTHON_FLAGS_LIST} "${stf_pyx_source_file}" --output-file + ${_stf_generated_extension_src} + DEPENDS "${stf_pyx_source_file}" + DEPFILE "${_stf_depfile}" + COMMENT "Cythonizing ${pyx_source_file} for CUDA ${CUDA_VERSION_MAJOR}" +) +set_source_files_properties( + "${_stf_generated_extension_src}" + PROPERTIES GENERATED TRUE +) +add_custom_target( + cythonize_stf_bindings_impl + ALL + DEPENDS "${_stf_generated_extension_src}" +) + python3_add_library( _bindings_impl MODULE WITH_SOABI "${_generated_extension_src}" ) + add_dependencies(_bindings_impl cythonize_bindings_impl) target_link_libraries( _bindings_impl @@ -125,4 +182,21 @@ target_link_libraries( ) set_target_properties(_bindings_impl PROPERTIES INSTALL_RPATH "$ORIGIN/cccl") +python3_add_library( + _stf_bindings_impl + MODULE + WITH_SOABI + "${_stf_generated_extension_src}" +) +add_dependencies(_stf_bindings_impl cythonize_stf_bindings_impl) +target_link_libraries( + _stf_bindings_impl + PRIVATE cccl.c.experimental.stf CUDA::cuda_driver +) +set_target_properties( + _stf_bindings_impl + PROPERTIES INSTALL_RPATH "$ORIGIN/cccl" +) + install(TARGETS _bindings_impl DESTINATION cuda/compute/${CUDA_VERSION_DIR}) +install(TARGETS _stf_bindings_impl DESTINATION cuda/stf/${CUDA_VERSION_DIR}) diff --git a/python/cuda_cccl/cuda/stf/__init__.py b/python/cuda_cccl/cuda/stf/__init__.py new file mode 100644 index 00000000000..6ca687dfcb3 --- /dev/null +++ b/python/cuda_cccl/cuda/stf/__init__.py @@ -0,0 +1,27 @@ +from ._stf_bindings import ( + context, + data_place, + dep, + exec_place, +) +from .decorator import jit # Python-side kernel launcher + +__all__ = [ + "context", + "dep", + "exec_place", + "data_place", + "jit", +] + + +def has_torch() -> bool: + import importlib.util + + return importlib.util.find_spec("torch") is not None + + +def has_numba() -> bool: + import importlib.util + + return importlib.util.find_spec("numba") is not None diff --git a/python/cuda_cccl/cuda/stf/_adapters/numba_bridge.py b/python/cuda_cccl/cuda/stf/_adapters/numba_bridge.py new file mode 100644 index 00000000000..32b160ba879 --- /dev/null +++ b/python/cuda_cccl/cuda/stf/_adapters/numba_bridge.py @@ -0,0 +1,4 @@ +def cai_to_numba(cai: dict): + from numba import cuda + + return cuda.from_cuda_array_interface(cai, owner=None, sync=False) diff --git a/python/cuda_cccl/cuda/stf/_adapters/numba_utils.py b/python/cuda_cccl/cuda/stf/_adapters/numba_utils.py new file mode 100644 index 00000000000..280d8f3a55d --- /dev/null +++ b/python/cuda_cccl/cuda/stf/_adapters/numba_utils.py @@ -0,0 +1,86 @@ +""" +Utilities for NUMBA-based STF operations. +""" + +from numba import cuda + + +def init_logical_data(ctx, ld, value, data_place=None, exec_place=None): + """ + Initialize a logical data with a constant value using CuPy's optimized fill. + + Parameters + ---------- + ctx : context + STF context + ld : logical_data + Logical data to initialize + value : scalar + Value to fill the array with + data_place : data_place, optional + Data place for the initialization task + exec_place : exec_place, optional + Execution place for the fill operation + """ + # Create write dependency with optional data place + dep_arg = ld.write(data_place) if data_place else ld.write() + + # Create task arguments - include exec_place if provided + task_args = [] + if exec_place is not None: + task_args.append(exec_place) + task_args.append(dep_arg) + + with ctx.task(*task_args) as t: + # Get the array as a numba device array + nb_stream = cuda.external_stream(t.stream_ptr()) + array = t.numba_arguments() + + try: + # Use CuPy's optimized operations (much faster than custom kernels) + import cupy as cp + + with cp.cuda.Stream(nb_stream): + cp_view = cp.asarray(array) + if value == 0 or value == 0.0: + # Use CuPy's potentially optimized zero operation + cp_view.fill(0) # CuPy may have special optimizations for zero + else: + # Use generic fill for non-zero values + cp_view.fill(value) + except ImportError: + # Fallback to simple kernel if CuPy not available + _fill_with_simple_kernel(array, value, nb_stream) + + +@cuda.jit +def _fill_kernel_fallback(array, value): + """Fallback 1D kernel when CuPy is not available.""" + idx = cuda.grid(1) + if idx < array.size: + array.flat[idx] = value + + +@cuda.jit +def _zero_kernel_fallback(array): + """Optimized fallback kernel for zero-filling when CuPy is not available.""" + idx = cuda.grid(1) + if idx < array.size: + array.flat[idx] = 0 + + +def _fill_with_simple_kernel(array, value, stream): + """Fallback method using simple NUMBA kernel when CuPy unavailable.""" + total_size = array.size + threads_per_block = 256 + blocks_per_grid = (total_size + threads_per_block - 1) // threads_per_block + + if value == 0 or value == 0.0: + # Use the specialized zero kernel for potentially better performance + _zero_kernel_fallback[blocks_per_grid, threads_per_block, stream](array) + else: + # Use generic fill kernel for non-zero values + typed_value = array.dtype.type(value) + _fill_kernel_fallback[blocks_per_grid, threads_per_block, stream]( + array, typed_value + ) diff --git a/python/cuda_cccl/cuda/stf/_adapters/torch_bridge.py b/python/cuda_cccl/cuda/stf/_adapters/torch_bridge.py new file mode 100644 index 00000000000..0e7686ea363 --- /dev/null +++ b/python/cuda_cccl/cuda/stf/_adapters/torch_bridge.py @@ -0,0 +1,16 @@ +from __future__ import annotations + + +def cai_to_torch(cai: dict): + """ + Convert a __cuda_array_interface__ dict to a torch.Tensor + without making PyTorch a hard dependency of the core extension. + + Uses Numba (a required dependency) to create a DeviceNDArray, + which torch.as_tensor can consume directly via __cuda_array_interface__. + """ + import torch + from numba import cuda as _cuda + + dev_array = _cuda.from_cuda_array_interface(cai, owner=None, sync=False) + return torch.as_tensor(dev_array) diff --git a/python/cuda_cccl/cuda/stf/_stf_bindings.py b/python/cuda_cccl/cuda/stf/_stf_bindings.py new file mode 100644 index 00000000000..169490739ee --- /dev/null +++ b/python/cuda_cccl/cuda/stf/_stf_bindings.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# _bindings.py is a shim module that imports symbols from a +# _bindings_impl extension module. The shim serves two purposes: +# +# 1. Import a CUDA-specific extension. The cuda.cccl wheel ships with multiple +# extensions, one for each CUDA version. At runtime, this shim chooses the +# appropriate extension based on the detected CUDA version, and imports all +# symbols from it. +# +# 2. Preload `nvrtc` and `nvJitLink` before importing the extension. +# These shared libraries are indirect dependencies, pulled in via the direct +# dependency `cccl.c.parallel`. To ensure reliable symbol resolution at +# runtime, we explicitly load them first using `cuda.pathfinder`. +# Without this step, importing the Cython extension directly may fail or behave +# inconsistently depending on environment setup and dynamic linker behavior. +# This indirection ensures the right loading order, regardless of how +# `_bindings` is first imported across the codebase. + +import importlib + +from cuda.cccl._cuda_version_utils import detect_cuda_version, get_recommended_extra +from cuda.pathfinder import ( # type: ignore[import-not-found] + load_nvidia_dynamic_lib, +) + + +def _load_cuda_libraries(): + """ + Preload CUDA libraries to ensure proper symbol resolution. + + These libraries are indirect dependencies pulled in via cccl.c.parallel. + Preloading ensures reliable symbol resolution regardless of dynamic linker behavior. + """ + import warnings + + for libname in ("nvrtc", "nvJitLink"): + try: + load_nvidia_dynamic_lib(libname) + except Exception as e: + # Log warning but don't fail - the extension might still work + # if the libraries are already loaded or available through other means + warnings.warn( + f"Failed to preload CUDA library '{libname}': {e}. " + f"STF bindings may fail to load if {libname} is not available.", + RuntimeWarning, + stacklevel=2, + ) + + +_load_cuda_libraries() + + +# Import the appropriate bindings implementation depending on what +# CUDA version is available: +cuda_version = detect_cuda_version() +if cuda_version not in [12, 13]: + raise RuntimeError( + f"Unsupported CUDA version: {cuda_version}. Only CUDA 12 and 13 are supported." + ) + +try: + extra_name = get_recommended_extra(cuda_version) + bindings_module = importlib.import_module( + f".{extra_name}._stf_bindings_impl", __package__ + ) + # Import all symbols from the module + globals().update(bindings_module.__dict__) +except ImportError as e: + raise ImportError( + f"Failed to import CUDA STF bindings for CUDA {cuda_version}. " + f"Ensure cuda-cccl is properly installed with: pip install cuda-cccl[cu{cuda_version}]" + ) from e diff --git a/python/cuda_cccl/cuda/stf/_stf_bindings_impl.pyx b/python/cuda_cccl/cuda/stf/_stf_bindings_impl.pyx new file mode 100644 index 00000000000..12f8fba3114 --- /dev/null +++ b/python/cuda_cccl/cuda/stf/_stf_bindings_impl.pyx @@ -0,0 +1,911 @@ +# distutils: language = c++ +# cython: language_level=3 +# cython: linetrace=True + +# Python signatures are declared in the companion Python stub file _bindings.pyi +# Make sure to update PYI with change to Python API to ensure that Python +# static type checker tools like mypy green-lights cuda.cccl.parallel + +from cpython.buffer cimport ( + Py_buffer, PyBUF_FORMAT, PyBUF_ND, PyBUF_SIMPLE, PyBUF_ANY_CONTIGUOUS, + PyObject_GetBuffer, PyBuffer_Release, PyObject_CheckBuffer +) +from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.pycapsule cimport ( + PyCapsule_CheckExact, PyCapsule_IsValid, PyCapsule_GetPointer +) +from libc.stdint cimport uint8_t, uint32_t, uint64_t, int64_t, uintptr_t +from libc.string cimport memset, memcpy + +import numpy as np + +import ctypes +from enum import IntFlag + +cdef extern from "": + cdef struct OpaqueCUstream_st + cdef struct OpaqueCUkernel_st + cdef struct OpaqueCUlibrary_st + + ctypedef int CUresult + ctypedef OpaqueCUstream_st *CUstream + ctypedef OpaqueCUkernel_st *CUkernel + ctypedef OpaqueCUlibrary_st *CUlibrary + +cdef extern from "cccl/c/experimental/stf/stf.h": + # + # Contexts + # + ctypedef struct stf_ctx_handle_t + ctypedef stf_ctx_handle_t* stf_ctx_handle + void stf_ctx_create(stf_ctx_handle* ctx) + void stf_ctx_create_graph(stf_ctx_handle* ctx) + void stf_ctx_finalize(stf_ctx_handle ctx) + + # + # Exec places + # + ctypedef enum stf_exec_place_kind: + STF_EXEC_PLACE_DEVICE + STF_EXEC_PLACE_HOST + + ctypedef struct stf_exec_place_device: + int dev_id + + ctypedef struct stf_exec_place_host: + int dummy + + ctypedef union stf_exec_place_u: + stf_exec_place_device device + stf_exec_place_host host + + ctypedef struct stf_exec_place: + stf_exec_place_kind kind + stf_exec_place_u u + + stf_exec_place make_device_place(int dev_id) + stf_exec_place make_host_place() + + # + # Data places + # + ctypedef enum stf_data_place_kind: + STF_DATA_PLACE_DEVICE + STF_DATA_PLACE_HOST + STF_DATA_PLACE_MANAGED + STF_DATA_PLACE_AFFINE + + ctypedef struct stf_data_place_device: + int dev_id + + ctypedef struct stf_data_place_host: + int dummy + + ctypedef struct stf_data_place_managed: + int dummy + + ctypedef struct stf_data_place_affine: + int dummy + + ctypedef union stf_data_place_u: + stf_data_place_device device + stf_data_place_host host + stf_data_place_managed managed + stf_data_place_affine affine + + ctypedef struct stf_data_place: + stf_data_place_kind kind + stf_data_place_u u + + stf_data_place make_device_data_place(int dev_id) + stf_data_place make_host_data_place() + stf_data_place make_managed_data_place() + stf_data_place make_affine_data_place() + + ctypedef struct stf_logical_data_handle_t + ctypedef stf_logical_data_handle_t* stf_logical_data_handle + void stf_logical_data(stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz) + void stf_logical_data_with_place(stf_ctx_handle ctx, stf_logical_data_handle* ld, void* addr, size_t sz, stf_data_place dplace) + void stf_logical_data_set_symbol(stf_logical_data_handle ld, const char* symbol) + void stf_logical_data_destroy(stf_logical_data_handle ld) + void stf_logical_data_empty(stf_ctx_handle ctx, size_t length, stf_logical_data_handle *to) + + void stf_token(stf_ctx_handle ctx, stf_logical_data_handle* ld); + + ctypedef struct stf_task_handle_t + ctypedef stf_task_handle_t* stf_task_handle + void stf_task_create(stf_ctx_handle ctx, stf_task_handle* t) + void stf_task_set_exec_place(stf_task_handle t, stf_exec_place* exec_p) + void stf_task_set_symbol(stf_task_handle t, const char* symbol) + void stf_task_add_dep(stf_task_handle t, stf_logical_data_handle ld, stf_access_mode m) + void stf_task_add_dep_with_dplace(stf_task_handle t, stf_logical_data_handle ld, stf_access_mode m, stf_data_place* data_p) + void stf_task_start(stf_task_handle t) + void stf_task_end(stf_task_handle t) + void stf_task_enable_capture(stf_task_handle t) + CUstream stf_task_get_custream(stf_task_handle t) + void* stf_task_get(stf_task_handle t, int submitted_index) + void stf_task_destroy(stf_task_handle t) + + cdef enum stf_access_mode: + STF_NONE + STF_READ + STF_WRITE + STF_RW + +class AccessMode(IntFlag): + NONE = STF_NONE + READ = STF_READ + WRITE = STF_WRITE + RW = STF_RW + +class stf_arg_cai: + def __init__(self, ptr, tuple shape, dtype, stream=0): + self.ptr = ptr # integer device pointer + self.shape = shape + self.dtype = np.dtype(dtype) + self.stream = stream # CUDA stream handle (int or 0) + self.__cuda_array_interface__ = { + 'version': 2, + 'shape': self.shape, + 'typestr': self.dtype.str, # e.g., 'data_ptr, self._len, dplace._c_place) + return + + # Fallback to Python buffer protocol + cdef Py_buffer view + cdef int flags = PyBUF_FORMAT | PyBUF_ND # request dtype + shape + + if PyObject_GetBuffer(buf, &view, flags) != 0: + raise ValueError("object doesn't support the full buffer protocol or __cuda_array_interface__") + + try: + self._ndim = view.ndim + self._len = view.len + self._shape = tuple(view.shape[i] for i in range(view.ndim)) + self._dtype = np.dtype(view.format) + # For buffer protocol objects, use the specified data place (defaults to host) + stf_logical_data_with_place(ctx._ctx, &self._ld, view.buf, view.len, dplace._c_place) + + finally: + PyBuffer_Release(&view) + + + def set_symbol(self, str name): + stf_logical_data_set_symbol(self._ld, name.encode()) + self._symbol = name # Store locally for retrieval + + @property + def symbol(self): + """Get the symbol name of this logical data, if set.""" + return self._symbol + + def __dealloc__(self): + if self._ld != NULL: + stf_logical_data_destroy(self._ld) + self._ld = NULL + + def __repr__(self): + """Return a detailed string representation of the logical_data object.""" + return (f"logical_data(shape={self._shape}, dtype={self._dtype}, " + f"is_token={self._is_token}, symbol={self._symbol!r}, " + f"len={self._len}, ndim={self._ndim})") + + @property + def dtype(self): + """Return the dtype of the logical data.""" + return self._dtype + + @property + def shape(self): + """Return the shape of the logical data.""" + return self._shape + + def read(self, dplace=None): + return dep(self, AccessMode.READ.value, dplace) + + def write(self, dplace=None): + return dep(self, AccessMode.WRITE.value, dplace) + + def rw(self, dplace=None): + return dep(self, AccessMode.RW.value, dplace) + + def like_empty(self): + """ + Create a new logical_data with the same shape (and dtype metadata) + as this object. + """ + if self._ld == NULL: + raise RuntimeError("source logical_data handle is NULL") + + cdef logical_data out = logical_data.__new__(logical_data) + stf_logical_data_empty(self._ctx, self._len, &out._ld) + out._ctx = self._ctx + out._dtype = self._dtype + out._shape = self._shape + out._ndim = self._ndim + out._len = self._len + out._symbol = None # New object has no symbol initially + out._is_token = False + + return out + + @staticmethod + def token(context ctx): + cdef logical_data out = logical_data.__new__(logical_data) + out._ctx = ctx._ctx + out._dtype = None + out._shape = None + out._ndim = 0 + out._len = 0 + out._symbol = None # New object has no symbol initially + out._is_token = True + stf_token(ctx._ctx, &out._ld) + + return out + + @staticmethod + def init_by_shape(context ctx, shape, dtype): + """ + Create a new logical_data from a shape and a dtype + """ + cdef logical_data out = logical_data.__new__(logical_data) + out._ctx = ctx._ctx + out._dtype = np.dtype(dtype) + out._shape = shape + out._ndim = len(shape) + cdef size_t total_items = 1 + for dim in shape: + total_items *= dim + out._len = total_items * out._dtype.itemsize + out._symbol = None # New object has no symbol initially + out._is_token = False + stf_logical_data_empty(ctx._ctx, out._len, &out._ld) + + return out + + def borrow_ctx_handle(self): + ctx = context(borrowed=True) + ctx.borrow_from_handle(self._ctx) + return ctx + +class dep: + __slots__ = ("ld", "mode", "dplace") + def __init__(self, logical_data ld, int mode, dplace=None): + self.ld = ld + self.mode = mode + self.dplace = dplace # can be None or a data place + def __iter__(self): # nice unpacking support + yield self.ld + yield self.mode + yield self.dplace + def __repr__(self): + return f"dep({self.ld!r}, {self.mode}, {self.dplace!r})" + def get_ld(self): + return self.ld + +def read(ld, dplace=None): return dep(ld, AccessMode.READ.value, dplace) +def write(ld, dplace=None): return dep(ld, AccessMode.WRITE.value, dplace) +def rw(ld, dplace=None): return dep(ld, AccessMode.RW.value, dplace) + +cdef class exec_place: + cdef stf_exec_place _c_place + + def __cinit__(self): + # empty default constructor; never directly used + pass + + @staticmethod + def device(int dev_id): + cdef exec_place p = exec_place.__new__(exec_place) + p._c_place = make_device_place(dev_id) + return p + + @staticmethod + def host(): + cdef exec_place p = exec_place.__new__(exec_place) + p._c_place = make_host_place() + return p + + @property + def kind(self) -> str: + return ("device" if self._c_place.kind == STF_EXEC_PLACE_DEVICE + else "host") + + @property + def device_id(self) -> int: + if self._c_place.kind != STF_EXEC_PLACE_DEVICE: + raise AttributeError("not a device execution place") + return self._c_place.u.device.dev_id + +cdef class data_place: + cdef stf_data_place _c_place + + def __cinit__(self): + # empty default constructor; never directly used + pass + + @staticmethod + def device(int dev_id): + cdef data_place p = data_place.__new__(data_place) + p._c_place = make_device_data_place(dev_id) + return p + + @staticmethod + def host(): + cdef data_place p = data_place.__new__(data_place) + p._c_place = make_host_data_place() + return p + + @staticmethod + def managed(): + cdef data_place p = data_place.__new__(data_place) + p._c_place = make_managed_data_place() + return p + + @staticmethod + def affine(): + cdef data_place p = data_place.__new__(data_place) + p._c_place = make_affine_data_place() + return p + + @property + def kind(self) -> str: + cdef stf_data_place_kind k = self._c_place.kind + if k == STF_DATA_PLACE_DEVICE: + return "device" + elif k == STF_DATA_PLACE_HOST: + return "host" + elif k == STF_DATA_PLACE_MANAGED: + return "managed" + elif k == STF_DATA_PLACE_AFFINE: + return "affine" + else: + raise ValueError(f"Unknown data place kind: {k}") + + @property + def device_id(self) -> int: + if self._c_place.kind != STF_DATA_PLACE_DEVICE: + raise AttributeError("not a device data place") + return self._c_place.u.device.dev_id + + + +cdef class task: + cdef stf_task_handle _t + + # list of logical data in deps: we need this because we can't exchange + # dtype/shape easily through the C API of STF + cdef list _lds_args + + def __cinit__(self, context ctx): + stf_task_create(ctx._ctx, &self._t) + self._lds_args = [] + + def __dealloc__(self): + if self._t != NULL: + stf_task_destroy(self._t) + + def start(self): + # This is ignored if this is not a graph task + stf_task_enable_capture(self._t) + + stf_task_start(self._t) + + def end(self): + stf_task_end(self._t) + + def add_dep(self, object d): + """ + Accept a `dep` instance created with read(ld), write(ld), or rw(ld). + """ + if not isinstance(d, dep): + raise TypeError("add_dep expects read(ld), write(ld) or rw(ld)") + + cdef logical_data ldata = d.ld + cdef int mode_int = int(d.mode) + cdef stf_access_mode mode_ce = mode_int + cdef data_place dp + + if d.dplace is None: + stf_task_add_dep(self._t, ldata._ld, mode_ce) + else: + dp = d.dplace + stf_task_add_dep_with_dplace(self._t, ldata._ld, mode_ce, &dp._c_place) + + self._lds_args.append(ldata) + + def set_exec_place(self, object exec_p): + if not isinstance(exec_p, exec_place): + raise TypeError("set_exec_place expects and exec_place argument") + + cdef exec_place ep = exec_p + stf_task_set_exec_place(self._t, &ep._c_place) + + def stream_ptr(self) -> int: + """ + Return the raw CUstream pointer as a Python int + (memory address). Suitable for ctypes or PyCUDA. + """ + cdef CUstream s = stf_task_get_custream(self._t) + return s # cast pointer -> Py int + + def get_arg(self, index) -> int: + if self._lds_args[index]._is_token: + raise RuntimeError("cannot materialize a token argument") + + cdef void *ptr = stf_task_get(self._t, index) + return ptr + + def get_arg_cai(self, index): + ptr = self.get_arg(index) + return stf_arg_cai(ptr, self._lds_args[index].shape, self._lds_args[index].dtype, stream=self.stream_ptr()).__cuda_array_interface__ + + def get_arg_numba(self, index): + cai = self.get_arg_cai(index) + try: + from cuda.stf._adapters.numba_bridge import cai_to_numba + except Exception as e: + raise RuntimeError("numba support is not available") from e + return cai_to_numba(cai) + + def numba_arguments(self): + # Only include non-token arguments in the tuple + non_token_args = [self.get_arg_numba(i) for i in range(len(self._lds_args)) + if not self._lds_args[i]._is_token] + + if len(non_token_args) == 0: + return None + elif len(non_token_args) == 1: + return non_token_args[0] + return tuple(non_token_args) + + def get_arg_as_tensor(self, index): + cai = self.get_arg_cai(index) + try: + from cuda.stf._adapters.torch_bridge import cai_to_torch + except Exception as e: + raise RuntimeError("PyTorch support is not available") from e + return cai_to_torch(cai) + + def tensor_arguments(self): + # Only include non-token arguments in the tuple + non_token_args = [self.get_arg_as_tensor(i) for i in range(len(self._lds_args)) + if not self._lds_args[i]._is_token] + + if len(non_token_args) == 0: + return None + elif len(non_token_args) == 1: + return non_token_args[0] + return tuple(non_token_args) + + # ---- context‑manager helpers ------------------------------- + def __enter__(self): + self.start() + return self + + def __exit__(self, object exc_type, object exc, object tb): + """ + Always called, even if an exception occurred inside the block. + """ + self.end() + return False + +cdef class pytorch_task_context: + """ + Context manager for PyTorch-integrated STF tasks. + + This class automatically handles: + - Task start/end + - PyTorch stream context + - Tensor argument conversion and unpacking + """ + cdef task _task + cdef object _torch_stream_context + + def __cinit__(self, task t): + self._task = t + self._torch_stream_context = None + + def __enter__(self): + # Import torch here since we know it's available (checked in pytorch_task) + import torch.cuda as tc + + # Start the underlying task + self._task.start() + + # Create torch stream context from task stream + torch_stream = tc.ExternalStream(self._task.stream_ptr()) + self._torch_stream_context = tc.stream(torch_stream) + self._torch_stream_context.__enter__() + + # Get tensor arguments and return them + tensors = self._task.tensor_arguments() + + # If only one tensor, return it directly; otherwise return tuple + if isinstance(tensors, tuple): + return tensors + else: + return (tensors,) + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + # Exit torch stream context first + if self._torch_stream_context is not None: + self._torch_stream_context.__exit__(exc_type, exc_val, exc_tb) + finally: + # Always end the task + self._task.end() + return False + +cdef class context: + cdef stf_ctx_handle _ctx + # Is this a context that we have borrowed ? + cdef bint _borrowed + + def __cinit__(self, bint use_graph=False, bint borrowed=False): + self._ctx = NULL + self._borrowed = borrowed + if not borrowed: + if use_graph: + stf_ctx_create_graph(&self._ctx) + else: + stf_ctx_create(&self._ctx) + + cdef borrow_from_handle(self, stf_ctx_handle ctx_handle): + if self._ctx != NULL: + raise RuntimeError("context already initialized") + + if not self._borrowed: + raise RuntimeError("cannot call borrow_from_handle on this context") + + self._ctx = ctx_handle + + def __repr__(self): + return f"context(handle={self._ctx}, borrowed={self._borrowed})" + + def __dealloc__(self): + if not self._borrowed: + self.finalize() + + def finalize(self): + if self._borrowed: + raise RuntimeError("cannot finalize borrowed context") + + if self._ctx != NULL: + stf_ctx_finalize(self._ctx) + self._ctx = NULL + + def logical_data(self, object buf, data_place dplace=None): + """ + Create and return a `logical_data` object bound to this context [PRIMARY API]. + + This is the primary function for creating logical data from existing buffers. + It supports both Python buffer protocol objects and CUDA Array Interface objects, + with explicit data_place specification for optimal STF data movement strategies. + + Parameters + ---------- + buf : any buffer‑supporting Python object or __cuda_array_interface__ object + (NumPy array, Warp array, CuPy array, bytes, bytearray, memoryview, …) + dplace : data_place, optional + Specifies where the buffer is located (host, device, managed, affine). + Defaults to data_place.host() for backward compatibility. + Essential for GPU arrays - use data_place.device() for optimal performance. + + Examples + -------- + >>> # Host memory (explicit - recommended) + >>> host_place = data_place.host() + >>> ld = ctx.logical_data(numpy_array, host_place) + >>> + >>> # GPU device memory (recommended for CUDA arrays) + >>> device_place = data_place.device(0) + >>> ld = ctx.logical_data(warp_array, device_place) + >>> + >>> # Managed/unified memory + >>> managed_place = data_place.managed() + >>> ld = ctx.logical_data(unified_array, managed_place) + >>> + >>> # Backward compatibility (defaults to host) + >>> ld = ctx.logical_data(numpy_array) # Same as specifying host + + Note + ---- + For GPU arrays (Warp, CuPy, etc.), always specify data_place.device() + for zero-copy performance and correct memory management. + """ + return logical_data(self, buf, dplace) + + + def logical_data_empty(self, shape, dtype=None): + """ + Create logical data with uninitialized values. + + Equivalent to numpy.empty() but for STF logical data. + + Parameters + ---------- + shape : tuple + Shape of the array + dtype : numpy.dtype, optional + Data type. Defaults to np.float64. + + Returns + ------- + logical_data + New logical data with uninitialized values + + Examples + -------- + >>> # Create uninitialized array (fast but contains garbage) + >>> ld = ctx.logical_data_empty((100, 100), dtype=np.float32) + + >>> # Fast allocation without initialization + >>> ld = ctx.logical_data_empty((50, 50, 50)) + """ + if dtype is None: + dtype = np.float64 + return logical_data.init_by_shape(self, shape, dtype) + + def logical_data_full(self, shape, fill_value, dtype=None, where=None, exec_place=None): + """ + Create logical data initialized with a constant value. + + Similar to numpy.full(), this creates a new logical data with the given + shape and fills it with fill_value. + + Parameters + ---------- + shape : tuple + Shape of the array + fill_value : scalar + Value to fill the array with + dtype : numpy.dtype, optional + Data type. If None, infer from fill_value. + where : data_place, optional + Data placement for initialization. Defaults to current device. + exec_place : exec_place, optional + Execution place for the fill operation. Defaults to current device. + Note: exec_place.host() is not yet supported. + + Returns + ------- + logical_data + New logical data initialized with fill_value + + Examples + -------- + >>> # Create array filled with epsilon0 on current device + >>> ld = ctx.logical_data_full((100, 100), 8.85e-12, dtype=np.float64) + + >>> # Create array on host memory + >>> ld = ctx.logical_data_full((50, 50), 1.0, where=data_place.host()) + + >>> # Create array on specific device, execute on device 1 + >>> ld = ctx.logical_data_full((200, 200), 0.0, where=data_place.device(0), + ... exec_place=exec_place.device(1)) + """ + # Infer dtype from fill_value if not provided + if dtype is None: + dtype = np.array(fill_value).dtype + else: + dtype = np.dtype(dtype) + + # Validate exec_place - host execution not yet supported + if exec_place is not None: + if hasattr(exec_place, 'kind') and exec_place.kind == "host": + raise NotImplementedError( + "exec_place.host() is not yet supported for logical_data_full. " + "Use exec_place.device() or omit exec_place parameter." + ) + + # Create empty logical data + ld = self.logical_data_empty(shape, dtype) + + # Initialize with the specified value using NUMBA + # The numba code already handles None properly by calling ld.write() without data place + try: + from cuda.stf._adapters.numba_utils import init_logical_data + init_logical_data(self, ld, fill_value, where, exec_place) + except ImportError as e: + raise RuntimeError("NUMBA support is not available for logical_data_full") from e + + return ld + + def logical_data_zeros(self, shape, dtype=None, where=None, exec_place=None): + """ + Create logical data filled with zeros. + + Equivalent to numpy.zeros() but for STF logical data. + + Parameters + ---------- + shape : tuple + Shape of the array + dtype : numpy.dtype, optional + Data type. Defaults to np.float64. + where : data_place, optional + Data placement. Defaults to current device. + exec_place : exec_place, optional + Execution place for the fill operation. Defaults to current device. + + Returns + ------- + logical_data + New logical data filled with zeros + + Examples + -------- + >>> # Create zero-filled array + >>> ld = ctx.logical_data_zeros((100, 100), dtype=np.float32) + + >>> # Create on host memory + >>> ld = ctx.logical_data_zeros((50, 50), where=data_place.host()) + """ + if dtype is None: + dtype = np.float64 + return self.logical_data_full(shape, 0.0, dtype, where, exec_place) + + def logical_data_ones(self, shape, dtype=None, where=None, exec_place=None): + """ + Create logical data filled with ones. + + Equivalent to numpy.ones() but for STF logical data. + + Parameters + ---------- + shape : tuple + Shape of the array + dtype : numpy.dtype, optional + Data type. Defaults to np.float64. + where : data_place, optional + Data placement. Defaults to current device. + exec_place : exec_place, optional + Execution place for the fill operation. Defaults to current device. + + Returns + ------- + logical_data + New logical data filled with ones + + Examples + -------- + >>> # Create ones-filled array + >>> ld = ctx.logical_data_ones((100, 100), dtype=np.float32) + + >>> # Create on specific device + >>> ld = ctx.logical_data_ones((50, 50), exec_place=exec_place.device(1)) + """ + if dtype is None: + dtype = np.float64 + return self.logical_data_full(shape, 1.0, dtype, where, exec_place) + + def token(self): + return logical_data.token(self) + + def task(self, *args): + """ + Create a `task` + + Example + ------- + >>> t = ctx.task(read(lX), rw(lY)) + >>> t.start() + >>> t.end() + """ + exec_place_set = False + t = task(self) # construct with this context + for d in args: + if isinstance(d, dep): + t.add_dep(d) + elif isinstance(d, exec_place): + if exec_place_set: + raise ValueError("Only one exec_place can be given") + t.set_exec_place(d) + exec_place_set = True + else: + raise TypeError( + "Arguments must be dependency objects or an exec_place" + ) + return t + + def pytorch_task(self, *args): + """ + Create a PyTorch-integrated task that returns tensors directly. + Only available if PyTorch is installed. + + This is a convenience method that combines task creation with automatic + PyTorch stream management and tensor conversion. + + Example + ------- + >>> with ctx.pytorch_task(read(lX), rw(lY)) as (x_tensor, y_tensor): + >>> # Automatic PyTorch stream context and tensor unpacking + >>> y_tensor[:] = x_tensor * 2 + + Returns + ------- + pytorch_task_context : Context manager that yields tensor arguments + """ + # Check if PyTorch is available + try: + import torch + except ImportError: + raise RuntimeError( + "pytorch_task requires PyTorch to be installed. " + "Install PyTorch or use the regular task() method." + ) + + # Create the underlying task + t = self.task(*args) + + # Return a PyTorch-specific context manager + return pytorch_task_context(t) diff --git a/python/cuda_cccl/cuda/stf/decorator.py b/python/cuda_cccl/cuda/stf/decorator.py new file mode 100644 index 00000000000..41bf71c6316 --- /dev/null +++ b/python/cuda_cccl/cuda/stf/decorator.py @@ -0,0 +1,100 @@ +from numba import cuda + +from cuda.stf import context, dep, exec_place + + +class stf_kernel_decorator: + def __init__(self, pyfunc, jit_args, jit_kwargs): + self._pyfunc = pyfunc + self._jit_args = jit_args + self._jit_kwargs = jit_kwargs + self._compiled_kernel = None + # (grid_dim, block_dim, exec_place_or_none, ctx_or_none) + self._launch_cfg = None + + def __getitem__(self, cfg): + # Normalize cfg into (grid_dim, block_dim, exec_pl, ctx) + if not (isinstance(cfg, tuple) or isinstance(cfg, list)): + raise TypeError("use kernel[grid, block ([, exec_place, ctx])]") + n = len(cfg) + if n not in (2, 3, 4): + raise TypeError( + "use kernel[grid, block], kernel[grid, block, exec_place], or kernel[grid, block, exec_place, ctx]" + ) + + grid_dim = cfg[0] + block_dim = cfg[1] + ctx = None + exec_pl = None + + if n >= 3: + exec_pl = cfg[2] + + if n == 4: + ctx = cfg[3] + + if exec_pl is not None and not isinstance(exec_pl, exec_place): + raise TypeError("3rd item must be an exec_place") + + # Type checks (ctx can be None; exec_pl can be None) + if ctx is not None and not isinstance(ctx, context): + raise TypeError("4th item must be an STF context (or None to infer)") + + self._launch_cfg = (grid_dim, block_dim, ctx, exec_pl) + + return self + + def __call__(self, *args, **kwargs): + if self._launch_cfg is None: + raise RuntimeError( + "launch configuration missing – use kernel[grid, block, ctx](...)" + ) + + gridDim, blockDim, ctx, exec_pl = self._launch_cfg + + dep_items = [] + for i, a in enumerate(args): + # print(f"got one arg {a} is dep ? {isinstance(a, dep)}") + if isinstance(a, dep): + if ctx is None: + ld = a.get_ld() + # This context will be used in the __call__ method itself + # so we can create a temporary object from the handle + ctx = ld.borrow_ctx_handle() + dep_items.append((i, a)) + + task_args = [exec_pl] if exec_pl else [] + task_args.extend(a for _, a in dep_items) + + with ctx.task(*task_args) as t: + dev_args = list(args) + # print(dev_args) + for dep_index, (pos, _) in enumerate(dep_items): + # print(f"set arg {dep_index} at position {pos}") + dev_args[pos] = t.get_arg_numba(dep_index) + + if self._compiled_kernel is None: + # print("compile kernel") + self._compiled_kernel = cuda.jit(*self._jit_args, **self._jit_kwargs)( + self._pyfunc + ) + + nb_stream = cuda.external_stream(t.stream_ptr()) + self._compiled_kernel[gridDim, blockDim, nb_stream](*dev_args, **kwargs) + + return None + + +def jit(*jit_args, **jit_kwargs): + if jit_args and callable(jit_args[0]): + pyfunc = jit_args[0] + return _build_kernel(pyfunc, (), **jit_kwargs) + + def _decorator(fn): + return _build_kernel(fn, jit_args, **jit_kwargs) + + return _decorator + + +def _build_kernel(pyfunc, jit_args, **jit_kwargs): + return stf_kernel_decorator(pyfunc, jit_args, jit_kwargs) diff --git a/python/cuda_cccl/tests/stf/example_cholesky.py b/python/cuda_cccl/tests/stf/example_cholesky.py new file mode 100755 index 00000000000..7eded4a20b7 --- /dev/null +++ b/python/cuda_cccl/tests/stf/example_cholesky.py @@ -0,0 +1,707 @@ +#!/usr/bin/env python3 +""" +Python implementation of Cholesky decomposition using CUDA STF and CuPy (CUBLAS/CUSOLVER). + +This example demonstrates: +- Tiled matrix operations with STF logical data +- Integration of CuPy's CUBLAS and CUSOLVER functions with STF tasks +- Multi-device execution with automatic data placement +- Task-based parallelism for linear algebra operations + +Note: CUDASTF automatically manages device context within tasks via exec_place.device(). +There's no need to manually set the current device in task bodies - just use the STF stream. +""" + +import sys + +import cupy as cp +import numpy as np +from cupyx.scipy import linalg as cp_linalg + +import cuda.stf as stf + + +class CAIWrapper: + """Wrapper to expose CUDA Array Interface dict as a proper CAI object.""" + + def __init__(self, cai_dict): + self.__cuda_array_interface__ = cai_dict + + +def get_cupy_arrays(task): + """ + Get all CuPy arrays from STF task arguments. + + Usage: + d_a, d_b, d_c = get_cupy_arrays(t) + """ + arrays = [] + idx = 0 + while True: + try: + arrays.append(cp.asarray(CAIWrapper(task.get_arg_cai(idx)))) + idx += 1 + except Exception: + break + return tuple(arrays) if len(arrays) > 1 else arrays[0] if arrays else None + + +def cai_to_numpy(cai_dict): + """Convert CUDA Array Interface dict to NumPy array (for host memory).""" + import ctypes + + # Extract CAI fields + data_ptr, readonly = cai_dict["data"] + shape = cai_dict["shape"] + typestr = cai_dict["typestr"] + + # Convert typestr to NumPy dtype + dtype = np.dtype(typestr) + + # Calculate total size in bytes + itemsize = dtype.itemsize + size = np.prod(shape) * itemsize + + # Create ctypes buffer from pointer + buffer = (ctypes.c_byte * size).from_address(data_ptr) + + # Create NumPy array from buffer + arr = np.frombuffer(buffer, dtype=dtype).reshape(shape) + + return arr + + +class BlockRef: + """Reference to a specific block in a tiled matrix.""" + + def __init__(self, matrix, row, col): + self.matrix = matrix + self.row = row + self.col = col + self._handle = matrix.handle(row, col) + self._devid = matrix.get_preferred_devid(row, col) + + def handle(self): + """Get the STF logical data handle for this block.""" + return self._handle + + def devid(self): + """Get the preferred device ID for this block.""" + return self._devid + + def __repr__(self): + return f"BlockRef({self.matrix.symbol}[{self.row},{self.col}])" + + +class TiledMatrix: + """ + Tiled matrix class that splits a matrix into blocks for parallel processing. + Each block is managed as an STF logical data object. + """ + + def __init__( + self, + ctx, + nrows, + ncols, + block_rows, + block_cols, + is_symmetric=False, + symbol="matrix", + dtype=np.float64, + ): + """ + Initialize a tiled matrix. + + Args: + ctx: STF context + nrows: Total number of rows + ncols: Total number of columns + block_rows: Block size (rows) + block_cols: Block size (columns) + is_symmetric: If True, only stores lower triangular blocks + symbol: Name/symbol for the matrix + dtype: Data type (default: np.float64) + """ + self.ctx = ctx + self.symbol = symbol + self.dtype = dtype + + self.m = nrows + self.n = ncols + self.mb = block_rows + self.nb = block_cols + self.sym_matrix = is_symmetric + + assert self.m % self.mb == 0, ( + f"nrows ({self.m}) must be divisible by block_rows ({self.mb})" + ) + assert self.n % self.nb == 0, ( + f"ncols ({self.n}) must be divisible by block_cols ({self.nb})" + ) + + # Number of blocks + self.mt = self.m // self.mb + self.nt = self.n // self.nb + + # Allocate host memory (pinned for faster transfers) + self.h_array = cp.cuda.alloc_pinned_memory( + self.m * self.n * np.dtype(dtype).itemsize + ) + self.h_array_np = np.frombuffer(self.h_array, dtype=dtype).reshape( + self.m, self.n + ) + + # Create logical data handles for each block + self.handles = {} + + # Get available devices for mapping + self.ndevs = cp.cuda.runtime.getDeviceCount() + self.grid_p, self.grid_q = self._compute_device_grid(self.ndevs) + + print( + f"[{symbol}] {self.m}x{self.n} matrix, {self.mt}x{self.nt} blocks of {self.mb}x{self.nb}" + ) + print( + f"[{symbol}] Using {self.ndevs} devices in {self.grid_p}x{self.grid_q} grid" + ) + + # Note: We DON'T create logical data here yet - that happens in fill() + # after the host data is initialized + + def _compute_device_grid(self, ndevs): + """Compute 2D device grid dimensions (as close to square as possible)""" + grid_p = 1 + grid_q = ndevs + for a in range(1, int(np.sqrt(ndevs)) + 1): + if ndevs % a == 0: + grid_p = a + grid_q = ndevs // a + return grid_p, grid_q + + def get_preferred_devid(self, row, col): + """Get preferred device ID for a given block using cyclic distribution""" + return (row % self.grid_p) + (col % self.grid_q) * self.grid_p + + def handle(self, row, col): + """Get the logical data handle for block (row, col)""" + return self.handles[(row, col)] + + def block(self, row, col): + """Get a BlockRef for block (row, col)""" + return BlockRef(self, row, col) + + def _get_index(self, row, col): + """Convert (row, col) to linear index in tiled storage""" + # Find which tile contains this element + tile_row = row // self.mb + tile_col = col // self.nb + + tile_size = self.mb * self.nb + + # Index of the beginning of the tile + tile_start = (tile_row + self.mt * tile_col) * tile_size + + # Offset within the tile + offset = (row % self.mb) + (col % self.nb) * self.mb + + return tile_start + offset + + def _get_block_h(self, brow, bcol): + """Get a view of the host data for block (brow, bcol)""" + # For tiled storage, blocks are stored contiguously + start_idx = (brow + self.mt * bcol) * self.mb * self.nb + end_idx = start_idx + self.mb * self.nb + flat_view = self.h_array_np.ravel() + return flat_view[start_idx:end_idx].reshape(self.mb, self.nb) + + def fill(self, func): + """Fill matrix on host, then create STF logical data that will transfer automatically""" + print(f"[{self.symbol}] Filling matrix on host...") + + for colb in range(self.nt): + low_rowb = colb if self.sym_matrix else 0 + for rowb in range(low_rowb, self.mt): + # Fill host block + h_block = self._get_block_h(rowb, colb) + for lrow in range(self.mb): + for lcol in range(self.nb): + row = lrow + rowb * self.mb + col = lcol + colb * self.nb + h_block[lrow, lcol] = func(row, col) + + handle = self.ctx.logical_data(h_block) + handle.set_symbol(f"{self.symbol}_{rowb}_{colb}") + + self.handles[(rowb, colb)] = handle + + +# BLAS/LAPACK operations wrapped in STF tasks + + +def DPOTRF(ctx, a): + """Cholesky factorization of a diagonal block: A = L*L^T (lower triangular)""" + with ctx.task(stf.exec_place.device(a.devid()), a.handle().rw()) as t: + d_block = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + d_block[:] = cp.linalg.cholesky(d_block) + + +def DTRSM(ctx, a, b, side="L", uplo="L", transa="T", diag="N", alpha=1.0): + """Triangular solve: B = alpha * op(A)^{-1} @ B or B = alpha * B @ op(A)^{-1}""" + with ctx.task( + stf.exec_place.device(b.devid()), a.handle().read(), b.handle().rw() + ) as t: + d_a, d_b = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + if side == "L": + if transa == "N": + d_b[:] = cp_linalg.solve_triangular(d_a, d_b, lower=(uplo == "L")) + else: + d_b[:] = cp_linalg.solve_triangular(d_a.T, d_b, lower=(uplo != "L")) + if alpha != 1.0: + d_b *= alpha + else: + if transa == "N": + d_b[:] = cp_linalg.solve_triangular( + d_a.T, d_b.T, lower=(uplo != "L") + ).T + else: + d_b[:] = cp_linalg.solve_triangular( + d_a, d_b.T, lower=(uplo == "L") + ).T + if alpha != 1.0: + d_b *= alpha + + +def DGEMM(ctx, a, b, c, transa="N", transb="N", alpha=1.0, beta=1.0): + """Matrix multiplication: C = alpha * op(A) @ op(B) + beta * C""" + with ctx.task( + stf.exec_place.device(c.devid()), + a.handle().read(), + b.handle().read(), + c.handle().rw(), + ) as t: + d_a, d_b, d_c = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + op_a = d_a.T if transa != "N" else d_a + op_b = d_b.T if transb != "N" else d_b + + if beta == 0.0: + d_c[:] = alpha * (op_a @ op_b) + elif beta == 1.0: + d_c[:] += alpha * (op_a @ op_b) + else: + d_c[:] = alpha * (op_a @ op_b) + beta * d_c + + +def DSYRK(ctx, a, c, uplo="L", trans="N", alpha=1.0, beta=1.0): + """Symmetric rank-k update: C = alpha * op(A) @ op(A).T + beta * C""" + with ctx.task( + stf.exec_place.device(c.devid()), a.handle().read(), c.handle().rw() + ) as t: + d_a, d_c = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + op_a = d_a.T if trans != "N" else d_a + + if beta == 0.0: + d_c[:] = alpha * (op_a @ op_a.T) + elif beta == 1.0: + d_c[:] += alpha * (op_a @ op_a.T) + else: + d_c[:] = alpha * (op_a @ op_a.T) + beta * d_c + + +# High-level algorithms + + +def PDPOTRF(ctx, A): + """Parallel tiled Cholesky factorization (blocked algorithm)""" + print("\n[PDPOTRF] Starting Cholesky factorization...") + + assert A.m == A.n, "Matrix must be square" + assert A.mt == A.nt, "Block grid must be square" + assert A.sym_matrix, "Matrix must be symmetric" + + nblocks = A.mt + + for k in range(nblocks): + # Factor diagonal block + DPOTRF(ctx, A.block(k, k)) + + # Solve triangular systems for blocks in column k + for row in range(k + 1, nblocks): + DTRSM( + ctx, + A.block(k, k), + A.block(row, k), + side="R", + uplo="L", + transa="T", + diag="N", + alpha=1.0, + ) + + # Update trailing matrix + for col in range(k + 1, row): + DGEMM( + ctx, + A.block(row, k), + A.block(col, k), + A.block(row, col), + transa="N", + transb="T", + alpha=-1.0, + beta=1.0, + ) + + # Symmetric rank-k update of diagonal block + DSYRK( + ctx, + A.block(row, k), + A.block(row, row), + uplo="L", + trans="N", + alpha=-1.0, + beta=1.0, + ) + + print("[PDPOTRF] Completed") + + +def PDTRSM(ctx, A, B, side="L", uplo="L", trans="N", diag="N", alpha=1.0): + """Parallel tiled triangular solve""" + print("\n[PDTRSM] Starting triangular solve...") + + if side == "L": + if uplo == "L": + if trans == "N": + # Forward substitution + for k in range(B.mt): + lalpha = alpha if k == 0 else 1.0 + for n in range(B.nt): + DTRSM( + ctx, + A.block(k, k), + B.block(k, n), + side="L", + uplo="L", + transa="N", + diag=diag, + alpha=lalpha, + ) + for m in range(k + 1, B.mt): + for n in range(B.nt): + DGEMM( + ctx, + A.block(m, k), + B.block(k, n), + B.block(m, n), + transa="N", + transb="N", + alpha=-1.0, + beta=lalpha, + ) + else: # trans == 'T' or 'C' + # Backward substitution + for k in range(B.mt): + lalpha = alpha if k == 0 else 1.0 + row_idx = B.mt - k - 1 + for n in range(B.nt): + DTRSM( + ctx, + A.block(row_idx, row_idx), + B.block(row_idx, n), + side="L", + uplo="L", + transa="T", + diag=diag, + alpha=lalpha, + ) + for m in range(k + 1, B.mt): + m_idx = B.mt - 1 - m + for n in range(B.nt): + DGEMM( + ctx, + A.block(row_idx, m_idx), + B.block(row_idx, n), + B.block(m_idx, n), + transa="T", + transb="N", + alpha=-1.0, + beta=lalpha, + ) + + print("[PDTRSM] Completed") + + +def PDPOTRS(ctx, A, B, uplo="L"): + """Solve A @ X = B where A is factored by Cholesky (A = L @ L.T)""" + print("\n[PDPOTRS] Solving linear system...") + + # First solve: L @ Y = B + PDTRSM( + ctx, + A, + B, + side="L", + uplo=uplo, + trans="N" if uplo == "L" else "T", + diag="N", + alpha=1.0, + ) + + # Second solve: L.T @ X = Y + PDTRSM( + ctx, + A, + B, + side="L", + uplo=uplo, + trans="T" if uplo == "L" else "N", + diag="N", + alpha=1.0, + ) + + print("[PDPOTRS] Completed") + + +def PDGEMM(ctx, A, B, C, transa="N", transb="N", alpha=1.0, beta=1.0): + """Parallel tiled matrix multiplication""" + print("\n[PDGEMM] Starting matrix multiplication...") + + for m in range(C.mt): + for n in range(C.nt): + inner_k = A.nt if transa == "N" else A.mt + + if alpha == 0.0 or inner_k == 0: + # Just scale C + DGEMM( + ctx, + A.block(0, 0), + B.block(0, 0), + C.block(m, n), + transa=transa, + transb=transb, + alpha=0.0, + beta=beta, + ) + elif transa == "N": + if transb == "N": + for k in range(A.nt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(m, k), + B.block(k, n), + C.block(m, n), + transa="N", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: + for k in range(A.nt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(m, k), + B.block(n, k), + C.block(m, n), + transa="N", + transb="T", + alpha=alpha, + beta=zbeta, + ) + else: + if transb == "N": + for k in range(A.mt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(k, m), + B.block(k, n), + C.block(m, n), + transa="T", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: + for k in range(A.mt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(k, m), + B.block(n, k), + C.block(m, n), + transa="T", + transb="T", + alpha=alpha, + beta=zbeta, + ) + + print("[PDGEMM] Completed") + + +def compute_norm(ctx, matrix): + """Compute Frobenius norm of matrix using host tasks""" + norm_sq = 0.0 + + for colb in range(matrix.nt): + low_rowb = colb if matrix.sym_matrix else 0 + for rowb in range(low_rowb, matrix.mt): + handle = matrix.handle(rowb, colb) + + # Host task to read the block and compute norm + def compute_block_norm(h_block): + nonlocal norm_sq + norm_sq += np.sum(h_block * h_block) + + with ctx.task(stf.exec_place.host(), handle.read()) as t: + # Synchronize the stream before reading data + cp.cuda.runtime.streamSynchronize(t.stream_ptr()) + + h_block = cai_to_numpy(t.get_arg_cai(0)) + compute_block_norm(h_block) + + return np.sqrt(norm_sq) + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Tiled Cholesky decomposition with CUDA STF" + ) + parser.add_argument( + "N", type=int, nargs="?", default=1024, help="Matrix size (default: 1024)" + ) + parser.add_argument( + "NB", type=int, nargs="?", default=128, help="Block size (default: 128)" + ) + parser.add_argument("--check", action="store_true", help="Check result (slower)") + args = parser.parse_args() + + N = args.N + NB = args.NB + check_result = args.check + + assert N % NB == 0, f"Matrix size {N} must be divisible by block size {NB}" + + print("=" * 60) + print("Tiled Cholesky Decomposition with CUDA STF + CuPy") + print("=" * 60) + print(f"Matrix size: {N}x{N}") + print(f"Block size: {NB}x{NB}") + print(f"Number of blocks: {N // NB}x{N // NB}") + print(f"Check result: {check_result}") + print("=" * 60) + + # Create STF context + ctx = stf.context() + + # Create matrices + A = TiledMatrix(ctx, N, N, NB, NB, is_symmetric=True, symbol="A") + + if check_result: + Aref = TiledMatrix(ctx, N, N, NB, NB, is_symmetric=False, symbol="Aref") + + # Fill with Hilbert matrix + diagonal dominance + # H_{i,j} = 1/(i+j+1) + 2*N if i==j + def hilbert(row, col): + return 1.0 / (row + col + 1.0) + (2.0 * N if row == col else 0.0) + + print("\n" + "=" * 60) + print("Initializing matrices...") + print("=" * 60) + + A.fill(hilbert) + if check_result: + Aref.fill(hilbert) + + # Create right-hand side + if check_result: + B = TiledMatrix(ctx, N, 1, NB, 1, is_symmetric=False, symbol="B") + Bref = TiledMatrix(ctx, N, 1, NB, 1, is_symmetric=False, symbol="Bref") + + def rhs_vals(row, col): + return 1.0 * (row + 1) + + B.fill(rhs_vals) + Bref.fill(rhs_vals) + + # Compute ||B|| for residual calculation + Bref_norm = compute_norm(ctx, Bref) + + # Synchronize before timing + cp.cuda.runtime.deviceSynchronize() + + # Record start time + start_event = cp.cuda.Event() + stop_event = cp.cuda.Event() + start_event.record() + + # Perform Cholesky factorization + print("\n" + "=" * 60) + print("Performing Cholesky factorization...") + print("=" * 60) + PDPOTRF(ctx, A) + + # Record stop time + stop_event.record() + + # Solve system if checking + if check_result: + print("\n" + "=" * 60) + print("Solving linear system...") + print("=" * 60) + PDPOTRS(ctx, A, B, uplo="L") + + print("\n" + "=" * 60) + print("Computing residual...") + print("=" * 60) + # Compute residual: Bref = Aref @ B - Bref + PDGEMM(ctx, Aref, B, Bref, transa="N", transb="N", alpha=1.0, beta=-1.0) + + # Compute ||residual|| + res_norm = compute_norm(ctx, Bref) + + # Finalize STF context + print("\n" + "=" * 60) + print("Finalizing STF context...") + print("=" * 60) + ctx.finalize() + + # Wait for completion + stop_event.synchronize() + + # Compute timing + elapsed_ms = cp.cuda.get_elapsed_time(start_event, stop_event) + gflops = (1.0 / 3.0 * N * N * N) / 1e9 + gflops_per_sec = gflops / (elapsed_ms / 1000.0) + + print("\n" + "=" * 60) + print("Results") + print("=" * 60) + print(f"[PDPOTRF] Elapsed time: {elapsed_ms:.2f} ms") + print(f"[PDPOTRF] Performance: {gflops_per_sec:.2f} GFLOPS") + + if check_result: + residual = res_norm / Bref_norm + print(f"\n[POTRS] ||AX - B||: {res_norm:.6e}") + print(f"[POTRS] ||B||: {Bref_norm:.6e}") + print(f"[POTRS] Residual (||AX - B||/||B||): {residual:.6e}") + + if residual >= 0.01: + print("\n❌ Algorithm did not converge (residual >= 0.01)") + return 1 + else: + print("\n✅ Algorithm converged successfully!") + + print("=" * 60) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/cuda_cccl/tests/stf/example_fluid_warp.py b/python/cuda_cccl/tests/stf/example_fluid_warp.py new file mode 100644 index 00000000000..ab3fd406864 --- /dev/null +++ b/python/cuda_cccl/tests/stf/example_fluid_warp.py @@ -0,0 +1,399 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +########################################################################### +# Example Fluid +# +# Shows how to implement a simple 2D Stable Fluids solver using +# multidimensional arrays and launches. +# +########################################################################### + +import math + +import warp as wp +import warp.render + +import cuda.stf as cudastf + + +# Add a stf-specific decorator to the wp. namespace +def stf_kernel(pyfunc): + # let warp decorate normally + kernel = wp.kernel(pyfunc) + + # attach an STF-aware call operator + def _stf_call(*args, dim=None, stream=None, **kwargs): + return wp.stf.launch(kernel, dim=dim, inputs=args, stream=stream, **kwargs) + + # monkey-patch a method onto the kernel object + kernel.stf = _stf_call + + return kernel + + +def stf_launch(kernel, dim, inputs=None, stream=None, **kwargs): + # just forward to warp for now + return wp.launch( + kernel, + dim=dim, + inputs=inputs, + stream=stream, + **kwargs, + ) + + +# put it under wp.stf +if not hasattr(wp, "stf"): + + class _stf: + pass + + wp.stf = _stf() + + +wp.stf.kernel = stf_kernel +wp.stf.launch = stf_launch + +grid_width = wp.constant(256) +grid_height = wp.constant(128) + + +@wp.func +def lookup_float(f: wp.array2d(dtype=float), x: int, y: int): + x = wp.clamp(x, 0, grid_width - 1) + y = wp.clamp(y, 0, grid_height - 1) + + return f[x, y] + + +@wp.func +def sample_float(f: wp.array2d(dtype=float), x: float, y: float): + lx = int(wp.floor(x)) + ly = int(wp.floor(y)) + + tx = x - float(lx) + ty = y - float(ly) + + s0 = wp.lerp(lookup_float(f, lx, ly), lookup_float(f, lx + 1, ly), tx) + s1 = wp.lerp(lookup_float(f, lx, ly + 1), lookup_float(f, lx + 1, ly + 1), tx) + + s = wp.lerp(s0, s1, ty) + return s + + +@wp.func +def lookup_vel(f: wp.array2d(dtype=wp.vec2), x: int, y: int): + if x < 0 or x >= grid_width: + return wp.vec2() + if y < 0 or y >= grid_height: + return wp.vec2() + + return f[x, y] + + +@wp.func +def sample_vel(f: wp.array2d(dtype=wp.vec2), x: float, y: float): + lx = int(wp.floor(x)) + ly = int(wp.floor(y)) + + tx = x - float(lx) + ty = y - float(ly) + + s0 = wp.lerp(lookup_vel(f, lx, ly), lookup_vel(f, lx + 1, ly), tx) + s1 = wp.lerp(lookup_vel(f, lx, ly + 1), lookup_vel(f, lx + 1, ly + 1), tx) + + s = wp.lerp(s0, s1, ty) + return s + + +@wp.stf.kernel +def advect( + u0: wp.array2d(dtype=wp.vec2), + u1: wp.array2d(dtype=wp.vec2), + rho0: wp.array2d(dtype=float), + rho1: wp.array2d(dtype=float), + dt: float, +): + i, j = wp.tid() + + u = u0[i, j] + + # trace backward + p = wp.vec2(float(i), float(j)) + p = p - u * dt + + # advect + u1[i, j] = sample_vel(u0, p[0], p[1]) + rho1[i, j] = sample_float(rho0, p[0], p[1]) + + +@wp.stf.kernel +def divergence(u: wp.array2d(dtype=wp.vec2), div: wp.array2d(dtype=float)): + i, j = wp.tid() + + if i == grid_width - 1: + return + if j == grid_height - 1: + return + + dx = (u[i + 1, j][0] - u[i, j][0]) * 0.5 + dy = (u[i, j + 1][1] - u[i, j][1]) * 0.5 + + div[i, j] = dx + dy + + +@wp.stf.kernel +def pressure_solve( + p0: wp.array2d(dtype=float), + p1: wp.array2d(dtype=float), + div: wp.array2d(dtype=float), +): + i, j = wp.tid() + + s1 = lookup_float(p0, i - 1, j) + s2 = lookup_float(p0, i + 1, j) + s3 = lookup_float(p0, i, j - 1) + s4 = lookup_float(p0, i, j + 1) + + # Jacobi update + err = s1 + s2 + s3 + s4 - div[i, j] + + p1[i, j] = err * 0.25 + + +@wp.stf.kernel +def pressure_apply(p: wp.array2d(dtype=float), u: wp.array2d(dtype=wp.vec2)): + i, j = wp.tid() + + if i == 0 or i == grid_width - 1: + return + if j == 0 or j == grid_height - 1: + return + + # pressure gradient + f_p = wp.vec2(p[i + 1, j] - p[i - 1, j], p[i, j + 1] - p[i, j - 1]) * 0.5 + + u[i, j] = u[i, j] - f_p + + +@wp.stf.kernel +def integrate(u: wp.array2d(dtype=wp.vec2), rho: wp.array2d(dtype=float), dt: float): + i, j = wp.tid() + + # gravity + f_g = wp.vec2(-90.8, 0.0) * rho[i, j] + + # integrate + u[i, j] = u[i, j] + dt * f_g + + # fade + rho[i, j] = rho[i, j] * (1.0 - 0.1 * dt) + + +@wp.stf.kernel +def init( + rho: wp.array2d(dtype=float), + u: wp.array2d(dtype=wp.vec2), + radius: int, + dir: wp.vec2, +): + i, j = wp.tid() + + d = wp.length(wp.vec2(float(i - grid_width / 2), float(j - grid_height / 2))) + + if d < radius: + rho[i, j] = 1.0 + u[i, j] = dir + + +class Example: + def __init__(self): + fps = 60 + self.frame_dt = 1.0 / fps + self.sim_substeps = 2 + self.iterations = 100 # Number of pressure iterations + self.sim_dt = self.frame_dt / self.sim_substeps + self.sim_time = 0.0 + + self._stf_ctx = cudastf.context() + + shape = (grid_width, grid_height) + + self.u0 = wp.zeros(shape, dtype=wp.vec2) + self.u1 = wp.zeros(shape, dtype=wp.vec2) + + self.rho0 = wp.zeros(shape, dtype=float) + self.rho1 = wp.zeros(shape, dtype=float) + + self.p0 = wp.zeros(shape, dtype=float) + self.p1 = wp.zeros(shape, dtype=float) + self.div = wp.zeros(shape, dtype=float) + + # Create STF logical data from Warp arrays with explicit data place + # Warp arrays are on GPU device memory, so specify data_place.device() + + # For regular float arrays, specify device data place + device_place = cudastf.data_place.device(0) + + self.rho0._stf_ld = self._stf_ctx.logical_data(self.rho0, device_place) + self.rho1._stf_ld = self._stf_ctx.logical_data(self.rho1, device_place) + self.p0._stf_ld = self._stf_ctx.logical_data(self.p0, device_place) + self.p1._stf_ld = self._stf_ctx.logical_data(self.p1, device_place) + self.div._stf_ld = self._stf_ctx.logical_data(self.div, device_place) + + # vec2 arrays - STF now automatically handles vector type flattening + # Store STF logical data consistently with other arrays + self.u0._stf_ld = self._stf_ctx.logical_data(self.u0, device_place) + self.u1._stf_ld = self._stf_ctx.logical_data(self.u1, device_place) + + self.rho0._stf_ld.set_symbol("density_current") + self.rho1._stf_ld.set_symbol("density_next") + self.p0._stf_ld.set_symbol("pressure_current") + self.p1._stf_ld.set_symbol("pressure_next") + self.div._stf_ld.set_symbol("velocity_divergence") + self.u0._stf_ld.set_symbol("velocity_current") + self.u1._stf_ld.set_symbol("velocity_next") + + # Set Warp array names (for Warp tracing) + self.u0._name = "u0" + self.u1._name = "u1" + self.rho0._name = "rho0" + self.rho1._name = "rho1" + self.p0._name = "p0" + self.p1._name = "p1" + self.div._name = "div" + + # capture pressure solve as a CUDA graph + self.use_cuda_graph = wp.get_device().is_cuda + if self.use_cuda_graph: + with wp.ScopedCapture() as capture: + self.pressure_iterations() + self.graph = capture.graph + + def step(self): + with wp.ScopedTimer("step"): + for _ in range(self.sim_substeps): + shape = (grid_width, grid_height) + dt = self.sim_dt + + speed = 400.0 + angle = math.sin(self.sim_time * 4.0) * 1.5 + vel = wp.vec2(math.cos(angle) * speed, math.sin(angle) * speed) + + # update emitters + wp.stf.launch(init, dim=shape, inputs=[self.rho0, self.u0, 5, vel]) + + # force integrate + wp.stf.launch(integrate, dim=shape, inputs=[self.u0, self.rho0, dt]) + wp.stf.launch(divergence, dim=shape, inputs=[self.u0, self.div]) + + # pressure solve + self.p0.zero_() + self.p1.zero_() + + # TODO experiment with explicit capture at Warp level + # if self.use_cuda_graph: + # wp.capture_launch(self.graph) + # else: + # self.pressure_iterations() + self.pressure_iterations() + + # velocity update + wp.stf.launch(pressure_apply, dim=shape, inputs=[self.p0, self.u0]) + + # semi-Lagrangian advection + wp.stf.launch( + advect, + dim=shape, + inputs=[self.u0, self.u1, self.rho0, self.rho1, dt], + ) + + # swap buffers + (self.u0, self.u1) = (self.u1, self.u0) + (self.rho0, self.rho1) = (self.rho1, self.rho0) + + self.sim_time += dt + + def pressure_iterations(self): + for _ in range(self.iterations): + wp.stf.launch( + pressure_solve, dim=self.p0.shape, inputs=[self.p0, self.p1, self.div] + ) + + # swap pressure fields + (self.p0, self.p1) = (self.p1, self.p0) + + def step_and_render_frame(self, frame_num=None, img=None): + self.step() + + with wp.ScopedTimer("render"): + if img: + img.set_array(self.rho0.numpy()) + + return (img,) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--device", type=str, default=None, help="Override the default Warp device." + ) + parser.add_argument( + "--num_frames", type=int, default=100000, help="Total number of frames." + ) + parser.add_argument( + "--headless", + action="store_true", + help="Run in headless mode, suppressing the opening of any graphical windows.", + ) + + args = parser.parse_known_args()[0] + + with wp.ScopedDevice(args.device): + example = Example() + + if args.headless: + for _ in range(args.num_frames): + example.step() + else: + import matplotlib + import matplotlib.animation as anim + import matplotlib.pyplot as plt + + fig = plt.figure() + + img = plt.imshow( + example.rho0.numpy(), + origin="lower", + animated=True, + interpolation="antialiased", + ) + img.set_norm(matplotlib.colors.Normalize(0.0, 1.0)) + seq = anim.FuncAnimation( + fig, + example.step_and_render_frame, + fargs=(img,), + frames=args.num_frames, + blit=True, + interval=8, + repeat=False, + ) + + plt.show() diff --git a/python/cuda_cccl/tests/stf/example_potri.py b/python/cuda_cccl/tests/stf/example_potri.py new file mode 100644 index 00000000000..1e3c721a9c1 --- /dev/null +++ b/python/cuda_cccl/tests/stf/example_potri.py @@ -0,0 +1,832 @@ +#!/usr/bin/env python3 +""" +Python implementation of POTRI (matrix inversion via Cholesky) using CUDA STF and CuPy. + +POTRI computes the inverse of a symmetric positive definite matrix using its Cholesky factorization: +1. Cholesky factorization: A = L*L^T +2. Triangular inversion: L^(-1) +3. Compute A^(-1) = L^(-T) * L^(-1) + +This example demonstrates: +- Tiled matrix operations with STF logical data +- Integration of CuPy's CUBLAS and CUSOLVER functions with STF tasks +- Multi-device execution with automatic data placement +- Task-based parallelism for linear algebra operations +""" + +import sys + +import cupy as cp +import numpy as np +from cupyx.scipy import linalg as cp_linalg + +import cuda.stf as stf + + +class CAIWrapper: + """Wrapper to expose CUDA Array Interface dict as a proper CAI object.""" + + def __init__(self, cai_dict): + self.__cuda_array_interface__ = cai_dict + + +def get_cupy_arrays(task): + """ + Get all CuPy arrays from STF task arguments. + + Usage: + d_a, d_b, d_c = get_cupy_arrays(t) + """ + arrays = [] + idx = 0 + while True: + try: + arrays.append(cp.asarray(CAIWrapper(task.get_arg_cai(idx)))) + idx += 1 + except Exception: + break + return tuple(arrays) if len(arrays) > 1 else arrays[0] if arrays else None + + +def cai_to_numpy(cai_dict): + """Convert CUDA Array Interface dict to NumPy array (for host memory).""" + import ctypes + + # Extract CAI fields + data_ptr, readonly = cai_dict["data"] + shape = cai_dict["shape"] + typestr = cai_dict["typestr"] + + # Convert typestr to NumPy dtype + dtype = np.dtype(typestr) + + # Calculate total size in bytes + itemsize = dtype.itemsize + size = np.prod(shape) * itemsize + + # Create ctypes buffer from pointer + buffer = (ctypes.c_byte * size).from_address(data_ptr) + + # Create NumPy array from buffer + arr = np.frombuffer(buffer, dtype=dtype).reshape(shape) + + return arr + + +class BlockRef: + """Reference to a specific block in a tiled matrix.""" + + def __init__(self, matrix, row, col): + self.matrix = matrix + self.row = row + self.col = col + self._handle = matrix.handle(row, col) + self._devid = matrix.get_preferred_devid(row, col) + + def handle(self): + """Get the STF logical data handle for this block.""" + return self._handle + + def devid(self): + """Get the preferred device ID for this block.""" + return self._devid + + def __repr__(self): + return f"BlockRef({self.matrix.symbol}[{self.row},{self.col}])" + + +class TiledMatrix: + """ + Tiled matrix class that splits a matrix into blocks for parallel processing. + Each block is managed as an STF logical data object. + Uses tiled storage format for contiguous blocks. + """ + + def __init__( + self, + ctx, + nrows, + ncols, + blocksize_rows, + blocksize_cols, + is_symmetric=False, + symbol="matrix", + dtype=np.float64, + ): + self.ctx = ctx + self.symbol = symbol + self.dtype = dtype + self.sym_matrix = is_symmetric + + self.m = nrows + self.n = ncols + self.mb = blocksize_rows + self.nb = blocksize_cols + + assert self.m % self.mb == 0, ( + f"nrows {nrows} must be divisible by blocksize_rows {blocksize_rows}" + ) + assert self.n % self.nb == 0, ( + f"ncols {ncols} must be divisible by blocksize_cols {blocksize_cols}" + ) + + # Number of blocks + self.mt = self.m // self.mb + self.nt = self.n // self.nb + + # Allocate pinned host memory for faster transfers (in tiled format) + self.h_array = cp.cuda.alloc_pinned_memory( + self.m * self.n * np.dtype(dtype).itemsize + ) + self.h_array_np = np.frombuffer(self.h_array, dtype=dtype).reshape( + self.m, self.n + ) + + # Dictionary to store logical data handles for each block + self.handles = {} + + # Determine device layout + self.ndevs = cp.cuda.runtime.getDeviceCount() + self.grid_p, self.grid_q = self._compute_device_grid(self.ndevs) + + print( + f"[{self.symbol}] {self.m}x{self.n} matrix, {self.mt}x{self.nt} blocks of {self.mb}x{self.nb}" + ) + print( + f"[{self.symbol}] Using {self.ndevs} devices in {self.grid_p}x{self.grid_q} grid" + ) + + def _compute_device_grid(self, ndevs): + """Compute 2D device grid dimensions (as close to square as possible)""" + grid_p = 1 + grid_q = ndevs + for a in range(1, int(np.sqrt(ndevs)) + 1): + if ndevs % a == 0: + grid_p = a + grid_q = ndevs // a + return grid_p, grid_q + + def get_preferred_devid(self, row, col): + """Get preferred device ID for a given block using cyclic distribution""" + return (row % self.grid_p) + (col % self.grid_q) * self.grid_p + + def handle(self, row, col): + """Get the logical data handle for a block.""" + return self.handles[(row, col)] + + def block(self, row, col): + """Get a BlockRef for block (row, col)""" + return BlockRef(self, row, col) + + def _get_index(self, row, col): + """Convert (row, col) to linear index in tiled storage""" + tile_row = row // self.mb + tile_col = col // self.nb + tile_size = self.mb * self.nb + tile_start = (tile_row + self.mt * tile_col) * tile_size + offset = (row % self.mb) + (col % self.nb) * self.mb + return tile_start + offset + + def _get_block_h(self, brow, bcol): + """Get a view of the host data for block (brow, bcol)""" + # For tiled storage, blocks are stored contiguously + start_idx = (brow + self.mt * bcol) * self.mb * self.nb + end_idx = start_idx + self.mb * self.nb + flat_view = self.h_array_np.ravel() + return flat_view[start_idx:end_idx].reshape(self.mb, self.nb) + + def fill(self, func): + """ + Fill the matrix blocks using a function func(row, col) -> value. + Creates STF logical data from host arrays and lets STF handle transfers. + """ + print(f"[{self.symbol}] Filling matrix on host...") + for colb in range(self.nt): + low_rowb = colb if self.sym_matrix else 0 + for rowb in range(low_rowb, self.mt): + # Fill host block + h_block = self._get_block_h(rowb, colb) + for lrow in range(self.mb): + for lcol in range(self.nb): + row = lrow + rowb * self.mb + col = lcol + colb * self.nb + h_block[lrow, lcol] = func(row, col) + + handle = self.ctx.logical_data(h_block) + handle.set_symbol(f"{self.symbol}_{rowb}_{colb}") + + self.handles[(rowb, colb)] = handle + + +# ============================================================================ +# Block-level operations (BLAS/LAPACK) +# ============================================================================ + + +def DPOTRF(ctx, a): + """Cholesky factorization of a diagonal block: A = L*L^T (lower triangular)""" + with ctx.task(stf.exec_place.device(a.devid()), a.handle().rw()) as t: + d_block = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + d_block[:] = cp.linalg.cholesky(d_block) + + +def DTRSM(ctx, a, b, side="L", uplo="L", transa="N", diag="N", alpha=1.0): + """Triangular solve: B = alpha * op(A)^(-1) * B""" + with ctx.task( + stf.exec_place.device(b.devid()), a.handle().read(), b.handle().rw() + ) as t: + d_a, d_b = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + lower = uplo == "L" + trans = transa != "N" + result = cp_linalg.solve_triangular(d_a, d_b, lower=lower, trans=trans) + if alpha != 1.0: + d_b[:] = alpha * result + else: + d_b[:] = result + + +def DTRTRI(ctx, a, uplo="L", diag="N"): + """Triangular matrix inversion: A = A^(-1)""" + with ctx.task(stf.exec_place.device(a.devid()), a.handle().rw()) as t: + d_block = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + lower = uplo == "L" + unit_diagonal = diag == "U" + # CuPy doesn't have trtri directly, use solve with identity + n = d_block.shape[0] + identity = cp.eye(n, dtype=d_block.dtype) + d_block[:] = cp_linalg.solve_triangular( + d_block, identity, lower=lower, unit_diagonal=unit_diagonal + ) + + +def DGEMM(ctx, a, b, c, transa="N", transb="N", alpha=1.0, beta=1.0): + """General matrix multiplication: C = alpha * op(A) * op(B) + beta * C""" + with ctx.task( + stf.exec_place.device(c.devid()), + a.handle().read(), + b.handle().read(), + c.handle().rw(), + ) as t: + d_a, d_b, d_c = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + op_a = d_a.T if transa != "N" else d_a + op_b = d_b.T if transb != "N" else d_b + + if beta == 0.0: + d_c[:] = alpha * (op_a @ op_b) + elif beta == 1.0: + d_c[:] += alpha * (op_a @ op_b) + else: + d_c[:] = alpha * (op_a @ op_b) + beta * d_c + + +def DSYRK(ctx, a, c, uplo="L", trans="N", alpha=1.0, beta=1.0): + """Symmetric rank-k update: C = alpha * op(A) @ op(A).T + beta * C""" + with ctx.task( + stf.exec_place.device(c.devid()), a.handle().read(), c.handle().rw() + ) as t: + d_a, d_c = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + op_a = d_a.T if trans != "N" else d_a + + if beta == 0.0: + d_c[:] = alpha * (op_a @ op_a.T) + elif beta == 1.0: + d_c[:] += alpha * (op_a @ op_a.T) + else: + d_c[:] = alpha * (op_a @ op_a.T) + beta * d_c + + +def DTRMM(ctx, a, b, side="L", uplo="L", transa="N", diag="N", alpha=1.0): + """Triangular matrix multiplication: B = alpha * op(A) * B (side='L') or B = alpha * B * op(A) (side='R')""" + with ctx.task( + stf.exec_place.device(b.devid()), a.handle().read(), b.handle().rw() + ) as t: + d_a, d_b = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + lower = uplo == "L" + trans = transa != "N" + + # Extract triangle from A + if lower: + tri_a = cp.tril(d_a) + else: + tri_a = cp.triu(d_a) + + if trans: + tri_a = tri_a.T + + if side == "L": + d_b[:] = alpha * (tri_a @ d_b) + else: # side == 'R' + d_b[:] = alpha * (d_b @ tri_a) + + +def DSYMM(ctx, a, b, c, side="L", uplo="L", alpha=1.0, beta=1.0): + """Symmetric matrix multiplication: C = alpha * A * B + beta * C (side='L') or C = alpha * B * A + beta * C (side='R') + where A is symmetric.""" + with ctx.task( + stf.exec_place.device(c.devid()), + a.handle().read(), + b.handle().read(), + c.handle().rw(), + ) as t: + d_a, d_b, d_c = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + # Reconstruct full symmetric matrix from lower/upper triangle + if uplo == "L": + # Lower triangle is stored + sym_a = cp.tril(d_a) + cp.tril(d_a, -1).T + else: + # Upper triangle is stored + sym_a = cp.triu(d_a) + cp.triu(d_a, 1).T + + if side == "L": + result = alpha * (sym_a @ d_b) + else: # side == 'R' + result = alpha * (d_b @ sym_a) + + if beta == 0.0: + d_c[:] = result + elif beta == 1.0: + d_c[:] += result + else: + d_c[:] = result + beta * d_c + + +# ============================================================================ +# Tiled operations +# ============================================================================ + + +def PDPOTRF(ctx, A, uplo="L"): + """Parallel tiled Cholesky factorization""" + print("\n[PDPOTRF] Starting Cholesky factorization...") + assert uplo == "L", "Only lower triangular factorization supported" + + for k in range(A.nt): + # Factorize diagonal block + DPOTRF(ctx, A.block(k, k)) + + # Update column below diagonal + for m in range(k + 1, A.mt): + DTRSM( + ctx, + A.block(k, k), + A.block(m, k), + side="R", + uplo="L", + transa="T", + diag="N", + alpha=1.0, + ) + + # Update trailing submatrix + for n in range(k + 1, A.nt): + DSYRK( + ctx, + A.block(n, k), + A.block(n, n), + uplo="L", + trans="N", + alpha=-1.0, + beta=1.0, + ) + + for m in range(n + 1, A.mt): + DGEMM( + ctx, + A.block(m, k), + A.block(n, k), + A.block(m, n), + transa="N", + transb="T", + alpha=-1.0, + beta=1.0, + ) + + print("[PDPOTRF] Completed") + + +def PDTRTRI(ctx, A, uplo="L", diag="N"): + """Parallel tiled triangular matrix inversion""" + print("\n[PDTRTRI] Starting triangular inversion...") + assert uplo == "L", "Only lower triangular inversion supported" + + for k in range(A.nt): + # Step 1: Update A[m,k] for m > k + for m in range(k + 1, A.mt): + DTRSM( + ctx, + A.block(k, k), + A.block(m, k), + side="R", + uplo="L", + transa="N", + diag=diag, + alpha=-1.0, + ) + + # Step 2: Update A[m,n] for m > k, n < k + for m in range(k + 1, A.mt): + for n in range(k): + DGEMM( + ctx, + A.block(m, k), + A.block(k, n), + A.block(m, n), + transa="N", + transb="N", + alpha=1.0, + beta=1.0, + ) + + # Step 3: Update A[k,n] for n < k + for n in range(k): + DTRSM( + ctx, + A.block(k, k), + A.block(k, n), + side="L", + uplo="L", + transa="N", + diag=diag, + alpha=1.0, + ) + + # Step 4: Invert diagonal block A[k,k] + DTRTRI(ctx, A.block(k, k), uplo=uplo, diag=diag) + + print("[PDTRTRI] Completed") + + +def DLAAUM(ctx, a, uplo="L"): + """Compute A^T * A for a triangular block (lauum operation)""" + with ctx.task(stf.exec_place.device(a.devid()), a.handle().rw()) as t: + d_block = get_cupy_arrays(t) + with cp.cuda.ExternalStream(t.stream_ptr()): + # lauum: compute L * L^T for lower triangular L + if uplo == "L": + L = cp.tril(d_block) + d_block[:] = L @ L.T + else: + U = cp.triu(d_block) + d_block[:] = U.T @ U + + +def PDLAUUM(ctx, A, uplo="L"): + """Parallel tiled computation of A^T * A for lower triangular A""" + print("\n[PDLAUUM] Starting LAUUM (A^T * A)...") + assert uplo == "L", "Only lower triangular LAUUM supported" + + for k in range(A.mt): + # Step 1: Update off-diagonal blocks + for n in range(k): + # Update A[n,n] with A[k,n]^T * A[k,n] + DSYRK( + ctx, + A.block(k, n), + A.block(n, n), + uplo="L", + trans="T", + alpha=1.0, + beta=1.0, + ) + + # Update A[m,n] with A[k,m]^T * A[k,n] + for m in range(n + 1, k): + DGEMM( + ctx, + A.block(k, m), + A.block(k, n), + A.block(m, n), + transa="T", + transb="N", + alpha=1.0, + beta=1.0, + ) + + # Step 2: Update A[k,n] = A[k,k]^T * A[k,n] + for n in range(k): + DTRMM( + ctx, + A.block(k, k), + A.block(k, n), + side="L", + uplo="L", + transa="T", + diag="N", + alpha=1.0, + ) + + # Step 3: Update diagonal block A[k,k] = A[k,k]^T * A[k,k] + DLAAUM(ctx, A.block(k, k), uplo=uplo) + + print("[PDLAUUM] Completed") + + +def PDGEMM(ctx, A, B, C, transa="N", transb="N", alpha=1.0, beta=1.0): + """Parallel tiled matrix multiplication""" + print("\n[PDGEMM] Starting matrix multiplication...") + + for m in range(C.mt): + for n in range(C.nt): + inner_k = A.nt if transa == "N" else A.mt + + if alpha == 0.0 or inner_k == 0: + # Just scale C + DGEMM( + ctx, + A.block(0, 0), + B.block(0, 0), + C.block(m, n), + transa=transa, + transb=transb, + alpha=0.0, + beta=beta, + ) + elif transa == "N": + if transb == "N": + for k in range(A.nt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(m, k), + B.block(k, n), + C.block(m, n), + transa="N", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: + for k in range(A.nt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(m, k), + B.block(n, k), + C.block(m, n), + transa="N", + transb="T", + alpha=alpha, + beta=zbeta, + ) + else: # transa in ['T', 'C'] + if transb == "N": + for k in range(A.mt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(k, m), + B.block(k, n), + C.block(m, n), + transa="T", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: + for k in range(A.mt): + zbeta = beta if k == 0 else 1.0 + DGEMM( + ctx, + A.block(k, m), + B.block(n, k), + C.block(m, n), + transa="T", + transb="T", + alpha=alpha, + beta=zbeta, + ) + + print("[PDGEMM] Completed") + + +def PDSYMM(ctx, A, B, C, side="L", uplo="L", alpha=1.0, beta=1.0): + """Parallel tiled symmetric matrix multiplication""" + print("\n[PDSYMM] Starting symmetric matrix multiplication...") + + for m in range(C.mt): + for n in range(C.nt): + if side == "L": + if uplo == "L": + for k in range(C.mt): + zbeta = beta if k == 0 else 1.0 + if k < m: + DGEMM( + ctx, + A.block(m, k), + B.block(k, n), + C.block(m, n), + transa="N", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: + if k == m: + DSYMM( + ctx, + A.block(k, k), + B.block(k, n), + C.block(m, n), + side=side, + uplo=uplo, + alpha=alpha, + beta=zbeta, + ) + else: + DGEMM( + ctx, + A.block(k, m), + B.block(k, n), + C.block(m, n), + transa="T", + transb="N", + alpha=alpha, + beta=zbeta, + ) + else: # side == 'R' + # Similar logic for right multiplication + pass + + print("[PDSYMM] Completed") + + +def compute_norm(ctx, matrix): + """Compute Frobenius norm of matrix using host tasks""" + norm_sq = 0.0 + + for colb in range(matrix.nt): + low_rowb = colb if matrix.sym_matrix else 0 + for rowb in range(low_rowb, matrix.mt): + handle = matrix.handle(rowb, colb) + + # Host task to read the block and compute norm + def compute_block_norm(h_block): + nonlocal norm_sq + norm_sq += np.sum(h_block * h_block) + + with ctx.task(stf.exec_place.host(), handle.read()) as t: + # Synchronize the stream before reading data + cp.cuda.runtime.streamSynchronize(t.stream_ptr()) + + h_block = cai_to_numpy(t.get_arg_cai(0)) + compute_block_norm(h_block) + + return np.sqrt(norm_sq) + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Tiled POTRI (matrix inversion via Cholesky) with CUDA STF" + ) + parser.add_argument( + "N", type=int, nargs="?", default=512, help="Matrix size (default: 512)" + ) + parser.add_argument( + "NB", type=int, nargs="?", default=128, help="Block size (default: 128)" + ) + parser.add_argument("--check", action="store_true", help="Check result (slower)") + args = parser.parse_args() + + N = args.N + NB = args.NB + check_result = args.check + + assert N % NB == 0, f"Matrix size {N} must be divisible by block size {NB}" + + print("=" * 60) + print("Tiled POTRI (Matrix Inversion) with CUDA STF + CuPy") + print("=" * 60) + print(f"Matrix size: {N}x{N}") + print(f"Block size: {NB}x{NB}") + print(f"Number of blocks: {N // NB}x{N // NB}") + print(f"Check result: {check_result}") + print("=" * 60) + + # Create STF context + ctx = stf.context() + + # Create matrices + A = TiledMatrix(ctx, N, N, NB, NB, is_symmetric=True, symbol="A") + + if check_result: + Aref = TiledMatrix(ctx, N, N, NB, NB, is_symmetric=False, symbol="Aref") + + print("\n" + "=" * 60) + print("Initializing matrices...") + print("=" * 60) + + # Hilbert matrix + diagonal dominance for numerical stability + def hilbert(row, col): + return 1.0 / (col + row + 1.0) + 2.0 * N * (col == row) + + A.fill(hilbert) + if check_result: + Aref.fill(hilbert) + + # Measure performance + import time + + start_time = time.time() + + print("\n" + "=" * 60) + print("Performing POTRI (inversion via Cholesky)...") + print("=" * 60) + + # Step 1: Cholesky factorization A = L*L^T + PDPOTRF(ctx, A, uplo="L") + + # Step 2: Triangular inversion L^(-1) + PDTRTRI(ctx, A, uplo="L", diag="N") + + # Step 3: Compute A^(-1) = L^(-T) * L^(-1) + PDLAUUM(ctx, A, uplo="L") + + if check_result: + print("\n" + "=" * 60) + print("Verifying result...") + print("=" * 60) + + # Create test vector B + B_potri = TiledMatrix(ctx, N, 1, NB, 1, is_symmetric=False, symbol="B_potri") + Bref_potri = TiledMatrix( + ctx, N, 1, NB, 1, is_symmetric=False, symbol="Bref_potri" + ) + + def rhs_vals(row, col): + return 1.0 * (row + 1) + + B_potri.fill(rhs_vals) + Bref_potri.fill(rhs_vals) + + # Compute norm of B + b_norm = compute_norm(ctx, Bref_potri) + + # Create temporary matrix for result + B_tmp = TiledMatrix(ctx, N, 1, NB, 1, is_symmetric=False, symbol="B_tmp") + + def zero_vals(row, col): + return 0.0 + + B_tmp.fill(zero_vals) + + # Compute B_tmp = A^(-1) * B + PDSYMM(ctx, A, B_potri, B_tmp, side="L", uplo="L", alpha=1.0, beta=0.0) + + # Compute residual: Bref = Aref * B_tmp - Bref + PDGEMM( + ctx, Aref, B_tmp, Bref_potri, transa="N", transb="N", alpha=1.0, beta=-1.0 + ) + + # Compute residual norm + res_norm = compute_norm(ctx, Bref_potri) + + print("\n" + "=" * 60) + print("Finalizing STF context...") + print("=" * 60) + ctx.finalize() + + end_time = time.time() + elapsed_ms = (end_time - start_time) * 1000.0 + + # Compute FLOPS for POTRI + # POTRF: (1/3) * N^3 + # TRTRI: (1/3) * N^3 + # LAUUM: (1/3) * N^3 + # Total: N^3 + flops = float(N) ** 3 + gflops = flops / (elapsed_ms / 1000.0) / 1e9 + + print("\n" + "=" * 60) + print("Results") + print("=" * 60) + print(f"[POTRI] Elapsed time: {elapsed_ms:.2f} ms") + print(f"[POTRI] Performance: {gflops:.2f} GFLOPS") + + if check_result: + residual = res_norm / b_norm + print(f"\n[POTRI] ||A * (A^(-1) * B) - B||: {res_norm:.6e}") + print(f"[POTRI] ||B||: {b_norm:.6e}") + print(f"[POTRI] Residual (||A * (A^(-1) * B) - B||/||B||): {residual:.6e}") + + if residual < 0.01: + print("\n✅ Algorithm converged successfully!") + return 0 + else: + print(f"\n❌ Algorithm did not converge (residual {residual:.6e} >= 0.01)") + return 1 + + print("=" * 60) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/cuda_cccl/tests/stf/test_context.py b/python/cuda_cccl/tests/stf/test_context.py new file mode 100644 index 00000000000..451c44aadb8 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_context.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numpy as np + +from cuda.stf._stf_bindings import context, read, rw + + +def test_ctx(): + ctx = context() + del ctx + + +def test_graph_ctx(): + ctx = context(use_graph=True) + ctx.finalize() + + +def test_ctx2(): + X = np.ones(16, dtype=np.float32) + Y = np.ones(16, dtype=np.float32) + Z = np.ones(16, dtype=np.float32) + + ctx = context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + t = ctx.task(rw(lX)) + t.start() + t.end() + + t2 = ctx.task(read(lX), rw(lY)) + t2.start() + t2.end() + + t3 = ctx.task(read(lX), rw(lZ)) + t3.start() + t3.end() + + t4 = ctx.task(read(lY), rw(lZ)) + t4.start() + t4.end() + + del ctx + + +def test_ctx3(): + X = np.ones(16, dtype=np.float32) + Y = np.ones(16, dtype=np.float32) + Z = np.ones(16, dtype=np.float32) + + ctx = context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + with ctx.task(rw(lX)): + pass + + with ctx.task(read(lX), rw(lY)): + pass + + with ctx.task(read(lX), rw(lZ)): + pass + + with ctx.task(read(lY), rw(lZ)): + pass + + del ctx + + +if __name__ == "__main__": + test_ctx3() diff --git a/python/cuda_cccl/tests/stf/test_decorator.py b/python/cuda_cccl/tests/stf/test_decorator.py new file mode 100644 index 00000000000..ce8fad1d69b --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_decorator.py @@ -0,0 +1,45 @@ +import numba +import numpy as np +import pytest +from numba import cuda + +import cuda.stf as stf + +numba.cuda.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +@stf.jit +def axpy(a, x, y): + i = cuda.grid(1) + if i < x.size: + y[i] = a * x[i] + y[i] + + +@stf.jit +def scale(a, x): + i = cuda.grid(1) + if i < x.size: + x[i] = a * x[i] + + +@pytest.mark.parametrize("use_graph", [True, False]) +def test_decorator(use_graph): + X, Y, Z = (np.ones(16, np.float32) for _ in range(3)) + + ctx = stf.context(use_graph=use_graph) + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + scale[32, 64](2.0, lX.rw()) + axpy[32, 64](2.0, lX.read(), lY.rw()) + axpy[32, 64, stf.exec_place.device(0)]( + 2.0, lX.read(), lZ.rw() + ) # explicit exec place + axpy[32, 64]( + 2.0, lY.read(), lZ.rw(stf.data_place.device(0)) + ) # per-dep placement override + + +if __name__ == "__main__": + test_decorator(False) diff --git a/python/cuda_cccl/tests/stf/test_fdtd_pytorch.py b/python/cuda_cccl/tests/stf/test_fdtd_pytorch.py new file mode 100644 index 00000000000..d550caba060 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_fdtd_pytorch.py @@ -0,0 +1,229 @@ +import math +from typing import Literal, Optional, Tuple + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +import torch.cuda as tc + +from cuda.stf._stf_bindings import ( + context, +) + +try: + import matplotlib.pyplot as plt + + has_matplotlib = True +except ImportError: + has_matplotlib = False + +Plane = Literal["xy", "xz", "yz"] + + +def show_slice(t3d, plane="xy", index=None): + """Display a 2D slice of a 3D tensor (requires matplotlib).""" + if not has_matplotlib: + return + + # grab a 2D view + if plane == "xy": + idx = t3d.shape[2] // 2 if index is None else index + slice2d = t3d[:, :, idx] + elif plane == "xz": + idx = t3d.shape[1] // 2 if index is None else index + slice2d = t3d[:, idx, :] + elif plane == "yz": + idx = t3d.shape[0] // 2 if index is None else index + slice2d = t3d[idx, :, :] + else: + raise ValueError("plane must be 'xy', 'xz' or 'yz'") + + # move to cpu numpy array + arr = slice2d.detach().cpu().numpy() + + # imshow = "imshow" not "imread" + plt.imshow( + arr, + origin="lower", + cmap="seismic", + vmin=-1e-2, + vmax=1e-2, + # norm=SymLogNorm(linthresh=1e-8, vmin=-1e-0, vmax=1e-0) + # norm=LogNorm(vmin=1e-12, vmax=1e-6) + ) + # plt.colorbar() + plt.show(block=False) + plt.pause(0.01) + + +def test_fdtd_3d_pytorch( + size_x: int = 150, + size_y: int = 150, + size_z: int = 150, + timesteps: int = 10, + output_freq: int = 0, + dx: float = 0.01, + dy: float = 0.01, + dz: float = 0.01, + epsilon0: float = 8.85e-12, + mu0: float = 1.256e-6, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float64, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + ctx = context() + + # allocate and initialize fields + shape = (size_x, size_y, size_z) + + # Electric field components (initialized to zero) + lex = ctx.logical_data_zeros(shape, dtype=np.float64) + ley = ctx.logical_data_zeros(shape, dtype=np.float64) + lez = ctx.logical_data_zeros(shape, dtype=np.float64) + + # Magnetic field components (initialized to zero) + lhx = ctx.logical_data_zeros(shape, dtype=np.float64) + lhy = ctx.logical_data_zeros(shape, dtype=np.float64) + lhz = ctx.logical_data_zeros(shape, dtype=np.float64) + + # Material properties + lepsilon = ctx.logical_data_full(shape, float(epsilon0), dtype=np.float64) + lmu = ctx.logical_data_full(shape, float(mu0), dtype=np.float64) + + # CFL (same formula as example) + dt = 0.25 * min(dx, dy, dz) * math.sqrt(epsilon0 * mu0) + + # Es (interior) = [1..N-2] along all dims -> enables i-1, j-1, k-1 + i_es, j_es, k_es = slice(1, -1), slice(1, -1), slice(1, -1) + i_es_m, j_es_m, k_es_m = slice(0, -2), slice(0, -2), slice(0, -2) + + # Hs (base) = [0..N-2] along all dims -> enables i+1, j+1, k+1 + i_hs, j_hs, k_hs = slice(0, -1), slice(0, -1), slice(0, -1) + i_hs_p, j_hs_p, k_hs_p = slice(1, None), slice(1, None), slice(1, None) + + # source location (single cell at center) + cx, cy, cz = size_x // 2, size_y // 10, size_z // 2 + + def source(t: float, x: float, y: float, z: float) -> float: + # sin(k*x - omega*t) with f = 1e9 Hz + pi = math.pi + freq = 1.0e9 + omega = 2.0 * pi * freq + wavelength = 3.0e8 / freq + k = 2.0 * pi / wavelength + return math.sin(k * x - omega * t) + + for n in range(int(timesteps)): + # ------------------------- + # update electric fields (Es) + # Ex(i,j,k) += (dt/(ε*dx)) * [(Hz(i,j,k)-Hz(i,j-1,k)) - (Hy(i,j,k)-Hy(i,j,k-1))] + with ( + ctx.task(lex.rw(), lhy.read(), lhz.read(), lepsilon.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + ex, hy, hz, epsilon = t.tensor_arguments() + ex[i_es, j_es, k_es] = ex[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dx) + ) * ( + (hz[i_es, j_es, k_es] - hz[i_es, j_es_m, k_es]) + - (hy[i_es, j_es, k_es] - hy[i_es, j_es, k_es_m]) + ) + + # Ey(i,j,k) += (dt/(ε*dy)) * [(Hx(i,j,k)-Hx(i,j,k-1)) - (Hz(i,j,k)-Hz(i-1,j,k))] + with ( + ctx.task(ley.rw(), lhx.read(), lhz.read(), lepsilon.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + ey, hx, hz, epsilon = t.tensor_arguments() + ey[i_es, j_es, k_es] = ey[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dy) + ) * ( + (hx[i_es, j_es, k_es] - hx[i_es, j_es, k_es_m]) + - (hz[i_es, j_es, k_es] - hz[i_es_m, j_es, k_es]) + ) + + # Ez(i,j,k) += (dt/(ε*dz)) * [(Hy(i,j,k)-Hy(i-1,j,k)) - (Hx(i,j,k)-Hx(i,j-1,k))] + with ( + ctx.task(lez.rw(), lhx.read(), lhy.read(), lepsilon.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + ez, hx, hy, epsilon = t.tensor_arguments() + ez[i_es, j_es, k_es] = ez[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dz) + ) * ( + (hy[i_es, j_es, k_es] - hy[i_es_m, j_es, k_es]) + - (hx[i_es, j_es, k_es] - hx[i_es, j_es_m, k_es]) + ) + + # source at center cell + with ( + ctx.task(lez.rw()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + ez = t.tensor_arguments() + ez[cx, cy, cz] = ez[cx, cy, cz] + source(n * dt, cx * dx, cy * dy, cz * dz) + + # ------------------------- + # update magnetic fields (Hs) + # Hx(i,j,k) -= (dt/(μ*dy)) * [(Ez(i,j+1,k)-Ez(i,j,k)) - (Ey(i,j,k+1)-Ey(i,j,k))] + with ( + ctx.task(lhx.rw(), ley.read(), lez.read(), lmu.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + hx, ey, ez, mu = t.tensor_arguments() + hx[i_hs, j_hs, k_hs] = hx[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dy) + ) * ( + (ez[i_hs, j_hs_p, k_hs] - ez[i_hs, j_hs, k_hs]) + - (ey[i_hs, j_hs, k_hs_p] - ey[i_hs, j_hs, k_hs]) + ) + + # Hy(i,j,k) -= (dt/(μ*dz)) * [(Ex(i,j,k+1)-Ex(i,j,k)) - (Ez(i+1,j,k)-Ez(i,j,k))] + with ( + ctx.task(lhy.rw(), lex.read(), lez.read(), lmu.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + hy, ex, ez, mu = t.tensor_arguments() + hy[i_hs, j_hs, k_hs] = hy[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dz) + ) * ( + (ex[i_hs, j_hs, k_hs_p] - ex[i_hs, j_hs, k_hs]) + - (ez[i_hs_p, j_hs, k_hs] - ez[i_hs, j_hs, k_hs]) + ) + + # Hz(i,j,k) -= (dt/(μ*dx)) * [(Ey(i+1,j,k)-Ey(i,j,k)) - (Ex(i,j+1,k)-Ex(i,j,k))] + with ( + ctx.task(lhz.rw(), lex.read(), ley.read(), lmu.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + hz, ex, ey, mu = t.tensor_arguments() + hz[i_hs, j_hs, k_hs] = hz[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dx) + ) * ( + (ey[i_hs_p, j_hs, k_hs] - ey[i_hs, j_hs, k_hs]) + - (ex[i_hs, j_hs_p, k_hs] - ex[i_hs, j_hs, k_hs]) + ) + + if output_freq > 0 and (n % output_freq) == 0: + with ( + ctx.task(lez.read()) as t, + tc.stream(tc.ExternalStream(t.stream_ptr())), + ): + ez = t.tensor_arguments() + print(f"{n}\t{ez[cx, cy, cz].item():.6e}") + if has_matplotlib: + show_slice(ez, plane="xy") + pass + + ctx.finalize() + + +if __name__ == "__main__": + # Run FDTD simulation + output_freq = 5 if has_matplotlib else 0 + if not has_matplotlib and output_freq > 0: + print("Warning: matplotlib not available, running without visualization") + output_freq = 0 + test_fdtd_3d_pytorch(timesteps=1000, output_freq=output_freq) diff --git a/python/cuda_cccl/tests/stf/test_fdtd_pytorch_simplified.py b/python/cuda_cccl/tests/stf/test_fdtd_pytorch_simplified.py new file mode 100644 index 00000000000..b786552b6b3 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_fdtd_pytorch_simplified.py @@ -0,0 +1,229 @@ +import math +from typing import Literal, Optional, Tuple + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from cuda.stf._stf_bindings import ( + context, +) + +try: + import matplotlib.pyplot as plt + + has_matplotlib = True +except ImportError: + has_matplotlib = False + +Plane = Literal["xy", "xz", "yz"] + + +def show_slice(t3d, plane="xy", index=None): + """Display a 2D slice of a 3D tensor (requires matplotlib).""" + if not has_matplotlib: + return + + # grab a 2D view + if plane == "xy": + idx = t3d.shape[2] // 2 if index is None else index + slice2d = t3d[:, :, idx] + elif plane == "xz": + idx = t3d.shape[1] // 2 if index is None else index + slice2d = t3d[:, idx, :] + elif plane == "yz": + idx = t3d.shape[0] // 2 if index is None else index + slice2d = t3d[idx, :, :] + else: + raise ValueError("plane must be 'xy', 'xz' or 'yz'") + + # move to cpu numpy array + arr = slice2d.detach().cpu().numpy() + + # imshow = "imshow" not "imread" + plt.imshow( + arr, + origin="lower", + cmap="seismic", + vmin=-1e-2, + vmax=1e-2, + # norm=SymLogNorm(linthresh=1e-8, vmin=-1e-0, vmax=1e-0) + # norm=LogNorm(vmin=1e-12, vmax=1e-6) + ) + # plt.colorbar() + plt.show(block=False) + plt.pause(0.01) + + +def test_fdtd_3d_pytorch_simplified( + size_x: int = 150, + size_y: int = 150, + size_z: int = 150, + timesteps: int = 10, + output_freq: int = 0, + dx: float = 0.01, + dy: float = 0.01, + dz: float = 0.01, + epsilon0: float = 8.85e-12, + mu0: float = 1.256e-6, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float64, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + """ + FDTD 3D implementation using pytorch_task for simplified syntax. + Demonstrates automatic stream and tensor management. + """ + ctx = context() + + # allocate and initialize fields + shape = (size_x, size_y, size_z) + + # Electric field components (initialized to zero) + lex = ctx.logical_data_zeros(shape, dtype=np.float64) + ley = ctx.logical_data_zeros(shape, dtype=np.float64) + lez = ctx.logical_data_zeros(shape, dtype=np.float64) + + # Magnetic field components (initialized to zero) + lhx = ctx.logical_data_zeros(shape, dtype=np.float64) + lhy = ctx.logical_data_zeros(shape, dtype=np.float64) + lhz = ctx.logical_data_zeros(shape, dtype=np.float64) + + # Material properties + lepsilon = ctx.logical_data_full(shape, float(epsilon0), dtype=np.float64) + lmu = ctx.logical_data_full(shape, float(mu0), dtype=np.float64) + + # CFL (same formula as example) + dt = 0.25 * min(dx, dy, dz) * math.sqrt(epsilon0 * mu0) + + # Es (interior) = [1..N-2] along all dims -> enables i-1, j-1, k-1 + i_es, j_es, k_es = slice(1, -1), slice(1, -1), slice(1, -1) + i_es_m, j_es_m, k_es_m = slice(0, -2), slice(0, -2), slice(0, -2) + + # Hs (base) = [0..N-2] along all dims -> enables i+1, j+1, k+1 + i_hs, j_hs, k_hs = slice(0, -1), slice(0, -1), slice(0, -1) + i_hs_p, j_hs_p, k_hs_p = slice(1, None), slice(1, None), slice(1, None) + + # source location (single cell at center) + cx, cy, cz = size_x // 2, size_y // 10, size_z // 2 + + def source(t: float, x: float, y: float, z: float) -> float: + # sin(k*x - omega*t) with f = 1e9 Hz + pi = math.pi + freq = 1.0e9 + omega = 2.0 * pi * freq + wavelength = 3.0e8 / freq + k = 2.0 * pi / wavelength + return math.sin(k * x - omega * t) + + for n in range(int(timesteps)): + # ------------------------- + # update electric fields (Es) + # Ex(i,j,k) += (dt/(ε*dx)) * [(Hz(i,j,k)-Hz(i,j-1,k)) - (Hy(i,j,k)-Hy(i,j,k-1))] + with ctx.pytorch_task(lex.rw(), lhy.read(), lhz.read(), lepsilon.read()) as ( + ex, + hy, + hz, + epsilon, + ): + ex[i_es, j_es, k_es] = ex[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dx) + ) * ( + (hz[i_es, j_es, k_es] - hz[i_es, j_es_m, k_es]) + - (hy[i_es, j_es, k_es] - hy[i_es, j_es, k_es_m]) + ) + + # Ey(i,j,k) += (dt/(ε*dy)) * [(Hx(i,j,k)-Hx(i,j,k-1)) - (Hz(i,j,k)-Hz(i-1,j,k))] + with ctx.pytorch_task(ley.rw(), lhx.read(), lhz.read(), lepsilon.read()) as ( + ey, + hx, + hz, + epsilon, + ): + ey[i_es, j_es, k_es] = ey[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dy) + ) * ( + (hx[i_es, j_es, k_es] - hx[i_es, j_es, k_es_m]) + - (hz[i_es, j_es, k_es] - hz[i_es_m, j_es, k_es]) + ) + + # Ez(i,j,k) += (dt/(ε*dz)) * [(Hy(i,j,k)-Hy(i-1,j,k)) - (Hx(i,j,k)-Hx(i,j-1,k))] + with ctx.pytorch_task(lez.rw(), lhx.read(), lhy.read(), lepsilon.read()) as ( + ez, + hx, + hy, + epsilon, + ): + ez[i_es, j_es, k_es] = ez[i_es, j_es, k_es] + ( + dt / (epsilon[i_es, j_es, k_es] * dz) + ) * ( + (hy[i_es, j_es, k_es] - hy[i_es_m, j_es, k_es]) + - (hx[i_es, j_es, k_es] - hx[i_es, j_es_m, k_es]) + ) + + # source at center cell + with ctx.pytorch_task(lez.rw()) as (ez,): + ez[cx, cy, cz] = ez[cx, cy, cz] + source(n * dt, cx * dx, cy * dy, cz * dz) + + # ------------------------- + # update magnetic fields (Hs) + # Hx(i,j,k) -= (dt/(μ*dy)) * [(Ez(i,j+1,k)-Ez(i,j,k)) - (Ey(i,j,k+1)-Ey(i,j,k))] + with ctx.pytorch_task(lhx.rw(), ley.read(), lez.read(), lmu.read()) as ( + hx, + ey, + ez, + mu, + ): + hx[i_hs, j_hs, k_hs] = hx[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dy) + ) * ( + (ez[i_hs, j_hs_p, k_hs] - ez[i_hs, j_hs, k_hs]) + - (ey[i_hs, j_hs, k_hs_p] - ey[i_hs, j_hs, k_hs]) + ) + + # Hy(i,j,k) -= (dt/(μ*dz)) * [(Ex(i,j,k+1)-Ex(i,j,k)) - (Ez(i+1,j,k)-Ez(i,j,k))] + with ctx.pytorch_task(lhy.rw(), lex.read(), lez.read(), lmu.read()) as ( + hy, + ex, + ez, + mu, + ): + hy[i_hs, j_hs, k_hs] = hy[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dz) + ) * ( + (ex[i_hs, j_hs, k_hs_p] - ex[i_hs, j_hs, k_hs]) + - (ez[i_hs_p, j_hs, k_hs] - ez[i_hs, j_hs, k_hs]) + ) + + # Hz(i,j,k) -= (dt/(μ*dx)) * [(Ey(i+1,j,k)-Ey(i,j,k)) - (Ex(i,j+1,k)-Ex(i,j,k))] + with ctx.pytorch_task(lhz.rw(), lex.read(), ley.read(), lmu.read()) as ( + hz, + ex, + ey, + mu, + ): + hz[i_hs, j_hs, k_hs] = hz[i_hs, j_hs, k_hs] - ( + dt / (mu[i_hs, j_hs, k_hs] * dx) + ) * ( + (ey[i_hs_p, j_hs, k_hs] - ey[i_hs, j_hs, k_hs]) + - (ex[i_hs, j_hs_p, k_hs] - ex[i_hs, j_hs, k_hs]) + ) + + if output_freq > 0 and (n % output_freq) == 0: + with ctx.pytorch_task(lez.read()) as (ez,): + print(f"{n}\t{ez[cx, cy, cz].item():.6e}") + if has_matplotlib: + show_slice(ez, plane="xy") + + ctx.finalize() + + +if __name__ == "__main__": + # Run simplified FDTD simulation using pytorch_task + output_freq = 5 if has_matplotlib else 0 + if not has_matplotlib and output_freq > 0: + print("Warning: matplotlib not available, running without visualization") + output_freq = 0 + test_fdtd_3d_pytorch_simplified(timesteps=1000, output_freq=output_freq) diff --git a/python/cuda_cccl/tests/stf/test_fhe.py b/python/cuda_cccl/tests/stf/test_fhe.py new file mode 100644 index 00000000000..b2bb9961b84 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_fhe.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Toy Fully Homomorphic Encryption (FHE) example with addition encryption + +import numba +from numba import cuda + +import cuda.stf as stf + +numba.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +class Plaintext: + def __init__(self, ctx, values=None, ld=None, key=0x42): + self.ctx = ctx + self.key = key + if ld is not None: + self.l = ld + if values is not None: + self.values = bytearray(values) + self.l = ctx.logical_data(self.values) + self.symbol = None + + def set_symbol(self, symbol: str): + self.l.set_symbol(symbol) + self.symbol = symbol + + def encrypt(self) -> "Ciphertext": + encrypted = bytearray([(c + self.key) & 0xFF for c in self.values]) + return Ciphertext(self.ctx, values=encrypted, key=self.key) + + def print_values(self): + with ctx.task( + stf.exec_place.host(), self.l.read(stf.data_place.managed()) + ) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + nb_stream.synchronize() + hvalues = t.numba_arguments() + print([v for v in hvalues]) + + +@cuda.jit +def add_kernel(a, b, out): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] + b[i]) & 0xFF + + +@cuda.jit +def sub_kernel(a, b, out): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] - b[i]) & 0xFF + + +@cuda.jit +def sub_scalar_kernel(a, out, v): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] - v) & 0xFF + + +class Ciphertext: + def __init__(self, ctx, values=None, ld=None, key=0x42): + self.ctx = ctx + self.key = key + if ld is not None: + self.l = ld + if values is not None: + self.values = bytearray(values) + self.l = ctx.logical_data(self.values) + self.symbol = None + + def __add__(self, other): + if not isinstance(other, Ciphertext): + return NotImplemented + result = self.like_empty() + with ctx.task(self.l.read(), other.l.read(), result.l.write()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + da, db, dresult = t.numba_arguments() + add_kernel[32, 16, nb_stream](da, db, dresult) + return result + + def __sub__(self, other): + if not isinstance(other, Ciphertext): + return NotImplemented + result = self.like_empty() + with ctx.task(self.l.read(), other.l.read(), result.l.write()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + da, db, dresult = t.numba_arguments() + sub_kernel[32, 16, nb_stream](da, db, dresult) + return result + + def set_symbol(self, symbol: str): + self.l.set_symbol(symbol) + self.symbol = symbol + + def decrypt(self, num_operands=2): + """Decrypt by subtracting num_operands * key""" + result = self.like_empty() + total_key = (num_operands * self.key) & 0xFF + with ctx.task(self.l.read(), result.l.write()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + da, dresult = t.numba_arguments() + sub_scalar_kernel[32, 16, nb_stream](da, dresult, total_key) + return Plaintext(self.ctx, ld=result.l, key=self.key) + + def like_empty(self): + return Ciphertext(self.ctx, ld=self.l.like_empty()) + + +def circuit(a, b): + """Circuit: (A + B) + (B - A) = 2*B""" + return (a + b) + (b - a) + + +def test_fhe(): + """Test FHE using manual task creation with addition encryption.""" + global ctx + ctx = stf.context(use_graph=False) + + vA = [3, 3, 2, 2, 17] + pA = Plaintext(ctx, vA) + pA.set_symbol("A") + + vB = [1, 7, 7, 7, 49] + pB = Plaintext(ctx, vB) + pB.set_symbol("B") + + expected = [circuit(a, b) & 0xFF for a, b in zip(vA, vB)] + + eA = pA.encrypt() + eB = pB.encrypt() + encrypted_out = circuit(eA, eB) + decrypted_out = encrypted_out.decrypt(num_operands=2) + + with ctx.task( + stf.exec_place.host(), decrypted_out.l.read(stf.data_place.managed()) + ) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + nb_stream.synchronize() + hvalues = t.numba_arguments() + actual = [int(v) for v in hvalues] + + ctx.finalize() + + assert actual == expected, ( + f"Decrypted result {actual} doesn't match expected {expected}" + ) + + +if __name__ == "__main__": + test_fhe() diff --git a/python/cuda_cccl/tests/stf/test_fhe_decorator.py b/python/cuda_cccl/tests/stf/test_fhe_decorator.py new file mode 100644 index 00000000000..980f7735ddc --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_fhe_decorator.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Toy Fully Homomorphic Encryption (FHE) example with addition encryption + +import numba +from numba import cuda + +import cuda.stf as cudastf + +numba.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +class Plaintext: + def __init__(self, ctx, values=None, ld=None, key=0x42): + self.ctx = ctx + self.key = key + if ld is not None: + self.l = ld + if values is not None: + self.values = bytearray(values) + self.l = ctx.logical_data(self.values) + self.symbol = None + + def set_symbol(self, symbol: str): + self.l.set_symbol(symbol) + self.symbol = symbol + + def encrypt(self) -> "Ciphertext": + encrypted = bytearray([(c + self.key) & 0xFF for c in self.values]) + return Ciphertext(self.ctx, values=encrypted, key=self.key) + + def print_values(self): + with ctx.task( + cudastf.exec_place.host(), self.l.read(cudastf.data_place.managed()) + ) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + nb_stream.synchronize() + hvalues = t.numba_arguments() + print([v for v in hvalues]) + + +@cudastf.jit +def add_kernel(a, b, out): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] + b[i]) & 0xFF + + +@cudastf.jit +def sub_kernel(a, b, out): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] - b[i]) & 0xFF + + +@cudastf.jit +def sub_scalar_kernel(a, out, v): + i = cuda.grid(1) + if i < out.size: + out[i] = (a[i] - v) & 0xFF + + +class Ciphertext: + def __init__(self, ctx, values=None, ld=None, key=0x42): + self.ctx = ctx + self.key = key + if ld is not None: + self.l = ld + if values is not None: + self.values = bytearray(values) + self.l = ctx.logical_data(self.values) + self.symbol = None + + def __add__(self, other): + if not isinstance(other, Ciphertext): + return NotImplemented + result = self.like_empty() + add_kernel[32, 16](self.l.read(), other.l.read(), result.l.write()) + return result + + def __sub__(self, other): + if not isinstance(other, Ciphertext): + return NotImplemented + result = self.like_empty() + sub_kernel[32, 16](self.l.read(), other.l.read(), result.l.write()) + return result + + def set_symbol(self, symbol: str): + self.l.set_symbol(symbol) + self.symbol = symbol + + def decrypt(self, num_operands=2): + """Decrypt by subtracting num_operands * key""" + result = self.like_empty() + total_key = (num_operands * self.key) & 0xFF + sub_scalar_kernel[32, 16](self.l.read(), result.l.write(), total_key) + return Plaintext(self.ctx, ld=result.l, key=self.key) + + def like_empty(self): + return Ciphertext(self.ctx, ld=self.l.like_empty()) + + +def circuit(a, b): + """Circuit: (A + B) + (B - A) = 2*B""" + return (a + b) + (b - a) + + +def test_fhe_decorator(): + """Test FHE using @cudastf.jit decorators with addition encryption.""" + global ctx + ctx = cudastf.context(use_graph=False) + + vA = [3, 3, 2, 2, 17] + pA = Plaintext(ctx, vA) + pA.set_symbol("A") + + vB = [1, 7, 7, 7, 49] + pB = Plaintext(ctx, vB) + pB.set_symbol("B") + + expected = [circuit(a, b) & 0xFF for a, b in zip(vA, vB)] + + eA = pA.encrypt() + eB = pB.encrypt() + encrypted_out = circuit(eA, eB) + decrypted_out = encrypted_out.decrypt(num_operands=2) + + with ctx.task( + cudastf.exec_place.host(), decrypted_out.l.read(cudastf.data_place.managed()) + ) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + nb_stream.synchronize() + hvalues = t.numba_arguments() + actual = [int(v) for v in hvalues] + + ctx.finalize() + + assert actual == expected, ( + f"Decrypted result {actual} doesn't match expected {expected}" + ) + + +if __name__ == "__main__": + test_fhe_decorator() diff --git a/python/cuda_cccl/tests/stf/test_numba.py b/python/cuda_cccl/tests/stf/test_numba.py new file mode 100644 index 00000000000..bd818e13894 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_numba.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import numba +import numpy as np +import pytest +from numba import cuda + +import cuda.stf as stf + +numba.cuda.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +@cuda.jit +def axpy(a, x, y): + i = cuda.grid(1) + if i < x.size: + y[i] = a * x[i] + y[i] + + +@cuda.jit +def scale(a, x): + i = cuda.grid(1) + if i < x.size: + x[i] = a * x[i] + + +# One test with a single kernel in a CUDA graph +def test_numba_graph(): + X = np.ones(16, dtype=np.float32) + ctx = stf.context(use_graph=True) + lX = ctx.logical_data(X) + with ctx.task(lX.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.numba_arguments() + scale[32, 64, nb_stream](2.0, dX) + + ctx.finalize() + + # Verify results after finalize (data written back to host) + # Expected: scale(2.0, 1.0) = 2.0 + assert np.allclose(X, 2.0) + + +def test_numba(): + n = 1024 * 1024 + X = np.ones(n, dtype=np.float32) + Y = np.ones(n, dtype=np.float32) + Z = np.ones(n, dtype=np.float32) + + ctx = stf.context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + threads_per_block = 256 + blocks = (n + threads_per_block - 1) // threads_per_block + + with ctx.task(lX.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.numba_arguments() + scale[blocks, threads_per_block, nb_stream](2.0, dX) + + with ctx.task(lX.read(), lY.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dY = t.get_arg_numba(1) + axpy[blocks, threads_per_block, nb_stream](2.0, dX, dY) + + with ctx.task(lX.read(), lZ.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX, dZ = t.numba_arguments() + axpy[blocks, threads_per_block, nb_stream](2.0, dX, dZ) + + with ctx.task(lY.read(), lZ.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dY, dZ = t.numba_arguments() + axpy[blocks, threads_per_block, nb_stream](2.0, dY, dZ) + + ctx.finalize() + + # Verify results after finalize (data written back to host) + # Expected values: + # X: scale(2.0, 1.0) = 2.0 + # Y: axpy(2.0, X=2.0, Y=1.0) = 2.0*2.0 + 1.0 = 5.0 + # Z: axpy(2.0, X=2.0, Z=1.0) = 5.0, then axpy(2.0, Y=5.0, Z=5.0) = 15.0 + assert np.allclose(X, 2.0) + assert np.allclose(Y, 5.0) + assert np.allclose(Z, 15.0) + + +@cuda.jit +def laplacian_5pt_kernel(u_in, u_out, dx, dy): + """ + Compute a 5-point Laplacian on u_in and write the result to u_out. + + Grid-stride 2-D kernel. Assumes C-contiguous (row-major) inputs. + Boundary cells are copied unchanged. + """ + coef_x = 1.0 / (dx * dx) + coef_y = 1.0 / (dy * dy) + + i, j = cuda.grid(2) # i <-> row (x-index), j <-> col (y-index) + nx, ny = u_in.shape + + if i >= nx or j >= ny: + return # out-of-bounds threads do nothing + + if 0 < i < nx - 1 and 0 < j < ny - 1: + u_out[i, j] = (u_in[i - 1, j] - 2.0 * u_in[i, j] + u_in[i + 1, j]) * coef_x + ( + u_in[i, j - 1] - 2.0 * u_in[i, j] + u_in[i, j + 1] + ) * coef_y + else: + # simple Dirichlet/Neumann placeholder: copy input to output + u_out[i, j] = u_in[i, j] + + +def test_numba2d(): + nx, ny = 1024, 1024 + dx = 2.0 * np.pi / (nx - 1) + dy = 2.0 * np.pi / (ny - 1) + + # a smooth test field: f(x,y) = sin(x) * cos(y) + x = np.linspace(0, 2 * np.pi, nx, dtype=np.float64) + y = np.linspace(0, 2 * np.pi, ny, dtype=np.float64) + + u = np.sin(x)[:, None] * np.cos(y)[None, :] # shape = (nx, ny) + u_out = np.zeros_like(u) + + ctx = stf.context() + lu = ctx.logical_data(u) + lu_out = ctx.logical_data(u_out) + + with ctx.task(lu.read(), lu_out.write()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + du = t.get_arg_numba(0) + du_out = t.get_arg_numba(1) + threads_per_block = (16, 16) # 256 threads per block is a solid starting point + blocks_per_grid = ( + (nx + threads_per_block[0] - 1) // threads_per_block[0], + (ny + threads_per_block[1] - 1) // threads_per_block[1], + ) + laplacian_5pt_kernel[blocks_per_grid, threads_per_block, nb_stream]( + du, du_out, dx, dy + ) + + ctx.finalize() + + u_out_ref = np.zeros_like(u) + + for i in range(1, nx - 1): # skip boundaries + for j in range(1, ny - 1): + u_out_ref[i, j] = (u[i - 1, j] - 2.0 * u[i, j] + u[i + 1, j]) / dx**2 + ( + u[i, j - 1] - 2.0 * u[i, j] + u[i, j + 1] + ) / dy**2 + + # copy boundaries + u_out_ref[0, :] = u[0, :] + u_out_ref[-1, :] = u[-1, :] + u_out_ref[:, 0] = u[:, 0] + u_out_ref[:, -1] = u[:, -1] + + # compare with the GPU result + assert np.allclose(u_out, u_out_ref, rtol=1e-6, atol=1e-6) + + +def test_numba_exec_place(): + X = np.ones(16, dtype=np.float32) + Y = np.ones(16, dtype=np.float32) + Z = np.ones(16, dtype=np.float32) + + ctx = stf.context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + with ctx.task(stf.exec_place.device(0), lX.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + # dX = t.get_arg_numba(0) + dX = cuda.from_cuda_array_interface(t.get_arg_cai(0), owner=None, sync=False) + scale[32, 64, nb_stream](2.0, dX) + + with ctx.task(stf.exec_place.device(0), lX.read(), lY.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dY = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dX, dY) + + with ctx.task( + stf.exec_place.device(0), + lX.read(stf.data_place.managed()), + lZ.rw(stf.data_place.managed()), + ) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dZ = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dX, dZ) + + with ctx.task(stf.exec_place.device(0), lY.read(), lZ.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dY = t.get_arg_numba(0) + dZ = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dY, dZ) + + +def test_numba_places(): + if len(list(cuda.gpus)) < 2: + pytest.skip("Need at least 2 GPUs") + return + + X = np.ones(16, dtype=np.float32) + Y = np.ones(16, dtype=np.float32) + Z = np.ones(16, dtype=np.float32) + + ctx = stf.context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + with ctx.task(lX.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.numba_arguments() + scale[32, 64, nb_stream](2.0, dX) + + with ctx.task(lX.read(), lY.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dY = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dX, dY) + + with ctx.task(stf.exec_place.device(1), lX.read(), lZ.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dZ = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dX, dZ) + + with ctx.task(lY.read(), lZ.rw(stf.data_place.device(1))) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dY = t.get_arg_numba(0) + dZ = t.get_arg_numba(1) + axpy[32, 64, nb_stream](2.0, dY, dZ) + + +if __name__ == "__main__": + test_numba_graph() + # test_numba() diff --git a/python/cuda_cccl/tests/stf/test_pytorch.py b/python/cuda_cccl/tests/stf/test_pytorch.py new file mode 100644 index 00000000000..02a7bc1c1b3 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_pytorch.py @@ -0,0 +1,118 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import numba +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +numba.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + +from cuda.stf._stf_bindings import ( # noqa: E402 + context, + rw, +) + + +def test_pytorch(): + n = 1024 * 1024 + X = np.ones(n, dtype=np.float32) + Y = np.ones(n, dtype=np.float32) + Z = np.ones(n, dtype=np.float32) + + ctx = context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + with ctx.task(rw(lX)) as t: + torch_stream = torch.cuda.ExternalStream(t.stream_ptr()) + with torch.cuda.stream(torch_stream): + tX = t.tensor_arguments() + tX[:] = tX * 2 # In-place multiplication + + with ctx.task(lX.read(), lY.write()) as t: + torch_stream = torch.cuda.ExternalStream(t.stream_ptr()) + with torch.cuda.stream(torch_stream): + tX = t.get_arg_as_tensor(0) + tY = t.get_arg_as_tensor(1) + tY[:] = tX * 2 # Copy result into tY tensor + + with ( + ctx.task(lX.read(), lZ.write()) as t, + torch.cuda.stream(torch.cuda.ExternalStream(t.stream_ptr())), + ): + tX, tZ = t.tensor_arguments() # Get tX and tZ tensors + tZ[:] = tX * 4 + 1 # Copy result into tZ tensor + + with ( + ctx.task(lY.read(), lZ.rw()) as t, + torch.cuda.stream(torch.cuda.ExternalStream(t.stream_ptr())), + ): + tY, tZ = t.tensor_arguments() # Get tY and tZ tensors + tZ[:] = tY * 2 - 3 # Copy result into tZ tensor + + ctx.finalize() + + # Verify results on host after finalize + # Expected values: + # X: 1.0 -> 2.0 (multiplied by 2) + # Y: 1.0 -> 4.0 (X * 2 = 2.0 * 2 = 4.0) + # Z: 1.0 -> 9.0 (X * 4 + 1 = 2.0 * 4 + 1 = 9.0) -> 5.0 (Y * 2 - 3 = 4.0 * 2 - 3 = 5.0) + assert np.allclose(X, 2.0) + assert np.allclose(Y, 4.0) + assert np.allclose(Z, 5.0) + + +def test_pytorch_task(): + """Test the pytorch_task functionality with simplified syntax""" + n = 1024 * 1024 + X = np.ones(n, dtype=np.float32) + Y = np.ones(n, dtype=np.float32) + Z = np.ones(n, dtype=np.float32) + + ctx = context() + + # Note: We could use ctx.logical_data_full instead of creating NumPy arrays first + # For example: lX = ctx.logical_data_full((n,), 1.0, dtype=np.float32) + # However, this would create logical data without underlying NumPy arrays, + # so we wouldn't be able to check results after ctx.finalize() in this test + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + lZ = ctx.logical_data(Z) + + # Equivalent operations to test_pytorch() but using pytorch_task syntax + + # In-place multiplication using pytorch_task (single tensor) + with ctx.pytorch_task(rw(lX)) as (tX,): + tX[:] = tX * 2 + + # Copy and multiply using pytorch_task (multiple tensors) + with ctx.pytorch_task(lX.read(), lY.write()) as (tX, tY): + tY[:] = tX * 2 + + # Another operation combining tensors + with ctx.pytorch_task(lX.read(), lZ.write()) as (tX, tZ): + tZ[:] = tX * 4 + 1 + + # Final operation with read-write access + with ctx.pytorch_task(lY.read(), lZ.rw()) as (tY, tZ): + tZ[:] = tY * 2 - 3 + + ctx.finalize() + + # Verify results on host after finalize (same as original test) + # Expected values: + # X: 1.0 -> 2.0 (multiplied by 2) + # Y: 1.0 -> 4.0 (X * 2 = 2.0 * 2 = 4.0) + # Z: 1.0 -> 9.0 (X * 4 + 1 = 2.0 * 4 + 1 = 9.0) -> 5.0 (Y * 2 - 3 = 4.0 * 2 - 3 = 5.0) + assert np.allclose(X, 2.0) + assert np.allclose(Y, 4.0) + assert np.allclose(Z, 5.0) + + +if __name__ == "__main__": + test_pytorch() diff --git a/python/cuda_cccl/tests/stf/test_stencil_decorator.py b/python/cuda_cccl/tests/stf/test_stencil_decorator.py new file mode 100644 index 00000000000..e8571edeae3 --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_stencil_decorator.py @@ -0,0 +1,83 @@ +import numba +import numpy as np +from numba import cuda + +import cuda.stf as cudastf + +numba.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +@cudastf.jit +def laplacian_5pt_kernel(u_in, u_out, dx, dy): + """ + Compute a 5?~@~Qpoint Laplacian on u_in and write the result to u_out. + + Grid?~@~Qstride 2?~@~QD kernel. Assumes C?~@~Qcontiguous (row?~@~Qmajor) inputs. + Boundary cells are copied unchanged. + """ + coef_x = 1.0 / (dx * dx) + coef_y = 1.0 / (dy * dy) + + i, j = cuda.grid(2) # i ?~F~T row (x?~@~Qindex), j ?~F~T col (y?~@~Qindex) + nx, ny = u_in.shape + + if i >= nx or j >= ny: + return # out?~@~Qof?~@~Qbounds threads do nothing + + if 0 < i < nx - 1 and 0 < j < ny - 1: + u_out[i, j] = (u_in[i - 1, j] - 2.0 * u_in[i, j] + u_in[i + 1, j]) * coef_x + ( + u_in[i, j - 1] - 2.0 * u_in[i, j] + u_in[i, j + 1] + ) * coef_y + else: + # simple Dirichlet/Neumann placeholder: copy input to output + u_out[i, j] = u_in[i, j] + + +def test_numba2d(): + nx, ny = 1024, 1024 + dx = 2.0 * np.pi / (nx - 1) + dy = 2.0 * np.pi / (ny - 1) + + # a smooth test field: f(x,y) = sin(x) * cos(y) + x = np.linspace(0, 2 * np.pi, nx, dtype=np.float64) + y = np.linspace(0, 2 * np.pi, ny, dtype=np.float64) + + u = np.sin(x)[:, None] * np.cos(y)[None, :] # shape = (nx, ny) + u_out = np.zeros_like(u) + + ctx = cudastf.context() + lu = ctx.logical_data(u) + lu_out = ctx.logical_data(u_out) + + threads_per_block = (16, 16) # 256 threads per block is a solid starting point + blocks_per_grid = ( + (nx + threads_per_block[0] - 1) // threads_per_block[0], + (ny + threads_per_block[1] - 1) // threads_per_block[1], + ) + + laplacian_5pt_kernel[blocks_per_grid, threads_per_block]( + lu.read(), lu_out.write(), dx, dy + ) + + ctx.finalize() + + u_out_ref = np.zeros_like(u) + + for i in range(1, nx - 1): # skip boundaries + for j in range(1, ny - 1): + u_out_ref[i, j] = (u[i - 1, j] - 2.0 * u[i, j] + u[i + 1, j]) / dx**2 + ( + u[i, j - 1] - 2.0 * u[i, j] + u[i, j + 1] + ) / dy**2 + + # copy boundaries + u_out_ref[0, :] = u[0, :] + u_out_ref[-1, :] = u[-1, :] + u_out_ref[:, 0] = u[:, 0] + u_out_ref[:, -1] = u[:, -1] + + # compare with the GPU result + assert np.allclose(u_out, u_out_ref, rtol=1e-6, atol=1e-6) + + +if __name__ == "__main__": + test_numba2d() diff --git a/python/cuda_cccl/tests/stf/test_token.py b/python/cuda_cccl/tests/stf/test_token.py new file mode 100644 index 00000000000..abadab8305d --- /dev/null +++ b/python/cuda_cccl/tests/stf/test_token.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numba +import numpy as np +from numba import cuda + +import cuda.stf as stf + +numba.cuda.config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 + + +def test_token(): + ctx = stf.context() + lX = ctx.token() + lY = ctx.token() + lZ = ctx.token() + + with ctx.task(lX.rw()): + pass + + with ctx.task(lX.read(), lY.rw()): + pass + + with ctx.task(lX.read(), lZ.rw()): + pass + + with ctx.task(lY.read(), lZ.rw()): + pass + + ctx.finalize() + + +@cuda.jit +def axpy(a, x, y): + start = cuda.grid(1) + stride = cuda.gridsize(1) + for i in range(start, x.size, stride): + y[i] = a * x[i] + y[i] + + +def test_numba_token(): + n = 1024 * 1024 + X = np.ones(n, dtype=np.float32) + Y = np.ones(n, dtype=np.float32) + + ctx = stf.context() + lX = ctx.logical_data(X) + lY = ctx.logical_data(Y) + token = ctx.token() + + # Use a reasonable grid size - kernel loop will handle all elements + blocks = 32 + threads_per_block = 256 + + with ctx.task(lX.read(), lY.rw(), token.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + dX = t.get_arg_numba(0) + dY = t.get_arg_numba(1) + axpy[blocks, threads_per_block, nb_stream](2.0, dX, dY) + + with ctx.task(lX.read(), lY.rw(), token.rw()) as t: + nb_stream = cuda.external_stream(t.stream_ptr()) + print(nb_stream) + dX, dY = t.numba_arguments() + axpy[blocks, threads_per_block, nb_stream](2.0, dX, dY) + + ctx.finalize() + + # Sanity checks: verify the results after finalize + # First task: Y = 2.0 * X + Y = 2.0 * 1.0 + 1.0 = 3.0 + # Second task: Y = 2.0 * X + Y = 2.0 * 1.0 + 3.0 = 5.0 + assert np.allclose(X, 1.0), f"X should still be 1.0 (read-only), but got {X[0]}" + assert np.allclose(Y, 5.0), ( + f"Y should be 5.0 after two axpy operations, but got {Y[0]}" + ) + + +if __name__ == "__main__": + test_token()