diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 592812c7bea0..18f6767f85a5 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1182,63 +1182,42 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, return output; } -void residual_add_bias(at::Tensor& output, - at::Tensor& input, - at::Tensor& attention_output, - at::Tensor& output_b, - at::Tensor& attention_b, - int mp_size, - bool mlp_after_attn, - bool add_bias, - bool preln) +template +at::Tensor& residual_add_bias(at::Tensor& hidden_state, + const at::Tensor& residual, + const at::Tensor& attention_output, + const at::Tensor& attention_bias, + const at::Tensor& final_bias, + const int mp_size, + const bool mlp_after_attn, + const bool add_bias, + const bool preln) { - int bsz = input.size(0) * input.size(1); - int hidden_size = input.size(2); - // cudaStreamWaitEvent( - // Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0); - if (input.scalar_type() == at::kFloat) - if (mlp_after_attn) - launch_bias_residual((float*)input.data_ptr(), - (float*)output.data_ptr(), - (float*)attention_output.data_ptr(), - (float*)output_b.data_ptr(), - (float*)attention_b.data_ptr(), - bsz, - hidden_size, - mp_size, - preln, - Context::Instance().GetCurrentStream()); - else - launch_gptj_residual_add((float*)input.data_ptr(), - (float*)output.data_ptr(), - (float*)attention_output.data_ptr(), - (float*)output_b.data_ptr(), - (float*)(add_bias ? attention_b.data_ptr() : nullptr), - hidden_size, - bsz, - mp_size, - Context::Instance().GetCurrentStream()); - else if (mlp_after_attn) - launch_bias_residual((__half*)input.data_ptr(), - (__half*)output.data_ptr(), - (__half*)attention_output.data_ptr(), - (__half*)output_b.data_ptr(), - (__half*)attention_b.data_ptr(), + int bsz = residual.size(0) * residual.size(1); + int hidden_size = residual.size(2); + if (mlp_after_attn) + launch_bias_residual(static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast(attention_bias.data_ptr()), bsz, hidden_size, mp_size, preln, Context::Instance().GetCurrentStream()); else - launch_gptj_residual_add<__half>((__half*)input.data_ptr(), - (__half*)output.data_ptr(), - (__half*)attention_output.data_ptr(), - (__half*)output_b.data_ptr(), - (__half*)(add_bias ? attention_b.data_ptr() : nullptr), - hidden_size, - bsz, - mp_size, - Context::Instance().GetCurrentStream()); + launch_gptj_residual_add( + static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast((add_bias ? attention_bias.data_ptr() : nullptr)), + hidden_size, + bsz, + mp_size, + Context::Instance().GetCurrentStream()); + return hidden_state; } std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, @@ -1380,7 +1359,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "DeepSpeed linear_layer with int8 (CUDA)"); m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu, "DeepSpeed mlp with fp32 (CUDA)"); m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("residual_add_bias_fp32", + &residual_add_bias, + "DeepSpeed residual add with fp32 (CUDA)"); + m.def("residual_add_bias_fp16", + &residual_add_bias<__half>, + "DeepSpeed residual add with fp16 (CUDA)"); m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); m.def("einsum_sec_sm_ecm_fp32", &einsum_sec_sm_ecm, diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index fa28a34f04a2..4b73f071b9c5 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -601,6 +601,7 @@ def forward(ctx, fused_gemm_gelu, vector_matmul_func, bias_residual_func, + residual_add_func, activation_func_type=ActivationFuncType.GELU): if attn_nw is None: @@ -630,16 +631,16 @@ def forward(ctx, False, output_w.scale, config.q_int8) - inference_cuda_module.residual_add( - output, - residual if config.pre_layer_norm else residual_add, - input, - output_b, + output = residual_add_func( + output, # hidden state + residual if config.pre_layer_norm else residual_add, # residual + input, # attention output bias if bias is not None else output_b, - config.mp_size, - config.mlp_after_attn, - bias is not None, - config.pre_layer_norm) + output_b, + config.mp_size, # model parallel size + config.mlp_after_attn, # whether mlp is after attention (GPTJ model architecture runs the MLP layer in parallel with attention) + bias is not None, # whether bias addition is fused + config.pre_layer_norm) # whether the layer norm is applied before attention if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return output @@ -710,6 +711,9 @@ def __init__(self, self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ inference_cuda_module.bias_residual_fp32 + self.residual_add_func = inference_cuda_module.residual_add_bias_fp16 if config.fp16 or config.q_int8 else \ + inference_cuda_module.residual_add_bias_fp32 + def forward(self, input, residual, residual_norm, bias): return DeepSpeedMLPFunction.apply(input, residual, @@ -729,7 +733,8 @@ def forward(self, input, residual, residual_norm, bias): self.mlp_gemm_func, self.fused_gemm_gelu, self.vector_matmul_func, - self.bias_residual_func) + self.bias_residual_func, + self.residual_add_func) class DeepSpeedTransformerInference(nn.Module): diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index b72fd25f1217..6f34437fa3aa 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -26,8 +26,8 @@ def inference_module(): def run_residual_add_reference(hidden_state, residual, attention_output, - final_bias, attention_output_bias, + final_bias, mlp_after_attn, add_bias, mp_size=1): @@ -86,21 +86,29 @@ def test_residual_add(inference_module, ref_out = run_residual_add_reference(ref_out, residual, attention_output, - final_bias, attention_output_bias, + final_bias, mlp_after_attn, add_bias, mp_size) - inference_module.residual_add( - ds_out, # in-place update of ds_out. Needs reafactoring to be consistent with other kernels. - residual, - attention_output, - final_bias, - attention_output_bias, - mp_size, - mlp_after_attn, - add_bias, - preln) + res_add_args = [ + ds_out, + residual, + attention_output, + attention_output_bias, + final_bias, + mp_size, + mlp_after_attn, + add_bias, + preln + ] + + if dtype == torch.float16: + ds_out = inference_module.residual_add_bias_fp16(*res_add_args) + elif dtype == torch.float32: + ds_out = inference_module.residual_add_bias_fp32(*res_add_args) + else: + raise ValueError(f"Unsupported dtype: {dtype}") assert (allclose(ds_out, ref_out))