Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 100 additions & 86 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,120 +126,127 @@ void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, c
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);

__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_scale,
bool preln)
__global__ void fused_bias_residual(float* residual,
Copy link
Contributor Author

@arashb arashb Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi in these kernels the computation output is stored in residual as opposed to final hidden_state. Then as output of the MLP here we return the residual value. Wouldn't make sense to store the output in hidden_state and return the hidden_state instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying storage for residual is actually the original input tensor to the Transformer layer (which we don't otherwise modify), whereas hidden_state is backed by an offset into our scratchpad memory. That scratchpad section would be overwritten by another part of the Transformer (the allocation for hidden states is here and conflicts with the QKV GEMM allocation here).

Long term I don't know how safe it necessarily is to assume that we can freely modify the input Tensor though, so modifying our scratchpad structure might be the most correct option? I don't know if anyone else thinks there's risk in using the input tensor this way though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing up this @arashb. Also, thanks @cmikeh2 for the main reason of not putting data in the hidden_state here. I think the safe way is probably to create another tensor for the output if this is needed in other places of the pipeline, however, this imposes more memory-usage of the inference system which is not desired. So, I think let's keep the default as writing in residual and if required by the users we can add a flag to create tensors for each layer's individual outputs. How does it sound?

const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
const float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
if (preln) {
data.x = (data.x + res_vec.x + bias_data.x + attn_bias.x) * mp_scale + (out.x);
data.y = (data.y + res_vec.y + bias_data.y + attn_bias.y) * mp_scale + (out.y);
data.z = (data.z + res_vec.z + bias_data.z + attn_bias.z) * mp_scale + (out.z);
data.w = (data.w + res_vec.w + bias_data.w + attn_bias.w) * mp_scale + (out.w);
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_fl4.x =
(res_fl4.x + attn_fl4.x + bias_fl4.x + attn_bias_fl4.x) * mp_scale + (hs_fl4.x);
res_fl4.y =
(res_fl4.y + attn_fl4.y + bias_fl4.y + attn_bias_fl4.y) * mp_scale + (hs_fl4.y);
res_fl4.z =
(res_fl4.z + attn_fl4.z + bias_fl4.z + attn_bias_fl4.z) * mp_scale + (hs_fl4.z);
res_fl4.w =
(res_fl4.w + attn_fl4.w + bias_fl4.w + attn_bias_fl4.w) * mp_scale + (hs_fl4.w);
} else {
data.x = data.x + out.x + bias_data.x;
data.y = data.y + out.y + bias_data.y;
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
// residual += hidden_state + bias
res_fl4.x = res_fl4.x + hs_fl4.x + bias_fl4.x;
res_fl4.y = res_fl4.y + hs_fl4.y + bias_fl4.y;
res_fl4.z = res_fl4.z + hs_fl4.z + bias_fl4.z;
res_fl4.w = res_fl4.w + hs_fl4.w + bias_fl4.w;
}
input_cast[offset] = data;
res_fl4_ptr[offset] = res_fl4;
}
}

__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_scale,
bool preln)
__global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);

float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];

float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);

float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);

float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);

float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);

if (preln) {
low_data.x =
(low_data.x + low_res.x + (low_bias.x + attn_low_bias.x)) * mp_scale + low_out.x;
low_data.y =
(low_data.y + low_res.y + (low_bias.y + attn_low_bias.y)) * mp_scale + low_out.y;
high_data.x = (high_data.x + high_res.x + (high_bias.x + attn_high_bias.x)) * mp_scale +
high_out.x;
high_data.y = (high_data.y + high_res.y + (high_bias.y + attn_high_bias.y)) * mp_scale +
high_out.y;
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_low.x =
(res_low.x + attn_low.x + bias_low.x + attn_bias_low.x) * mp_scale + hs_low.x;
res_low.y =
(res_low.y + attn_low.y + bias_low.y + attn_bias_low.y) * mp_scale + hs_low.y;
res_high.x =
(res_high.x + attn_high.x + bias_high.x + attn_bias_high.x) * mp_scale + hs_high.x;
res_high.y =
(res_high.y + attn_high.y + bias_high.y + attn_bias_high.y) * mp_scale + hs_high.y;
} else {
low_data.x = (low_data.x + low_out.x + low_bias.x);
low_data.y = (low_data.y + low_out.y + low_bias.y);
high_data.x = (high_data.x + high_out.x + high_bias.x);
high_data.y = (high_data.y + high_out.y + high_bias.y);
// residual += hidden_state + bias
res_low.x = (res_low.x + hs_low.x + bias_low.x);
res_low.y = (res_low.y + hs_low.y + bias_low.y);
res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y);
}
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);

input_cast[offset] = vals_vec;
res_fl2_ptr[offset] = res_fl2;
}
#endif
}

template <typename T>
void launch_bias_residual(T* input,
T* output,
void launch_bias_residual(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
Expand All @@ -253,8 +260,15 @@ void launch_bias_residual(T* input,
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);

fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size, preln);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(residual,
hidden_state,
attn,
bias,
attn_bias,
total_count,
hidden_dim / 4,
1.0 / mp_size,
preln);
}

template void launch_bias_residual<
Expand Down