diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 5c9e47402408..9fcc88ef41bd 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,7 +1,8 @@ #include #include #include -#include + +#include "../cub_helpers.h" #include #include diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 108091efbefa..69ac131bfa39 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -5,9 +5,14 @@ #include #include #include "dispatch.h" -#include -#include -#include + +#include "../cub_helpers.h" +#ifndef USE_ROCM + #include + #include +#else +// hipcub includes these via hipcub.hpp +#endif #include "cutlass/numeric_size.h" #include "cutlass/array.h" diff --git a/csrc/sampler.cu b/csrc/sampler.cu index d458f8e4c1d0..a2bd42813723 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -3,11 +3,7 @@ #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif +#include "cub_helpers.h" namespace vllm {