diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index b3bcd30f5718..bff9e424efd9 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -262,114 +262,115 @@ template void launch_bias_residual< template void launch_bias_residual< __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t); -__global__ void gptj_residual_add(float* input, - float* output, - float* attn, - float* bias, - float* attnbias, - int total_count, - int intermediate_size, - float mp_scale) +__global__ void gptj_residual_add(float* residual, + const float* hidden_state, + const float* attn, + const float* bias, + const float* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) { - float4* input_cast = reinterpret_cast(input); - float4* output_cast = reinterpret_cast(output); - float4* attn_cast = reinterpret_cast(attn); - float4* bias_cast = reinterpret_cast(bias); - float4* attnbias_cast = reinterpret_cast(attnbias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; + float4* res_fl4_ptr = reinterpret_cast(residual); + const float4* hs_fl4_ptr = reinterpret_cast(hidden_state); + const float4* attn_fl4_ptr = reinterpret_cast(attn); + const float4* bias_fl4_ptr = reinterpret_cast(bias); + const float4* attn_bias_fl4_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; if (offset < total_count) { - float4 data = input_cast[offset]; - float4 out = output_cast[offset]; - float4 res_vec = attn_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; + float4 res_fl4 = res_fl4_ptr[offset]; + const float4 hs_fl4 = hs_fl4_ptr[offset]; + const float4 attn_fl4 = attn_fl4_ptr[offset]; + const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; - if (attnbias) { - float4 attn_bias = attnbias_cast[offset % intermediate_size]; - data.x += attn_bias.x; - data.y += attn_bias.y; - data.z += attn_bias.z; - data.w += attn_bias.w; + if (attn_bias) { + float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size]; + // residual += attention_bias + res_fl4.x += attn_bias_fl4.x; + res_fl4.y += attn_bias_fl4.y; + res_fl4.z += attn_bias_fl4.z; + res_fl4.w += attn_bias_fl4.w; } - data.x = out.x + res_vec.x + (data.x + bias_data.x) * mp_scale; - data.y = out.y + res_vec.y + (data.y + bias_data.y) * mp_scale; - data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale; - data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale; + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale; + res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale; + res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale; + res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale; - input_cast[offset] = data; + res_fl4_ptr[offset] = res_fl4; } } -__global__ void gptj_residual_add(__half* input, - __half* output, - __half* attn, - __half* bias, - __half* attn_bias, - int total_count, - int intermediate_size, - float mp_scale) +__global__ void gptj_residual_add(__half* residual, + const __half* hidden_state, + const __half* attn, + const __half* bias, + const __half* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) { #ifdef HALF_PRECISION_AVAILABLE - float2* input_cast = reinterpret_cast(input); - float2* output_cast = reinterpret_cast(output); - float2* attn_cast = reinterpret_cast(attn); - - float2* bias_cast = reinterpret_cast(bias); - float2* attnbias_cast = reinterpret_cast(attn_bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; + float2* res_fl2_ptr = reinterpret_cast(residual); + const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); + const float2* attn_fl2_ptr = reinterpret_cast(attn); + const float2* bias_fl2_ptr = reinterpret_cast(bias); + const float2* attn_bias_fl2_ptr = reinterpret_cast(attn_bias); + const int offset = blockIdx.x * blockDim.x + threadIdx.x; if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 out_vec = output_cast[offset]; - float2 res_vec = attn_cast[offset]; + float2 res_fl2 = res_fl2_ptr[offset]; + const float2 hs_fl2 = hs_fl2_ptr[offset]; + const float2 attn_fl2 = attn_fl2_ptr[offset]; + const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - float2 bias_vec = bias_cast[offset % intermediate_size]; + __half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); + const __half2* hs_half2 = reinterpret_cast(&hs_fl2); + const __half2* attn_half2 = reinterpret_cast(&attn_fl2); + const __half2* bias_half2 = reinterpret_cast(&bias_fl2); - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* out_half = reinterpret_cast<__half2*>(&out_vec); - __half2* res_half = reinterpret_cast<__half2*>(&res_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + float2 res_low = __half22float2(res_half2[0]); + float2 res_high = __half22float2(res_half2[1]); - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); + const float2 hs_low = __half22float2(hs_half2[0]); + const float2 hs_high = __half22float2(hs_half2[1]); - float2 low_out = __half22float2(out_half[0]); - float2 high_out = __half22float2(out_half[1]); + const float2 attn_low = __half22float2(attn_half2[0]); + const float2 attn_high = __half22float2(attn_half2[1]); - float2 low_res = __half22float2(res_half[0]); - float2 high_res = __half22float2(res_half[1]); + const float2 bias_low = __half22float2(bias_half2[0]); + const float2 bias_high = __half22float2(bias_half2[1]); - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); if (attn_bias) { - float2 attn_bias_vec = attnbias_cast[offset % intermediate_size]; - __half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec); - float2 attn_low_bias = __half22float2(attnbias_half[0]); - float2 attn_high_bias = __half22float2(attnbias_half[1]); - low_data.x += attn_low_bias.x; - low_data.y += attn_low_bias.y; - high_data.x += attn_high_bias.x; - high_data.y += attn_high_bias.y; + const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; + const __half2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); + const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); + // residual += attention_bias + res_low.x += attn_bias_low.x; + res_low.y += attn_bias_low.y; + res_high.x += attn_bias_high.x; + res_high.y += attn_bias_high.y; } + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale; + res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale; + res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; + res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; - low_data.x = low_res.x + low_out.x + (low_data.x + low_bias.x) * mp_scale; - low_data.y = low_res.y + low_out.y + (low_data.y + low_bias.y) * mp_scale; - high_data.x = high_res.x + high_out.x + (high_data.x + high_bias.x) * mp_scale; - high_data.y = high_res.y + high_out.y + (high_data.y + high_bias.y) * mp_scale; + res_half2[0] = __float22half2_rn(res_low); + res_half2[1] = __float22half2_rn(res_high); - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; + res_fl2_ptr[offset] = res_fl2; } #endif } template -void launch_gptj_residual_add(T* input, - T* output, +void launch_gptj_residual_add(T* residual, + T* hidden_state, T* attn, T* bias, T* attn_bias, @@ -383,7 +384,7 @@ void launch_gptj_residual_add(T* input, dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); gptj_residual_add<<>>( - input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); + residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); } template void launch_gptj_residual_add(float*,