Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings.path_finder import _load_nvidia_dynamic_library
from cuda.core._version import __version__

""" _load_nvidia_dynamic_library("nvrtc") """
2 changes: 2 additions & 0 deletions cuda_core/docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ CUDA runtime
GraphBuilder
launch
Buffer
Stream
Event
MemoryResource
DeviceMemoryResource
LegacyPinnedMemoryResource
Expand Down
112 changes: 112 additions & 0 deletions cuda_core/docs/source/getting-started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Overview

## What is `cuda core`?

`cuda.core` provides a Pythonic interface to the CUDA runtime and other functionality,
including:

- Compiling and launching CUDA kernels
- Asynchronous concurrent execution with CUDA graphs, streams and events
- Coordinating work across multiple CUDA devices
- Allocating, transfering, and managing device memory
- Runtime linking of device code with Link-Time Optimization (LTO)
- and much more!

Rather than providing 1:1 equivalents of the CUDA driver and runtime APIs
(for that, see [`cuda.bindings`][bindings]), `cuda.core` provides high-level constructs such as:

- {class}`Device <cuda.core.experimental.Device>` class for GPU device operations and context management.
- {class}`Buffer <cuda.core.experimental.Buffer>` and {class}`MemoryResource <cuda.core.experimental.MemoryResource>` classes for memory allocation and management.
- {class}`Program <cuda.core.experimental.Program>` for JIT compilation of CUDA kernels.
- {class}`GraphBuilder <cuda.core.experimental.GraphBuilder>` for building and executing CUDA graphs.
- {class}`Stream <cuda.core.experimental.Stream>` and {class}`Event <cuda.core.experimental.Event>` for asynchronous execution and timing.

## Example: Compiling and Launching a CUDA kernel

To get a taste for `cuda.core`, let's walk through a simple example that compiles and launches a vector addition kernel.
You can find the complete example in [`vector_add.py`][vector_add_example].

First, we define a string containing the CUDA C++ kernel. Note that this is a templated kernel:

```python
# compute c = a + b
code = """
template<typename T>
__global__ void vector_add(const T* A,
const T* B,
T* C,
size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) {
C[tid] = A[tid] + B[tid];
}
}
"""
```

Next, we create a {class}`Device <cuda.core.experimental.Device>` object
and a corresponding {class}`Stream <cuda.core.experimental.Stream>`.
Don't forget to use {meth}`Device.set_current() <cuda.core.experimental.Device.set_current>`!

```python
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch

dev = Device()
dev.set_current()
s = dev.create_stream()
```

Next, we compile the CUDA C++ kernel from earlier using the {class}`Program <cuda.core.experimental.Program>` class.
The result of the compilation is saved as a CUBIN.
Note the use of the `name_expressions` parameter to the {meth}`Program.compile() <cuda.core.experimental.Program.compile>` method to specify which kernel template instantiations to compile:

```python
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin", name_expressions=("vector_add<float>",))
```

Next, we retrieve the compiled kernel from the CUBIN and prepare the arguments and kernel configuration.
We're using [CuPy][cupy] arrays as inputs for this example, but you can use PyTorch tensors too
(we show how to do this in one of our [examples][examples]).

```python
ker = mod.get_kernel("vector_add<float>")

# Prepare input/output arrays (using CuPy)
size = 50000
a = rng.random(size, dtype=cp.float32)
b = rng.random(size, dtype=cp.float32)
c = cp.empty_like(a)

# Configure launch parameters
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)
```

Finally, we use the {func}`launch <cuda.core.experimental.launch>` function to execute our kernel on the specified stream with the given configuration and arguments. Note the use of `.data.ptr` to get the pointer to the array data.

```python
launch(s, config, ker, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
s.sync()
```

This example demonstrates one of the core workflows enabled by `cuda.core`: compiling and launching CUDA code.
Note the clean, Pythonic interface, and absense of any direct calls to the CUDA runtime/driver APIs!

## Examples and Recipes

As we mentioned before, `cuda.core` can do much more than just compile and launch kernels!

The best way to explore and learn the different features `cuda.core` is through
our [`examples`][examples]. Find one that matches your use-case, and modify it to fit your needs!


[bindings]: https://nvidia.github.io/cuda-python/cuda-bindings/latest/
[cai]: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
[cupy]: https://cupy.dev/
[dlpack]: https://dmlc.github.io/dlpack/latest/
[examples]: https://github.com/NVIDIA/cuda-python/tree/main/cuda_core/examples
[vector_add_example]: https://github.com/NVIDIA/cuda-python/tree/main/cuda_core/examples/vector_add.py
21 changes: 14 additions & 7 deletions cuda_core/docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
.. SPDX-License-Identifier: Apache-2.0

``cuda.core``: Pythonic access to CUDA core functionalities
===========================================================
``cuda.core``: Pythonic access to CUDA core functionality
=========================================================

The new Python module ``cuda.core`` offers idiomatic, pythonic access to CUDA runtime
and other functionalities.
Welcome to the documentation for ``cuda.core``.

.. toctree::
:maxdepth: 2
:caption: Contents:

release
install.md
getting-started
install
interoperability
api
contribute
conduct.md

.. toctree::
:maxdepth: 1

conduct
license

.. toctree::
:maxdepth: 2

release

