diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 47707c7c64a5..3f9bf4ca740a 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -432,71 +432,34 @@ template void launch_gptj_residual_add<__half>(__half*, int, int, cudaStream_t); - -__global__ void moe_res_matmul(float* residual, - float* coef, - float* mlp_out, - int seq_len, - int hidden_dim) +template +__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) { - unsigned tid = threadIdx.x; - float4* residual_cast = reinterpret_cast(residual); - float4* coef_cast = reinterpret_cast(coef); - float4* mlp_out_cast = reinterpret_cast(mlp_out); - - residual_cast += blockIdx.x * hidden_dim; - mlp_out_cast += blockIdx.x * hidden_dim; - - float4* coef_cast2 = coef_cast + hidden_dim; - - while (tid < hidden_dim) { - float4 res = residual_cast[tid]; - float4 mlp = mlp_out_cast[tid]; - float4 coef1 = coef_cast[tid]; - float4 coef2 = coef_cast2[tid]; - mlp.x = mlp.x * coef2.x + res.x * coef1.x; - mlp.y = mlp.y * coef2.y + res.y * coef1.y; - mlp.z = mlp.z * coef2.z + res.z * coef1.z; - mlp.w = mlp.w * coef2.w + res.w * coef1.w; - mlp_out_cast[tid] = mlp; - tid += blockDim.x; - } -} + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(T); -__global__ void moe_res_matmul(__half* residual, - __half* coef, - __half* mlp_out, - int seq_len, - int hidden_dim) -{ -#ifdef HALF_PRECISION_AVAILABLE - unsigned tid = threadIdx.x; - float2* residual_cast = reinterpret_cast(residual); - float2* mlp_out_cast = reinterpret_cast(mlp_out); - float2* coef_cast = reinterpret_cast(coef); - float2* coef_cast2 = coef_cast + hidden_dim; - - residual_cast += blockIdx.x * hidden_dim; - mlp_out_cast += blockIdx.x * hidden_dim; - - while (tid < hidden_dim) { - float2 res = residual_cast[tid]; - float2 coef1 = coef_cast[tid]; - float2 coef2 = coef_cast2[tid]; - float2 data = mlp_out_cast[tid]; - __half* data_h = reinterpret_cast<__half*>(&data); - __half* coef1_h = reinterpret_cast<__half*>(&coef1); - __half* coef2_h = reinterpret_cast<__half*>(&coef2); - __half* res_h = reinterpret_cast<__half*>(&res); - data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0]; - data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1]; - data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2]; - data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3]; - - mlp_out_cast[tid] = data; - tid += blockDim.x; + T* residual_seq = residual + blockIdx.x * hidden_dim; + T* mlp_out_seq = mlp_out + blockIdx.x * hidden_dim; + + for (unsigned tid = threadIdx.x * vals_per_access; tid < hidden_dim; + tid += blockDim.x * vals_per_access) { + T mlp[vals_per_access]; + T res[vals_per_access]; + T coef1[vals_per_access]; + T coef2[vals_per_access]; + + mem_access::load_global(mlp, mlp_out_seq + tid); + mem_access::load_global(res, residual_seq + tid); + mem_access::load_global(coef1, coef + tid); + mem_access::load_global(coef2, coef + tid + hidden_dim); + +#pragma unroll + for (int idx = 0; idx < vals_per_access; idx++) { + mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx]; + } + + mem_access::store_global(mlp_out_seq + tid, mlp); } -#endif } template @@ -510,7 +473,7 @@ void launch_moe_res_matmul(T* residual, dim3 grid_dim(seq_len); dim3 block_dim(1024); moe_res_matmul<<>>( - residual, coef, mlp_out, seq_len, hidden_dim / 4); + residual, coef, mlp_out, seq_len, hidden_dim); } template void launch_moe_res_matmul(float* residual,