Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #480 from senior-zero/disable_bf16
Browse files Browse the repository at this point in the history
Add option to disable BF16 support
  • Loading branch information
gevtushenko authored Jun 4, 2022
2 parents 5334b27 + 0a30673 commit 1b61fc9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
40 changes: 24 additions & 16 deletions cmake/CubHeaderTesting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@ foreach (header IN LISTS headers)
list(APPEND headertest_srcs "${headertest_src}")
endforeach()

foreach(cub_target IN LISTS CUB_TARGETS)
cub_get_target_property(config_prefix ${cub_target} PREFIX)

set(headertest_target ${config_prefix}.headers)
add_library(${headertest_target} OBJECT ${headertest_srcs})
target_link_libraries(${headertest_target} PUBLIC ${cub_target})
# Wrap Thrust/CUB in a custom namespace to check proper use of ns macros:
target_compile_definitions(${headertest_target} PRIVATE
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
)
cub_clone_target_properties(${headertest_target} ${cub_target})

add_dependencies(cub.all.headers ${headertest_target})
add_dependencies(${config_prefix}.all ${headertest_target})
endforeach()
function(cub_add_header_test label definitions)
foreach(cub_target IN LISTS CUB_TARGETS)
cub_get_target_property(config_prefix ${cub_target} PREFIX)

set(headertest_target ${config_prefix}.headers.${label})
add_library(${headertest_target} OBJECT ${headertest_srcs})
target_link_libraries(${headertest_target} PUBLIC ${cub_target})
target_compile_definitions(${headertest_target} PRIVATE ${definitions})
cub_clone_target_properties(${headertest_target} ${cub_target})

add_dependencies(cub.all.headers ${headertest_target})
add_dependencies(${config_prefix}.all ${headertest_target})
endforeach()
endfunction()

# Wrap Thrust/CUB in a custom namespace to check proper use of ns macros:
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub")
cub_add_header_test(base "${header_definitions}")

list(APPEND header_definitions "CUB_DISABLE_BF16_SUPPORT")
cub_add_header_test(bf16 "${header_definitions}")

6 changes: 6 additions & 0 deletions cmake/header_test.in
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@
#define B0 CUB_MACRO_CHECK("B0", termios.h)

#include <cub/${header}>

#if defined(CUB_DISABLE_BF16_SUPPORT)
#if defined(__CUDA_BF16_TYPES_EXIST__)
#error CUB should not include cuda_bf16.h when BF16 support is disabled
#endif
#endif
11 changes: 7 additions & 4 deletions cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !_NVHPC_CUDA
#include <cuda_fp16.h>
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA
#include <cuda_bf16.h>
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA && \
!defined(CUB_DISABLE_BF16_SUPPORT)
#include <cuda_bf16.h>
#endif

#include <cub/util_arch.cuh>
Expand Down Expand Up @@ -1105,7 +1106,8 @@ struct FpLimits<__half>
};
#endif

#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA && \
!defined(CUB_DISABLE_BF16_SUPPORT)
template <>
struct FpLimits<__nv_bfloat16>
{
Expand Down Expand Up @@ -1188,7 +1190,8 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !_NVHPC_CUDA
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !_NVHPC_CUDA && \
!defined(CUB_DISABLE_BF16_SUPPORT)
template <> struct NumericTraits<__nv_bfloat16> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __nv_bfloat16> {};
#endif

Expand Down

0 comments on commit 1b61fc9

Please sign in to comment.