Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 81 additions & 80 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,114 +262,115 @@ template void launch_bias_residual<
template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t);

__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_scale)
__global__ void gptj_residual_add(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];

if (attnbias) {
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x += attn_bias.x;
data.y += attn_bias.y;
data.z += attn_bias.z;
data.w += attn_bias.w;
if (attn_bias) {
float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
// residual += attention_bias
res_fl4.x += attn_bias_fl4.x;
res_fl4.y += attn_bias_fl4.y;
res_fl4.z += attn_bias_fl4.z;
res_fl4.w += attn_bias_fl4.w;
}
data.x = out.x + res_vec.x + (data.x + bias_data.x) * mp_scale;
data.y = out.y + res_vec.y + (data.y + bias_data.y) * mp_scale;
data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale;
data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale;
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale;
res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale;
res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale;
res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale;

input_cast[offset] = data;
res_fl4_ptr[offset] = res_fl4;
}
}

__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_scale)
__global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);

float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];

float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);

float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);

float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
if (attn_bias) {
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x += attn_low_bias.x;
low_data.y += attn_low_bias.y;
high_data.x += attn_high_bias.x;
high_data.y += attn_high_bias.y;
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
// residual += attention_bias
res_low.x += attn_bias_low.x;
res_low.y += attn_bias_low.y;
res_high.x += attn_bias_high.x;
res_high.y += attn_bias_high.y;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale;
res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale;
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;

low_data.x = low_res.x + low_out.x + (low_data.x + low_bias.x) * mp_scale;
low_data.y = low_res.y + low_out.y + (low_data.y + low_bias.y) * mp_scale;
high_data.x = high_res.x + high_out.x + (high_data.x + high_bias.x) * mp_scale;
high_data.y = high_res.y + high_out.y + (high_data.y + high_bias.y) * mp_scale;
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

input_cast[offset] = vals_vec;
res_fl2_ptr[offset] = res_fl2;
}
#endif
}

template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
void launch_gptj_residual_add(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
Expand All @@ -383,7 +384,7 @@ void launch_gptj_residual_add(T* input,
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);

gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}

template void launch_gptj_residual_add<float>(float*,
Expand Down