diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 360beaa10284..e323dd9c233a 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -8,6 +8,11 @@ std::array gemm_algos = std::array({99, 99, 99}); +// NOTE: This activation function type enum should be always in sync +// with the python counterpart, otherwise the casting from python binding +// will be incorrect. +enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 }; + template at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, @@ -464,9 +469,9 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, 1); if (layer_id == num_layers - 1) Context::Instance().advance_tokens(); - auto prev_key = torch::from_blob(workspace + offset, {bsz, all_tokens, hidden_dim}, options); + auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, options); auto prev_value = - torch::from_blob(workspace + offset + value_offset, {bsz, all_tokens, hidden_dim}, options); + torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, options); return {output, prev_key, prev_value}; } @@ -486,6 +491,22 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) return input_cont; } +template +at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_relu((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + Context::Instance().GetCurrentStream()); + return input_cont; +} + template at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) { @@ -840,7 +861,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, at::Tensor& beta, const float epsilon, bool preLayerNorm, - bool mlp_after_attn) + bool mlp_after_attn, + ActivationFuncType act_func_type) { int bsz = input.size(0) * input.size(1); auto inp_norm = at::empty_like(input); @@ -878,13 +900,24 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); + + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + } + return inp_norm; } + template std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& residual, @@ -895,7 +928,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& beta, const float epsilon, bool preLayerNorm, - bool mlp_after_attn) + bool mlp_after_attn, + int activation_type) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -907,6 +941,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); int bsz = input_cont.size(0) * input_cont.size(1); + auto act_func_type = static_cast(activation_type); auto res_add = mlp_unfused_cublas(output, mlp_after_attn ? input : residual, residual, @@ -917,7 +952,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, beta, epsilon, preLayerNorm, - mlp_after_attn); + mlp_after_attn, + act_func_type); return {output, res_add}; } @@ -1205,7 +1241,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &ds_softmax_context1<__half>, "DeepSpeed attention with fp32 (CUDA)"); m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); + m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); + m.def("bias_relu_fp32", &ds_bias_relu, "DeepSpeed ReLU with fp32 (CUDA)"); + m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)"); m.def("bias_residual_fp32", &ds_bias_residual, "DeepSpeed residual-bias add with fp32 (CUDA)"); diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu new file mode 100644 index 000000000000..87011f65ea92 --- /dev/null +++ b/csrc/transformer/inference/csrc/relu.cu @@ -0,0 +1,95 @@ +#include "custom_cuda_layers.h" + +#define MAX_CAP 4 +#define MAX_SEQ 2048 + +inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } + +__global__ void fused_bias_relu(float* input, + const float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = relu(data.x); + data.y = relu(data.y); + data.z = relu(data.z); + data.w = relu(data.w); + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_relu(__half* input, + const __half* bias, + int total_count, + int intermediate_size) +{ +#ifdef HALF_PRECISION_AVAILABLE + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = relu(low_data.x); + low_data.y = relu(low_data.y); + high_data.x = relu(high_data.x); + high_data.y = relu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_relu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int total_count = batch_size * (intermediate_size / 4); + int threads = 1024; // intermediate_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); + + fused_bias_relu<<>>( + input, bias, total_count, intermediate_size / 4); +} + +template void launch_bias_relu(float*, const float*, int, int, cudaStream_t); +template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index c2bb30126cd6..afa3c65aed7b 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -50,6 +50,15 @@ void launch_bias_gelu(T* input, int intermediate_size, int batch_size, cudaStream_t stream); + +// Fused bias add with relu activation +template +void launch_bias_relu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); + template void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index db9efb19dcb1..1c22960b1e31 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -19,6 +19,7 @@ from ..moe.utils import has_moe_layers from ..runtime.zero import GatheredParameters from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing +from ..module_inject.replace_policy import DSPolicy DS_INFERENCE_ENABLED = False from torch import nn @@ -77,6 +78,9 @@ def __init__(self, self._get_model_config_generate(config) + if hasattr(self.module, "config"): + DSPolicy.hf_model_config = self.module.config + self.mp_world_size = mp_size self.checkpoint = checkpoint self.dtype = dtype diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 4ae9e5529d0e..ebfbbec8aa66 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -147,7 +147,6 @@ def replace_transformer_layer(orig_layer_impl, mp_group=None, ep_group=None, expert_mp_group=None, - preln=True, fp16=True, local_rank=-1, stochastic_mode=True, @@ -204,13 +203,8 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, - preln=True, layer_id=0): - preln = False if policy_cls is HFBertLayerPolicy else preln - if policy_cls is HFBertLayerPolicy: - policy = policy_cls(child, inference=inference, preln=preln) - else: - policy = policy_cls(child, inference=inference) + policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() @@ -275,7 +269,7 @@ def replace_with_policy(child, config, 'layer_norm_eps') else 1e-12, fp16=fp16, - pre_layer_norm=preln, + pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, @@ -297,7 +291,7 @@ def replace_with_policy(child, if hasattr(config, 'layernorm_epsilon') else 1.0e-12), fp16=fp16, - pre_layer_norm=preln, + pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), @@ -309,6 +303,7 @@ def replace_with_policy(child, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), + mlp_act_func_type=policy.mlp_act_func_type, training_mp_size=training_mp_size, bigscience_bloom=bigscience_bloom) @@ -594,7 +589,7 @@ def _transpose(x): 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, - pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln), + pre_layer_norm=policy.pre_attn_norm, return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, @@ -758,10 +753,7 @@ def _replace_module(r_module, prev_name=''): def replace_fn(child, _policy, layer_id=0): if training: # copy relevant state from child -> new module - new_module = replace_with_policy(child, - _policy, - triangular_masking, - preln=preln) + new_module = replace_with_policy(child, _policy, triangular_masking) else: # copy relevant state from child -> new module @@ -770,8 +762,6 @@ def replace_fn(child, _policy, layer_id=0): _policy, triangular_masking, inference=True, - preln=(_policy - is not HFBertLayerPolicy), layer_id=layer_id) else: new_module = replace_wo_policy(child, _policy) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 3d5c53275e33..b0f5be283b37 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -4,19 +4,33 @@ from torch.nn.parameter import Parameter from packaging import version as pkg_version +from deepspeed.utils.types import ActivationFuncType + supported_models = {None} class DSPolicy(ABC): - def __init__(self, - inference=True, - linear_layer=True, - scale_attention=True, - megatron_v2=False): + # a static class variable containing the HuggingFace model configuration. + # see e.g., transformers.models.opt.configuration_opt.OPTConfig + hf_model_config = None + + def __init__( + self, + inference=True, + linear_layer=True, + scale_attention=True, + megatron_v2=False, + # the type of activation function used in MLP + mlp_act_func_type=ActivationFuncType.GELU, + # applies layer norm before attention if `pre_attn_norm` is set to True + pre_attn_norm=True): + self.inference = inference self.linear_layer = linear_layer self.scale_attention = scale_attention self.is_megatron_v2 = megatron_v2 + self.mlp_act_func_type = mlp_act_func_type + self.pre_attn_norm = pre_attn_norm def attention(self): """ @@ -52,10 +66,10 @@ def layerNorm(self): class HFBertLayerPolicy(DSPolicy): _orig_layer_class = None - def __init__(self, client_module, inference=False, preln=False): - super().__init__(inference) + def __init__(self, client_module, inference=False): + super().__init__(inference, pre_attn_norm=False) self.client_module = client_module - self.preln = preln + if HFBertLayerPolicy._orig_layer_class is None: try: import transformers @@ -90,7 +104,7 @@ def attention(self): self.is_megatron_v2 def mlp(self): - if self.preln: + if self.pre_attn_norm: intermediate_ff = self.client_module.intermediate.dense_act else: intermediate_ff = self.client_module.intermediate.dense @@ -100,7 +114,7 @@ def mlp(self): self.client_module.output.dense.bias def layerNorm(self): - if self.preln: + if self.pre_attn_norm: attention_layernorm = self.client_module.PostAttentionLayerNorm transformer_layernorm = self.client_module.PreAttentionLayerNorm else: @@ -181,12 +195,12 @@ def attention(self): qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) return self.linear_layer, \ - qkvw, \ - None, \ - self.client_module.attn.out_proj.weight, \ - None, \ - self.scale_attention, \ - self.is_megatron_v2 + qkvw, \ + None, \ + self.client_module.attn.out_proj.weight, \ + None, \ + self.scale_attention, \ + self.is_megatron_v2 def mlp(self): return self.linear_layer, \ @@ -418,6 +432,64 @@ def layerNorm(self): self.client_module.input_layernorm.bias +class HFOPTLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + super().__init__(inference, + linear_layer=True, + mlp_act_func_type=ActivationFuncType.ReLU, + pre_attn_norm=True) + self.client_module = client_module + try: + import transformers + HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer + except: + HFOPTLayerPolicy._orig_layer_class = None + + if isinstance(DSPolicy.hf_model_config, + transformers.models.opt.configuration_opt.OPTConfig): + self.pre_attn_norm = self.hf_model_config.do_layer_norm_before + + def get_hidden_heads(self): + return self.client_module.self_attn.embed_dim, \ + self.client_module.self_attn.num_heads + + def attention(self): + qw = self.client_module.self_attn.q_proj.weight + qb = self.client_module.self_attn.q_proj.bias + + kw = self.client_module.self_attn.k_proj.weight + kb = self.client_module.self_attn.k_proj.bias + + vw = self.client_module.self_attn.v_proj.weight + vb = self.client_module.self_attn.v_proj.bias + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) + + return self.linear_layer, \ + qkvw, \ + qkvb, \ + self.client_module.self_attn.out_proj.weight, \ + self.client_module.self_attn.out_proj.bias, \ + self.scale_attention, \ + self.is_megatron_v2 + + def mlp(self): + return self.linear_layer, \ + self.client_module.fc1.weight, \ + self.client_module.fc1.bias, \ + self.client_module.fc2.weight, \ + self.client_module.fc2.bias + + def layerNorm(self): + return self.client_module.final_layer_norm.weight, \ + self.client_module.final_layer_norm.bias, \ + self.client_module.self_attn_layer_norm.weight, \ + self.client_module.self_attn_layer_norm.bias + + replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, @@ -426,4 +498,5 @@ def layerNorm(self): MegatronLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy, + HFOPTLayerPolicy, ] diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index d38cf8c3d395..f03cc1248578 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -9,6 +9,7 @@ import torch.nn as nn from deepspeed import comm as dist from deepspeed.utils.logging import log_dist +from deepspeed.utils.types import ActivationFuncType # Cuda modules will be imported if needed inference_cuda_module = None @@ -71,6 +72,7 @@ def __init__(self, rotate_every_two=True, return_tuple=True, mlp_after_attn=True, + mlp_act_func_type=ActivationFuncType.GELU, training_mp_size=1, bigscience_bloom=False): super(DeepSpeedInferenceConfig, @@ -95,6 +97,7 @@ def __init__(self, self.rotate_every_two = rotate_every_two self.return_tuple = return_tuple self.mlp_after_attn = mlp_after_attn + self.mlp_act_func_type = mlp_act_func_type self.specialized_mode = False self.training_mp_size = training_mp_size self.bigscience_bloom = bigscience_bloom @@ -589,7 +592,8 @@ def forward(ctx, mlp_gemm_func, fused_gemm_gelu, vector_matmul_func, - bias_residual_func): + bias_residual_func, + activation_func_type=ActivationFuncType.GELU): if config.q_int8: (intermediate, @@ -629,7 +633,8 @@ def forward(ctx, attn_nb, config.epsilon, config.pre_layer_norm, - config.mlp_after_attn) + config.mlp_after_attn, + config.mlp_act_func_type) output = vector_matmul_func(intermediate, output_w, False) inference_cuda_module.residual_add( @@ -795,28 +800,35 @@ def __init__(self, device=device)) self.layer_past = None - def forward(self, - input, - input_mask=None, - attention_mask=None, - head_mask=None, - layer_past=None, - get_key_value=False, - get_present=False, - encoder_output=None, - enc_dec_attn_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - alibi=None, - output_attentions=False): + def forward( + self, + input, + input_mask=None, + attention_mask=None, + head_mask=None, + layer_past=None, + get_key_value=False, + get_present=False, + encoder_output=None, + enc_dec_attn_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + alibi=None, + output_attentions=False, + # TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API. + # This needs to be redesigned later! + layer_head_mask=None, + past_key_value=None): get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask # We set the prev key/value to None when there is a prompt if input.shape[1] > 1: self.layer_past = None + layer_past = layer_past if layer_past is not None else self.layer_past + head_mask = layer_head_mask if layer_head_mask is not None else head_mask attn_mask = None if isinstance(input, tuple): diff --git a/deepspeed/utils/types.py b/deepspeed/utils/types.py new file mode 100644 index 000000000000..1e833c12b007 --- /dev/null +++ b/deepspeed/utils/types.py @@ -0,0 +1,7 @@ +from enum import IntEnum + + +class ActivationFuncType(IntEnum): + UNKNOWN = 0 + GELU = 1 + ReLU = 2 diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index eb374c29ef77..b0b86225e97c 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -36,6 +36,7 @@ def sources(self): return [ 'csrc/transformer/inference/csrc/pt_binding.cpp', 'csrc/transformer/inference/csrc/gelu.cu', + 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/normalize.cu', 'csrc/transformer/inference/csrc/softmax.cu', 'csrc/transformer/inference/csrc/dequantize.cu', diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 09fcd0736af5..1294bb0d5895 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -54,9 +54,13 @@ def lm_eval_imports(): "EleutherAI/gpt-j-6B", "bigscience/bloom-350m", ] +_opt_models = [ + "facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture. + "facebook/opt-350m", # 350m applies layer norm after attnention layer which is different than other variants. +] _all_models = HfApi().list_models() -test_models = set(_bert_models + _roberta_models + _gpt_models) +test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models) test_tasks = [ "fill-mask", "question-answering",