Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numba_cuda/numba/cuda/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _this_grid(typingctx):
sig = signature(grid_group)

def codegen(context, builder, sig, args):
context.active_code_library.use_cooperative = True
one = context.get_constant(types.int32, 1)
mod = builder.module
return builder.call(
Expand All @@ -45,6 +46,7 @@ def _grid_group_sync(typingctx, group):
sig = signature(types.int32, group)

def codegen(context, builder, sig, args):
context.active_code_library.use_cooperative = True
flags = context.get_constant(types.int32, 0)
mod = builder.module
return builder.call(
Expand Down
8 changes: 8 additions & 0 deletions numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(self, codegen, name):
self._setup_functions = []
self._teardown_functions = []

self.use_cooperative = False

@property
def modules(self):
# There are no LLVM IR modules in an ExternalCodeLibrary
Expand Down Expand Up @@ -181,6 +183,8 @@ def __init__(
self._nvvm_options = nvvm_options
self._entry_name = entry_name

self.use_cooperative = False

@property
def llvm_strs(self):
if self._llvm_strs is None:
Expand Down Expand Up @@ -352,6 +356,7 @@ def add_linking_library(self, library):
self._linking_files.update(library._linking_files)
self._setup_functions.extend(library._setup_functions)
self._teardown_functions.extend(library._teardown_functions)
self.use_cooperative |= library.use_cooperative

def add_linking_file(self, path_or_obj):
if isinstance(path_or_obj, LinkableCode):
Expand Down Expand Up @@ -442,6 +447,7 @@ def _reduce_states(self):
nvvm_options=self._nvvm_options,
needs_cudadevrt=self.needs_cudadevrt,
nrt=nrt,
use_cooperative=self.use_cooperative,
)

@classmethod
Expand All @@ -458,6 +464,7 @@ def _rebuild(
nvvm_options,
needs_cudadevrt,
nrt,
use_cooperative,
):
"""
Rebuild an instance.
Expand All @@ -472,6 +479,7 @@ def _rebuild(
instance._max_registers = max_registers
instance._nvvm_options = nvvm_options
instance.needs_cudadevrt = needs_cudadevrt
instance.use_cooperative = use_cooperative

instance._finalized = True
if nrt:
Expand Down
3 changes: 2 additions & 1 deletion numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def compile_ptx_for_current_device(
)


def declare_device_function(name, restype, argtypes, link):
def declare_device_function(name, restype, argtypes, link, use_cooperative):
from .descriptor import cuda_target

typingctx = cuda_target.typing_context
Expand All @@ -816,6 +816,7 @@ def declare_device_function(name, restype, argtypes, link):
lib = ExternalCodeLibrary(f"{name}_externals", targetctx.codegen())
for file in link:
lib.add_linking_file(file)
lib.use_cooperative = use_cooperative

# ExternalFunctionDescriptor provides a lowering implementation for calling
# external functions
Expand Down
7 changes: 6 additions & 1 deletion numba_cuda/numba/cuda/cudadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,11 @@ def _genfp16_binary_operator(op):
def _resolve_wrapped_unary(fname):
link = tuple()
decl = declare_device_function(
f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
f"__numba_wrapper_{fname}",
types.float16,
(types.float16,),
link,
use_cooperative=False,
)
return types.Function(decl)

Expand All @@ -438,6 +442,7 @@ def _resolve_wrapped_binary(fname):
types.float16,
),
link,
use_cooperative=False,
)
return types.Function(decl)

Expand Down
7 changes: 5 additions & 2 deletions numba_cuda/numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def autojitwrapper(func):
return disp


def declare_device(name, sig, link=None):
def declare_device(name, sig, link=None, use_cooperative=False):
"""
Declare the signature of a foreign function. Returns a descriptor that can
be used to call the function from a Python kernel.
Expand All @@ -238,6 +238,7 @@ def declare_device(name, sig, link=None):
:type name: str
:param sig: The Numba signature of the function.
:param link: External code to link when calling the function.
:param use_cooperative: External code requires cooperative launch.
"""
if link is None:
link = tuple()
Expand All @@ -250,6 +251,8 @@ def declare_device(name, sig, link=None):
msg = "Return type must be provided for device declarations"
raise TypeError(msg)

template = declare_device_function(name, restype, argtypes, link)
template = declare_device_function(
name, restype, argtypes, link, use_cooperative
)

return template.key
4 changes: 2 additions & 2 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def __init__(

asm = lib.get_asm_str()

# A kernel needs cooperative launch if grid_sync is being used.
self.cooperative = "cudaCGGetIntrinsicHandle" in asm
# The code library contains functions that require cooperative launch.
self.cooperative = lib.use_cooperative
# We need to link against cudadevrt if grid sync is being used.
if self.cooperative:
lib.needs_cudadevrt = True
Expand Down
12 changes: 0 additions & 12 deletions numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,6 @@ def simple_usecase_kernel(r, x):
simple_usecase_caller = CUDAUseCase(simple_usecase_kernel)


# Usecase with cooperative groups


@cuda.jit(cache=True)
def cg_usecase_kernel(r, x):
grid = cuda.cg.this_grid()
grid.sync()


cg_usecase = CUDAUseCase(cg_usecase_kernel)


class _TestModule(CUDATestCase):
"""
Tests for functionality of this module's functions.
Expand Down
33 changes: 33 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from numba import cuda
from numba.cuda.testing import CUDATestCase
import sys

from numba.cuda.tests.cudapy.cache_usecases import CUDAUseCase


# Usecase with cooperative groups


@cuda.jit(cache=True)
def cg_usecase_kernel(r, x):
grid = cuda.cg.this_grid()
grid.sync()


cg_usecase = CUDAUseCase(cg_usecase_kernel)


class _TestModule(CUDATestCase):
"""
Tests for functionality of this module's functions.
Note this does not define any "test_*" method, instead check_module()
should be called by hand.
"""

def check_module(self, mod):
mod.cg_usecase(0)


def self_test():
mod = sys.modules[__name__]
_TestModule().check_module(mod)
85 changes: 34 additions & 51 deletions numba_cuda/numba/cuda/tests/cudapy/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import multiprocessing
import os
import shutil
import subprocess
import sys
import unittest
import warnings

Expand Down Expand Up @@ -163,55 +161,6 @@ def test_same_names(self):
f = mod.renamed_function2
self.assertPreciseEqual(f(2), 8)

@skip_unless_cc_60
@skip_if_cudadevrt_missing
@skip_if_mvc_enabled("CG not supported with MVC")
def test_cache_cg(self):
# Functions using cooperative groups should be cacheable. See Issue
# #8888: https://github.com/numba/numba/issues/8888
self.check_pycache(0)
mod = self.import_module()
self.check_pycache(0)

mod.cg_usecase(0)
self.check_pycache(2) # 1 index, 1 data

# Check the code runs ok from another process
self.run_in_separate_process()

@skip_unless_cc_60
@skip_if_cudadevrt_missing
@skip_if_mvc_enabled("CG not supported with MVC")
def test_cache_cg_clean_run(self):
# See Issue #9432: https://github.com/numba/numba/issues/9432
# If a cached function using CG sync was the first thing to compile,
# the compile would fail.
self.check_pycache(0)

# This logic is modelled on run_in_separate_process(), but executes the
# CG usecase directly in the subprocess.
code = """if 1:
import sys

sys.path.insert(0, %(tempdir)r)
mod = __import__(%(modname)r)
mod.cg_usecase(0)
""" % dict(tempdir=self.tempdir, modname=self.modname)

popen = subprocess.Popen(
[sys.executable, "-c", code],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = popen.communicate(timeout=60)
if popen.returncode != 0:
raise AssertionError(
"process failed with code %s: \n"
"stdout follows\n%s\n"
"stderr follows\n%s\n"
% (popen.returncode, out.decode(), err.decode()),
)

def _test_pycache_fallback(self):
"""
With a disabled __pycache__, test there is a working fallback
Expand Down Expand Up @@ -275,6 +224,40 @@ def f():
pass


@skip_on_cudasim("Simulator does not implement caching")
class CUDACooperativeGroupTest(SerialMixin, DispatcherCacheUsecasesTest):
# See Issue #9432: https://github.com/numba/numba/issues/9432
# If a cached function using CG sync was the first thing to compile,
# the compile would fail.
here = os.path.dirname(__file__)
usecases_file = os.path.join(here, "cg_cache_usecases.py")
modname = "cuda_cooperative_caching_test_fodder"

def setUp(self):
DispatcherCacheUsecasesTest.setUp(self)
CUDATestCase.setUp(self)

def tearDown(self):
CUDATestCase.tearDown(self)
DispatcherCacheUsecasesTest.tearDown(self)

@skip_unless_cc_60
@skip_if_cudadevrt_missing
@skip_if_mvc_enabled("CG not supported with MVC")
def test_cache_cg(self):
# Functions using cooperative groups should be cacheable. See Issue
# #8888: https://github.com/numba/numba/issues/8888
self.check_pycache(0)
mod = self.import_module()
self.check_pycache(0)

mod.cg_usecase(0)
self.check_pycache(2) # 1 index, 1 data

# Check the code runs ok from another process
self.run_in_separate_process()


@skip_on_cudasim("Simulator does not implement caching")
class CUDAAndCPUCachingTest(SerialMixin, DispatcherCacheUsecasesTest):
here = os.path.dirname(__file__)
Expand Down
34 changes: 34 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import print_function

import os

import cffi

import numpy as np

from numba import config, cuda, int32
from numba.types import CPointer
from numba.cuda.testing import (
unittest,
CUDATestCase,
Expand All @@ -11,6 +16,9 @@
skip_if_cudadevrt_missing,
skip_if_mvc_enabled,
)
from numba.core.typing import signature

ffi = cffi.FFI()


@cuda.jit
Expand Down Expand Up @@ -149,6 +157,32 @@ def test_max_cooperative_grid_blocks(self):
self.assertEqual(blocks1d, blocks2d)
self.assertEqual(blocks1d, blocks3d)

@skip_unless_cc_60
def test_external_cooperative_func(self):
cudapy_test_path = os.path.dirname(__file__)
tests_path = os.path.dirname(cudapy_test_path)
data_path = os.path.join(tests_path, "data")
src = os.path.join(data_path, "cta_barrier.cu")

sig = signature(
CPointer(int32),
)
cta_barrier = cuda.declare_device(
"cta_barrier", sig=sig, link=[src], use_cooperative=True
)

@cuda.jit
def kernel():
cta_barrier()

block_size = 32
grid_size = 1024

kernel[grid_size, block_size]()

overload = kernel.overloads[()]
self.assertTrue(overload.cooperative)


if __name__ == "__main__":
unittest.main()
23 changes: 23 additions & 0 deletions numba_cuda/numba/cuda/tests/data/cta_barrier.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <cooperative_groups.h>
#include <cuda/barrier>

namespace cg = cooperative_groups;

__device__ void _wait_on_tile(cuda::barrier<cuda::thread_scope_block> &tile)
{
auto token = tile.arrive();
tile.wait(std::move(token));
}

extern "C"
__device__ int cta_barrier(int *ret) {
auto cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
__shared__ cuda::barrier<cuda::thread_scope_block> barrier;
if (threadIdx.x == 0) {
init(&barrier, blockDim.x);
}

_wait_on_tile(barrier);
return 0;
}