Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3129,8 +3129,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


Expand Down Expand Up @@ -4543,19 +4543,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dt><tt>fc1_scales</tt> : T2</dt>
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dt><tt>fc2_scales</tt> : T2</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dt><tt>fc3_scales</tt> (optional) : T2</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
Expand All @@ -4571,10 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
<dt><tt>T2</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain scales type to float tensors.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -949,15 +949,15 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* cumulative_sequence_length:**S**<br> *in* past_seqlens:**S**<br> *in* block_table:**S**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* key_cache_out:**T**<br> *out* value_cache_out:**T**|1+|**S** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention);
Expand Down Expand Up @@ -307,7 +309,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention)>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif

namespace ort_fastertransformer {
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
} // namespace ort_fastertransformer
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif

namespace ort_fastertransformer {
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
} // namespace ort_fastertransformer
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif

#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"

#if defined(_MSC_VER)
#pragma warning(pop)
#endif

namespace ort_fastertransformer {
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
} // namespace ort_fastertransformer
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#include "cutlass_heuristic.h"
#include "moe_gemm_kernels.h"

#include <cuda_bf16.h>

#include <limits>
#include <math.h>
#include <sstream>
Expand All @@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
CutlassGemmConfig gemm_config, const int multi_processor_count,
cudaStream_t stream, int* kernel_occupancy = nullptr) {
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for half, float");
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || cutlass::platform::is_same<T, __nv_bfloat16>::value,
"Specialized for half, float, bfloat16");

static_assert(cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
Expand All @@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w

// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, typename cutlass::platform::conditional<cutlass::platform::is_same<T, __nv_bfloat16>::value, cutlass::bfloat16_t, T>::type>::type;
using ElementType = ElementType_;

using CutlassWeightType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
WeightType>::type;
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, __nv_bfloat16>::value, cutlass::bfloat16_t, WeightType>::type>::type;

using CutlassWeightType = CutlassWeightType_;

Expand Down
56 changes: 49 additions & 7 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@

#include "moe_kernel.h"

#include <cuda_runtime_api.h>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>

#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"

namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;

Expand Down Expand Up @@ -103,11 +106,16 @@ void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows
dim3 block(std::min(intermediate_size, 1024));
dim3 grid(num_rows);

DUMP_TENSOR_INIT();
DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size);

if constexpr (interleaved) {
swiglu_kernel_interleaved<T><<<grid, block, 0, stream>>>(output, input, intermediate_size, num_rows, swiglu_alpha);
} else {
swiglu_kernel_chunked<T><<<grid, block, 0, stream>>>(output, input, intermediate_size, num_rows, swiglu_alpha);
}

DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size);
}

// ====================== Softmax things ===============================
Expand Down Expand Up @@ -838,11 +846,15 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr,
}

namespace {

struct __align__(8) Half4 {
typedef struct __CUDA_ALIGN__(8) {
half2 x;
half2 y;
};
} half2_2;

typedef struct __CUDA_ALIGN__(8) {
__nv_bfloat162 x;
__nv_bfloat162 y;
} __nv_bfloat162_2;

// TODO(wy): move to common header
template <typename T>
Expand All @@ -853,7 +865,11 @@ struct T4<float> {
};
template <>
struct T4<half> {
using Type = Half4;
using Type = half2_2;
};
template <>
struct T4<__nv_bfloat16> {
using Type = __nv_bfloat162_2;
};

template <typename T>
Expand All @@ -866,6 +882,10 @@ template <>
struct T2<half> {
using Type = half2;
};
template <>
struct T2<__nv_bfloat16> {
using Type = __nv_bfloat162;
};

inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); }

Expand All @@ -882,15 +902,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha
#endif

// TODO(wy): use cuda common header and investigate pipeline build issue.
inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
Half4 result;
half2_2 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#else
return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
#endif
}

inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__nv_bfloat162_2 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#else
return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
#endif
}

Expand Down Expand Up @@ -1291,18 +1323,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int,
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int,
int, bool, bool, cudaStream_t);

// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;
template class CutlassMoeFCRunner<half, half>;
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>;
// For qMoE:
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
template class CutlassMoeFCRunner<half, uint8_t>;
template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>;
template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>;

// ===================== Specializations for init routing =========================
template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int,
cudaStream_t);
template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int,
cudaStream_t);
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int,
cudaStream_t);

// ==================== Specializations for final routing ===================================
template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*,
Expand All @@ -1317,6 +1357,8 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl
const float*, const int*, const int*, int, int, int, cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*,
const half*, const int*, const int*, int, int, int, cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*,
const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t);

template void invokeSwiGLU<float, true>(float*, float const*, int, int, float, cudaStream_t);
template void invokeSwiGLU<half, true>(half*, half const*, int, int, float, cudaStream_t);
Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,4 @@ class CutlassMoeFCRunner {
std::vector<int64_t> total_rows_before_expert_host_;
};

template <typename WeightType>
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
public:
CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);

size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;
}
};

} // namespace ort_fastertransformer
Loading
Loading