Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numbast/src/numbast/static/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def _try_import_numba_type(cls, typ: str):
cls.Imports.add("from numba.cuda.types import bfloat16")
cls._imported_numba_types.add(typ)

elif typ == "__nv_bfloat16_raw":
cls.Imports.add(
"from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307 as bfloat16_raw_type"
)
cls._imported_numba_types.add(typ)

elif typ in vector_types:
# CUDA target specific types
cls.Imports.add("from numba.cuda.vector_types import vector_types")
Expand Down
8 changes: 8 additions & 0 deletions numbast/src/numbast/static/tests/data/bf16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,12 @@ nv_bfloat16 inline __device__ add(nv_bfloat16 a, nv_bfloat16 b) {
return a + b;
}

__nv_bfloat16_raw inline __device__ bf16_to_raw(nv_bfloat16 a) {
return __nv_bfloat16_raw(a);
}

nv_bfloat16 inline __device__ bf16_from_raw(__nv_bfloat16_raw a) {
return __nv_bfloat16(a);
}

#endif
24 changes: 24 additions & 0 deletions numbast/src/numbast/static/tests/test_bf16_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,27 @@ def kernel(arr):

# Check that bfloat16 is imported
assert "from numba.cuda.types import bfloat16" in res1["src"]


def test_bindings_from_bf16raw(make_binding):
res = make_binding("bf16.cuh", {}, {})

binding = res["bindings"]

bf16_from_raw = binding["bf16_from_raw"]
bf16_to_raw = binding["bf16_to_raw"]

@cuda.jit
def kernel(arr):
x = bfloat16(3.14)

x_raw = bf16_to_raw(x)
x2 = bf16_from_raw(x_raw)

arr[0] = float32(x2)

arr = cuda.device_array((1,), dtype="float32")

kernel[1, 1](arr)

assert pytest.approx(arr[0], 1e-2) == 3.14
4 changes: 4 additions & 0 deletions numbast/src/numbast/static/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def to_numba_type_str(ty: str):
BaseRenderer._try_import_numba_type("__nv_bfloat16")
return "bfloat16"

if ty == "__nv_bfloat16_raw":
BaseRenderer._try_import_numba_type("__nv_bfloat16_raw")
return "bfloat16_raw_type"

if ty.endswith("*"):
base_ty = ty.rstrip("*").rstrip(" ")
ptr_ty_str = f"CPointer({to_numba_type_str(base_ty)})"
Expand Down
6 changes: 4 additions & 2 deletions numbast/src/numbast/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from numba.cuda.vector_types import vector_types
from numba.misc.special import typeof

from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307


class FunctorType(nbtypes.Type):
def __init__(self, name):
Expand Down Expand Up @@ -83,8 +85,8 @@ def register_enum_type(cxx_name: str, e: IntEnum):


def to_numba_type(ty: str):
if ty == "__nv_bfloat16":
return bfloat16
if ty == "__nv_bfloat16_raw":
return _type_unnamed1405307

if "FunctorType" in ty:
return FunctorType(ty[:-11])
Expand Down
Loading