Indices and tables
==================
Expand Down
172 changes: 172 additions & 0 deletions cuda_core/examples/cuda_graphs.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vzhurba01 for vis (a new sample for CUDA graphs)

Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0

# ################################################################################
#
# This demo illustrates how to use CUDA graphs to capture and execute
# multiple kernel launches with minimal overhead. The graph performs a
# sequence of vector operations: add, multiply, and subtract.
#
# ################################################################################

import time

import cupy as cp

from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch


def main():
# CUDA kernels for vector operations
code = """
template<typename T>
__global__ void vector_add(const T* A, const T* B, T* C, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
C[i] = A[i] + B[i];
}
}

template<typename T>
__global__ void vector_multiply(const T* A, const T* B, T* C, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
C[i] = A[i] * B[i];
}
}

template<typename T>
__global__ void vector_subtract(const T* A, const T* B, T* C, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
C[i] = A[i] - B[i];
}
}
"""

# Initialize device and stream
dev = Device()
dev.set_current()
stream = dev.create_stream()
# tell CuPy to use our stream as the current stream:
cp.cuda.ExternalStream(int(stream.handle)).use()

# Compile the program
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile(
"cubin", name_expressions=("vector_add<float>", "vector_multiply<float>", "vector_subtract<float>")
)

# Get kernel functions
add_kernel = mod.get_kernel("vector_add<float>")
multiply_kernel = mod.get_kernel("vector_multiply<float>")
subtract_kernel = mod.get_kernel("vector_subtract<float>")

# Prepare data
size = 1000000
dtype = cp.float32

# Create input arrays
rng = cp.random.default_rng(42) # Fixed seed for reproducibility
a = rng.random(size, dtype=dtype)
b = rng.random(size, dtype=dtype)
c = rng.random(size, dtype=dtype)

# Create output arrays
result1 = cp.empty_like(a)
result2 = cp.empty_like(a)
result3 = cp.empty_like(a)

# Prepare launch configuration
block_size = 256
grid_size = (size + block_size - 1) // block_size
config = LaunchConfig(grid=grid_size, block=block_size)

# Sync before graph capture
dev.sync()

print("Building CUDA graph...")

# Build the graph
graph_builder = stream.create_graph_builder()
graph_builder.begin_building()

# Add multiple kernel launches to the graph
# Kernel 1: result1 = a + b
launch(graph_builder, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size))

# Kernel 2: result2 = result1 * c
launch(graph_builder, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size))

# Kernel 3: result3 = result2 - a
launch(graph_builder, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size))

# Complete the graph
graph = graph_builder.end_building().complete()

print("Graph built successfully!")

# Upload the graph to the stream
graph.upload(stream)

# Execute the entire graph with a single launch
print("Executing graph...")
start_time = time.time()
graph.launch(stream)
stream.sync()
end_time = time.time()

graph_execution_time = end_time - start_time
print(f"Graph execution time: {graph_execution_time:.6f} seconds")

# Verify results
expected_result1 = a + b
expected_result2 = expected_result1 * c
expected_result3 = expected_result2 - a

print("Verifying results...")
assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch"
assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch"
assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch"
print("All results verified successfully!")

# Demonstrate performance benefit by running the same operations without graph
print("\nRunning same operations without graph for comparison...")

# Reset results
result1.fill(0)
result2.fill(0)
result3.fill(0)

start_time = time.time()

# Individual kernel launches
launch(stream, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size))
launch(stream, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size))
launch(stream, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size))

stream.sync()
end_time = time.time()

individual_execution_time = end_time - start_time
print(f"Individual kernel execution time: {individual_execution_time:.6f} seconds")

# Calculate speedup
speedup = individual_execution_time / graph_execution_time
print(f"Graph provides {speedup:.2f}x speedup")

# Verify results again
assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch"
assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch"
assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch"

cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream

print("\nExample completed successfully!")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion cuda_core/examples/jit_lto_fractal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# ################################################################################
#
# This demo aims to illustrate a couple takeaways:
# This demo illustrates:
#
# 1. How to use the JIT LTO feature provided by the Linker class to link multiple objects together
# 2. That linking allows for libraries to modify workflows dynamically at runtime
Expand Down
12 changes: 10 additions & 2 deletions cuda_core/examples/pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

## Usage: pip install "cuda-core[cu12]"
## python python_example.py
# ################################################################################
#
# This demo illustrates how to use `cuda.core` to compile a CUDA kernel
# and launch it using PyTorch tensors as inputs.
#
# ## Usage: pip install "cuda-core[cu12]"
# ## python pytorch_example.py
#
# ################################################################################

import sys

import torch
Expand Down
13 changes: 13 additions & 0 deletions cuda_core/examples/saxpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

# ################################################################################
#
# This demo illustrates how to use `cuda.core` to compile a templated CUDA kernel
# and launch it using `cupy` arrays as inputs. This is a simple example of a
# templated kernel, where the kernel is instantiated for both `float` and `double`
# data types.
#
# ################################################################################

import sys

import cupy as cp
Expand Down Expand Up @@ -32,6 +41,10 @@
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)

# Note the use of the `name_expressions` argument to specify the template
# instantiations of the kernel that we will use. For non-templated kernels,
# `name_expressions` will simply contain the name of the kernels.
mod = prog.compile(
"cubin",
logs=sys.stdout,
Expand Down
Loading