Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 35 additions & 51 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,63 +1182,42 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
return output;
}

void residual_add_bias(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn,
bool add_bias,
bool preln)
template <typename T>
at::Tensor& residual_add_bias(at::Tensor& hidden_state,
const at::Tensor& residual,
const at::Tensor& attention_output,
const at::Tensor& attention_bias,
const at::Tensor& final_bias,
const int mp_size,
const bool mlp_after_attn,
const bool add_bias,
const bool preln)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
if (mlp_after_attn)
launch_bias_residual((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)(add_bias ? attention_b.data_ptr() : nullptr),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
else if (mlp_after_attn)
launch_bias_residual((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
int bsz = residual.size(0) * residual.size(1);
int hidden_size = residual.size(2);
if (mlp_after_attn)
launch_bias_residual(static_cast<T*>(residual.data_ptr()),
static_cast<T*>(hidden_state.data_ptr()),
static_cast<T*>(attention_output.data_ptr()),
static_cast<T*>(final_bias.data_ptr()),
static_cast<T*>(attention_bias.data_ptr()),
bsz,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)(add_bias ? attention_b.data_ptr() : nullptr),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
launch_gptj_residual_add<T>(
static_cast<T*>(residual.data_ptr()),
static_cast<T*>(hidden_state.data_ptr()),
static_cast<T*>(attention_output.data_ptr()),
static_cast<T*>(final_bias.data_ptr()),
static_cast<T*>((add_bias ? attention_bias.data_ptr() : nullptr)),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
return hidden_state;
}

std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
Expand Down Expand Up @@ -1380,7 +1359,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add_bias_fp32",
&residual_add_bias<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("residual_add_bias_fp16",
&residual_add_bias<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
Expand Down
25 changes: 15 additions & 10 deletions deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def forward(ctx,
fused_gemm_gelu,
vector_matmul_func,
bias_residual_func,
residual_add_func,
activation_func_type=ActivationFuncType.GELU):

if attn_nw is None:
Expand Down Expand Up @@ -630,16 +631,16 @@ def forward(ctx,
False,
output_w.scale,
config.q_int8)
inference_cuda_module.residual_add(
output,
residual if config.pre_layer_norm else residual_add,
input,
output_b,
output = residual_add_func(
output, # hidden state
residual if config.pre_layer_norm else residual_add, # residual
input, # attention output
bias if bias is not None else output_b,
config.mp_size,
config.mlp_after_attn,
bias is not None,
config.pre_layer_norm)
output_b,
config.mp_size, # model parallel size
config.mlp_after_attn, # whether mlp is after attention (GPTJ model architecture runs the MLP layer in parallel with attention)
bias is not None, # whether bias addition is fused
config.pre_layer_norm) # whether the layer norm is applied before attention
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
dist.all_reduce(output, group=mp_group)
return output
Expand Down Expand Up @@ -710,6 +711,9 @@ def __init__(self,
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
inference_cuda_module.bias_residual_fp32

self.residual_add_func = inference_cuda_module.residual_add_bias_fp16 if config.fp16 or config.q_int8 else \
inference_cuda_module.residual_add_bias_fp32

def forward(self, input, residual, residual_norm, bias):
return DeepSpeedMLPFunction.apply(input,
residual,
Expand All @@ -729,7 +733,8 @@ def forward(self, input, residual, residual_norm, bias):
self.mlp_gemm_func,
self.fused_gemm_gelu,
self.vector_matmul_func,
self.bias_residual_func)
self.bias_residual_func,
self.residual_add_func)


class DeepSpeedTransformerInference(nn.Module):
Expand Down
32 changes: 20 additions & 12 deletions tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def inference_module():
def run_residual_add_reference(hidden_state,
residual,
attention_output,
final_bias,
attention_output_bias,
final_bias,
mlp_after_attn,
add_bias,
mp_size=1):
Expand Down Expand Up @@ -86,21 +86,29 @@ def test_residual_add(inference_module,
ref_out = run_residual_add_reference(ref_out,
residual,
attention_output,
final_bias,
attention_output_bias,
final_bias,
mlp_after_attn,
add_bias,
mp_size)

inference_module.residual_add(
ds_out, # in-place update of ds_out. Needs reafactoring to be consistent with other kernels.
residual,
attention_output,
final_bias,
attention_output_bias,
mp_size,
mlp_after_attn,
add_bias,
preln)
res_add_args = [
ds_out,
residual,
attention_output,
attention_output_bias,
final_bias,
mp_size,
mlp_after_attn,
add_bias,
preln
]

if dtype == torch.float16:
ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
elif dtype == torch.float32:
ds_out = inference_module.residual_add_bias_fp32(*res_add_args)
else:
raise ValueError(f"Unsupported dtype: {dtype}")

assert (allclose(ds_out, ref_out))