Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 26 additions & 63 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -432,71 +432,34 @@ template void launch_gptj_residual_add<__half>(__half*,
int,
int,
cudaStream_t);

__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
template <typename T>
__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);

residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;

float4* coef_cast2 = coef_cast + hidden_dim;

while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(T);

__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;

residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;

while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast2[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];

mlp_out_cast[tid] = data;
tid += blockDim.x;
T* residual_seq = residual + blockIdx.x * hidden_dim;
T* mlp_out_seq = mlp_out + blockIdx.x * hidden_dim;

for (unsigned tid = threadIdx.x * vals_per_access; tid < hidden_dim;
tid += blockDim.x * vals_per_access) {
T mlp[vals_per_access];
T res[vals_per_access];
T coef1[vals_per_access];
T coef2[vals_per_access];

mem_access::load_global<granularity>(mlp, mlp_out_seq + tid);
mem_access::load_global<granularity>(res, residual_seq + tid);
mem_access::load_global<granularity>(coef1, coef + tid);
mem_access::load_global<granularity>(coef2, coef + tid + hidden_dim);

#pragma unroll
for (int idx = 0; idx < vals_per_access; idx++) {
mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx];
}

mem_access::store_global<granularity>(mlp_out_seq + tid, mlp);
}
#endif
}

template <typename T>
Expand All @@ -510,7 +473,7 @@ void launch_moe_res_matmul(T* residual,
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
moe_res_matmul<<<grid_dim, block_dim, 0, stream>>>(
residual, coef, mlp_out, seq_len, hidden_dim / 4);
residual, coef, mlp_out, seq_len, hidden_dim);
}

template void launch_moe_res_matmul(float* residual,
Expand Down