From df895f7a515a8bf5f44172d3e33c290ea712fd9b Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 29 Sep 2022 14:23:50 -0700 Subject: [PATCH 01/20] add initial sd policy --- deepspeed/inference/engine.py | 73 +++++++++++++---------- deepspeed/module_inject/replace_module.py | 24 +++++++- deepspeed/module_inject/replace_policy.py | 53 ++++++++++++---- deepspeed/module_inject/unet.py | 36 +++++++++++ deepspeed/moe/utils.py | 1 + 5 files changed, 141 insertions(+), 46 deletions(-) create mode 100644 deepspeed/module_inject/unet.py diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 89a8d8288455..2c44648a2091 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -14,7 +14,7 @@ from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization -from ..module_inject.replace_module import replace_transformer_layer +from ..module_inject.replace_module import replace_transformer_layer, generic_injection from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers @@ -89,7 +89,7 @@ def __init__(self, self.injection_dict = injection_dict self.mp_group = None self.mpu = mpu - self._validate_args(mpu) + self._validate_args(mpu, replace_with_kernel_inject) self.replace_method = replace_method self.quantize_merge_count = 1 self.quantization_scales = None @@ -125,7 +125,8 @@ def __init__(self, elif self.mp_world_size > 1: self._create_model_parallel_group() - moe, _ = has_moe_layers(self.module) + if isinstance(self.module, torch.nn.Module): + moe, _ = has_moe_layers(self.module) if moe and dist.get_world_size() > 1: self._create_ep_parallel_group(moe_experts) @@ -251,8 +252,9 @@ def _init_quantization_setting(self, quantization_setting): f"quantize_groups = {self.quantize_groups}", [0]) - def _validate_args(self, mpu): - if not isinstance(self.module, Module): + def _validate_args(self, mpu, replace_with_kernel_inject): + # TODO: to support SD pipeline we need to avoid this check for now + if replace_with_kernel_inject and not isinstance(self.module, Module): raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") if not isinstance(self.mp_world_size, int) or self.mp_world_size < 1: raise ValueError(f"mp_size must be an int >= 1, got {self.mp_world_size}") @@ -357,33 +359,37 @@ def _apply_injection_policy(self, checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - replace_transformer_layer(client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half) - or (self.dtype == torch.int8), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint_dict=checkpoint, - save_mp_checkpoint_path=save_mp_checkpoint_path, - base_dir=base_dir, - enable_cuda_graph=self.enable_cuda_graph) + + generic_injection(self.module) + + if isinstance(self.module, torch.nn.Module): + replace_transformer_layer( + client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half) or (self.dtype == torch.int8), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint_dict=checkpoint, + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir, + enable_cuda_graph=self.enable_cuda_graph) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -478,6 +484,9 @@ def _choose_module_key(self, sd): return 'model' def _convert_to_dtype(self): + if not isinstance(self.module, torch.nn.Module): + return + if False: #self.dtype is torch.int8 and self.quantization_scales is None: quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) model, self.quantization_scales = quantizer.model_quantize(self.module, diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ea0e13726316..b730aabb210c 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -4,7 +4,7 @@ import deepspeed import deepspeed.ops.transformer as transformer_inference from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy -from .replace_policy import replace_policies +from .replace_policy import replace_policies, generic_policies #from ..runtime.weight_quantizer import WeightQuantization from deepspeed import comm as dist from torch import nn @@ -187,6 +187,27 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): return out +def _module_match(module): + for policy in generic_policies: + policy = policy() + if policy.match(module): + return policy + return None + + +def generic_injection(module): + if isinstance(module, torch.nn.Module): + pass + else: + for name in module.__dict__.keys(): + sub_module = getattr(module, name) + policy = _module_match(sub_module) + if policy is not None: + new_module = policy.apply(sub_module) + print(f"**** found and replaced {name} w. {type(new_module)}") + setattr(module, name, new_module) + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -251,6 +272,7 @@ def replace_transformer_layer(orig_layer_impl, Returns: Updated nn.module with replaced transformer layers """ + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=mp_size) #, out_dim=0, in_dim=1) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index cb7c4818961a..5c47fa71213e 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -1,3 +1,6 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' from abc import ABC import torch @@ -10,6 +13,30 @@ class DSPolicy(ABC): + _orig_layer_class = None + + def __init__(self): + self.cuda_graph_supported = False + + +class UNetPolicy(DSPolicy): + def __init__(self): + super().__init__() + try: + import diffusers + self._orig_layer_class = diffusers.models.unet_2d_condition.UNet2DConditionModel + except ImportError: + self._orig_layer_class = None + + def match(self, module): + return isinstance(module, self._orig_layer_class) + + def apply(self, module): + from .unet import DSUNet + return DSUNet(module) + + +class TransformerPolicy(DSPolicy): # a static class variable containing the HuggingFace model configuration. # see e.g., transformers.models.opt.configuration_opt.OPTConfig hf_model_config = None @@ -24,7 +51,7 @@ def __init__( mlp_act_func_type=ActivationFuncType.GELU, # applies layer norm before attention if `pre_attn_norm` is set to True pre_attn_norm=True): - self.cuda_graph_supported = False + super().__init__() self.inference = inference self.linear_layer = linear_layer self.scale_attention = scale_attention @@ -63,9 +90,7 @@ def layerNorm(self): raise NotImplementedError -class HFBertLayerPolicy(DSPolicy): - _orig_layer_class = None - +class HFBertLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=False): super().__init__(inference, pre_attn_norm=False) self.client_module = client_module @@ -127,9 +152,7 @@ def layerNorm(self): transformer_layernorm.bias -class HFGPTNEOLayerPolicy(DSPolicy): - _orig_layer_class = None - +class HFGPTNEOLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=True): super().__init__(inference, scale_attention=False) self.client_module = client_module @@ -172,7 +195,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class HFGPTJLayerPolicy(DSPolicy): +class HFGPTJLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -217,7 +240,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class MegatronLayerPolicy(DSPolicy): +class MegatronLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 moe_type = 'standard' @@ -297,7 +320,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class HFGPT2LayerPolicy(DSPolicy): +class HFGPT2LayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -337,7 +360,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class BLOOMLayerPolicy(DSPolicy): +class BLOOMLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -379,7 +402,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class GPTNEOXLayerPolicy(DSPolicy): +class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 @@ -433,7 +456,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class HFOPTLayerPolicy(DSPolicy): +class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -490,6 +513,7 @@ def layerNorm(self): self.client_module.self_attn_layer_norm.bias +# transformer-based policies replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, @@ -500,3 +524,6 @@ def layerNorm(self): BLOOMLayerPolicy, HFOPTLayerPolicy, ] + +# non-transformer-based policies +generic_policies = [UNetPolicy] diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py new file mode 100644 index 000000000000..7eb1e6a68f1b --- /dev/null +++ b/deepspeed/module_inject/unet.py @@ -0,0 +1,36 @@ +''' +Copyright 2022 The Microsoft DeepSpeed Team +''' +import torch +import diffusers + + +class DSUNet(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.unet = unet + # SD pipeline accesses this attribute + self.in_channels = unet.in_channels + self._traced_unet = None + self._enabled = True + self.device = self.unet.device + self.unet.requires_grad_(requires_grad=False) + + def forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): + if self._enabled: + if self._traced_unet is None: + # boosts perf ~10% + self.unet.to(memory_format=torch.channels_last) + + # force return tuple instead of dict + self._traced_unet = torch.jit.trace( + lambda _sample, _timestamp, _encoder_hidden_states: self.unet(_sample, _timestamp, _encoder_hidden_states, return_dict=False), + (sample, timestamp, encoder_hidden_states) + ) + return self.unet(sample, timestamp, encoder_hidden_states) + else: + # convert return type to UNet2DConditionOutput + out_sample, *_ = self._traced_unet(sample, timestamp, encoder_hidden_states) + return diffusers.models.unet_2d_condition.UNet2DConditionOutput(out_sample) + else: + return self.unet(sample, timestamp, encoder_hidden_states, return_dict) \ No newline at end of file diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index 043d2626d43c..16f59c4fe70b 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -6,6 +6,7 @@ def has_moe_layers(m): has_moe = False num_experts = 0 + for _, module in m.named_modules(): if isinstance(module, MoE): has_moe = True From c8163e28a1c2771ca4247071421a8e9154f016bb Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 4 Oct 2022 17:09:00 -0700 Subject: [PATCH 02/20] formatting --- deepspeed/module_inject/unet.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py index 7eb1e6a68f1b..8d3b1de2a4ab 100644 --- a/deepspeed/module_inject/unet.py +++ b/deepspeed/module_inject/unet.py @@ -24,13 +24,20 @@ def forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): # force return tuple instead of dict self._traced_unet = torch.jit.trace( - lambda _sample, _timestamp, _encoder_hidden_states: self.unet(_sample, _timestamp, _encoder_hidden_states, return_dict=False), - (sample, timestamp, encoder_hidden_states) - ) + lambda _sample, + _timestamp, + _encoder_hidden_states: self.unet(_sample, + _timestamp, + _encoder_hidden_states, + return_dict=False), + (sample, + timestamp, + encoder_hidden_states)) return self.unet(sample, timestamp, encoder_hidden_states) else: # convert return type to UNet2DConditionOutput out_sample, *_ = self._traced_unet(sample, timestamp, encoder_hidden_states) - return diffusers.models.unet_2d_condition.UNet2DConditionOutput(out_sample) + return diffusers.models.unet_2d_condition.UNet2DConditionOutput( + out_sample) else: - return self.unet(sample, timestamp, encoder_hidden_states, return_dict) \ No newline at end of file + return self.unet(sample, timestamp, encoder_hidden_states, return_dict) From f34eb5bb4107a8334ec0f04e5f6b5ca447b4bf15 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 7 Oct 2022 23:53:37 +0500 Subject: [PATCH 03/20] add attention kernel and enable cuda-graph for SD models --- csrc/transformer/inference/csrc/gelu.cu | 36 ++++++++++ .../transformer/inference/csrc/pt_binding.cpp | 71 ++++++++++++++++-- .../inference/includes/inference_context.h | 4 +- .../includes/inference_cuda_layers.h | 3 + deepspeed/inference/engine.py | 2 +- deepspeed/module_inject/replace_module.py | 58 ++++++++++++++- deepspeed/module_inject/replace_policy.py | 72 +++++++++++++++++++ deepspeed/module_inject/unet.py | 48 +++++++++++-- deepspeed/ops/transformer/__init__.py | 1 + .../ops/transformer/inference/__init__.py | 1 + .../inference/transformer_inference.py | 6 +- 11 files changed, 287 insertions(+), 15 deletions(-) diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index e648e13095ef..949b736592df 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -475,3 +475,39 @@ template void launch_moe_res_matmul(__half* residual, int seq_len, int hidden_dim, cudaStream_t stream); + +__global__ void pad_data_kernel(__half* padded_output, __half* output, int head_size, int padded_head_size) +{ + float4 *padded_output_cast = reinterpret_cast(padded_output); + float4 *output_cast = reinterpret_cast(output); + int bid = blockIdx.x * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bid * padded_head_size); + output_cast += (bid * head_size); + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2 *ZERO_h = reinterpret_cast<__half2*>(&ZERO); + #pragma unroll + for (int i = 0;i < 4;i++) + ZERO_h[i] = zero_h; + if (idx < head_size) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} +__global__ void pad_data_kernel(float* padded_output, float* output, int head_size, int padded_head_size) +{ +} +template +void pad_data(T* padded_output, T* output, int bsz, int head_size, int padded_head_size, + cudaStream_t stream) +{ + + dim3 grid_dim((bsz-1) / 16 + 1); + dim3 block_dim(padded_head_size / 8, 16); + pad_data_kernel<<>>(padded_output, output, head_size / 8, padded_head_size / 8); +} +template void pad_data(__half* padded_output, __half* output, int bsz, int head_size, int padded_head_size, + cudaStream_t stream); +template void pad_data(float* padded_output, float* output, int bsz, int head_size, int padded_head_size, + cudaStream_t stream); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 65549cdcd71a..d7f1d7a76f3f 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -835,6 +835,10 @@ template at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias, + bool add_bias, + bool external_cache, + bool do_flash_attn, + int num_heads, unsigned num_layers) { auto input_cont = input.contiguous(); @@ -844,13 +848,14 @@ at::Tensor ds_linear_layer(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); + int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); // Reallocate memory if we received a new prompt - if (!workspace || input.size(1) != 1) { + if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), input.size(0), num_layers); + allocate_workspace(input.size(2), input.size(0), num_layers, 1, external_cache); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); @@ -875,14 +880,68 @@ at::Tensor ds_linear_layer(at::Tensor& input, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif - - launch_bias_add((T*)output.data_ptr(), + if (add_bias) + launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), weight.size(1), bsz, Context::Instance().GetCurrentStream()); - - return output; + bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); + if (do_flash_attn) + { + if (add_padding){ + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + auto padded_output = workspace + output.numel(); + auto final_output = padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); + pad_data(padded_output, workspace, 3*bsz * num_heads, head_size, padded_head_size, Context::Instance().GetCurrentStream()); + + launch_bias_add_transform_0213(final_output, + final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), + final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), + padded_output, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + (num_heads * padded_head_size), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), padded_head_size}, options); + //return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, padded_head_size}, options); + } + else + { + auto final_output = workspace + output.numel(); + launch_bias_add_transform_0213(final_output, + final_output + (input.size(0) * input.size(1) * input_cont.size(2)), + final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), + workspace, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + input_cont.size(2), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); + //return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, head_size}, options); + } + + } + else + return output; } template diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 6725fc72fb7c..e535972d2be3 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -98,15 +98,15 @@ class Context { size_t total_size; if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } - size_t activation_size = 16 * hidden_dim * batch_size; + size_t activation_size = 16 * hidden_dim * batch_size + MAX_OUT_TOKENS * 16 * batch_size; size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; _max_seq_len = (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / elem_size)) / (activation_size + cache_size); + _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) * _max_seq_len * elem_size; - _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); if (rank == 0 && !_workspace) printf( "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 1f86e2d858d1..a10448d9d6e2 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -172,3 +172,6 @@ void launch_bias_add_transform_0213(T* outputs, cudaStream_t stream, int trans_count, int max_out_tokens); +template +void pad_data(T* padded_output, T* output, int bsz, int head_size, int padded_head_size, + cudaStream_t stream); \ No newline at end of file diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 2c44648a2091..b716a20d1cb0 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -360,7 +360,7 @@ def _apply_injection_policy(self, checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - generic_injection(self.module) + generic_injection(self.module, fp16=(self.dtype == torch.half) or (self.dtype == torch.int8)) if isinstance(self.module, torch.nn.Module): replace_transformer_layer( diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index b730aabb210c..823533510784 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -195,14 +195,70 @@ def _module_match(module): return None -def generic_injection(module): +def generic_injection(module, fp16=False): + + def replace_attn(child, policy, layer_id): + policy_attn = policy.attention(child) + if policy_attn is None: + return child + if len(policy_attn) == 5: + qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn + else: + qw, kvw, attn_ow, attn_ob, hidden_size, heads = policy_attn + + config = transformer_inference.DeepSpeedInferenceConfig( + hidden_size=hidden_size, + heads=heads, + fp16=fp16, + triangular_masking=False,) + attn_module = transformer_inference.DeepSpeedAttention(config) + def transpose(data): + data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) + data = data.reshape(data.shape[-1], data.shape[-2]) + data.to(torch.cuda.current_device()) + return data + if len(policy_attn) == 5: + attn_module.attn_qkvw.data = transpose(qkvw.data) + else: + attn_module.attn_qkvw = None + attn_module.attn_qw.data = transpose(qw.data) + attn_module.attn_kvw.data = transpose(kvw.data) + + attn_module.attn_qkvb = None + attn_module.attn_ow.data = transpose(attn_ow.data) + attn_module.attn_ob.data.copy_(attn_ob.data.to(torch.cuda.current_device())) + return attn_module + if isinstance(module, torch.nn.Module): pass else: + try: + import diffusers + cross_attention = diffusers.models.attention.CrossAttention + new_policies = {cross_attention: replace_attn} + except ImportError: + new_policies = {} + + #replace_transformer_layer(None, module.text_encoder, training=False, + # replace_with_kernel_inject=True, + # triangular_masking=True) for name in module.__dict__.keys(): sub_module = getattr(module, name) policy = _module_match(sub_module) + if policy is not None: + def _replace_module(module, policy, layer_id=0): + for name, child in module.named_children(): + if child.__class__ in new_policies: + replaced_module = new_policies[child.__class__](child, + policy, + layer_id) + setattr(module, name, replaced_module) + layer_id += 1 + else: + layer_id = _replace_module(child, policy, layer_id=layer_id) + return layer_id + _replace_module(sub_module, policy) new_module = policy.apply(sub_module) print(f"**** found and replaced {name} w. {type(new_module)}") setattr(module, name, new_module) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 5c47fa71213e..f02e81439274 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -18,6 +18,13 @@ class DSPolicy(ABC): def __init__(self): self.cuda_graph_supported = False + def attention(self): + """ + Returns attention qkv and dense parameters + weight: (3*hidden, hidden) and (hidden, hidden) + bias: (3*hidden) and (hidden) + """ + raise NotImplementedError class UNetPolicy(DSPolicy): def __init__(self): @@ -34,6 +41,21 @@ def match(self, module): def apply(self, module): from .unet import DSUNet return DSUNet(module) + def attention(self, client_module): + qw = client_module.to_q.weight + kw = client_module.to_k.weight + vw = client_module.to_v.weight + + if qw.shape[1] == kw.shape[1]: + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + + return qkvw, \ + client_module.to_out[0].weight, \ + client_module.to_out[0].bias, \ + qw.shape[-1], \ + client_module.heads + else: + return None class TransformerPolicy(DSPolicy): @@ -151,6 +173,55 @@ def layerNorm(self): transformer_layernorm.weight, \ transformer_layernorm.bias +class HFCLIPLayerPolicy(TransformerPolicy): + def __init__(self, client_module, inference=False): + super().__init__(inference, pre_attn_norm=True, scale_attention=False) + self.client_module = client_module + self.cuda_graph_supported = True + + if HFCLIPLayerPolicy._orig_layer_class is None: + try: + import transformers + HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer + except: + HFCLIPLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.self_attn.q_proj.weight.shape[1], \ + self.client_module.self_attn.num_heads + + def attention(self): + qw = self.client_module.self_attn.q_proj.weight + qb = self.client_module.self_attn.q_proj.bias + kw = self.client_module.self_attn.k_proj.weight + kb = self.client_module.self_attn.k_proj.bias + vw = self.client_module.self_attn.v_proj.weight + vb = self.client_module.self_attn.v_proj.bias + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) + + return self.linear_layer, \ + qkvw, \ + qkvb, \ + self.client_module.self_attn.out_proj.weight, \ + self.client_module.self_attn.out_proj.bias, \ + self.scale_attention, \ + self.is_megatron_v2 + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.fc1.weight, \ + self.client_module.mlp.fc1.bias, \ + self.client_module.mlp.fc2.weight, \ + self.client_module.mlp.fc2.bias + + def layerNorm(self): + return self.client_module.layer_norm2.weight, \ + self.client_module.layer_norm2.bias, \ + self.client_module.layer_norm1.weight, \ + self.client_module.layer_norm1.bias + class HFGPTNEOLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=True): @@ -523,6 +594,7 @@ def layerNorm(self): HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, + HFCLIPLayerPolicy, ] # non-transformer-based policies diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py index 8d3b1de2a4ab..fb0fe3eb8a61 100644 --- a/deepspeed/module_inject/unet.py +++ b/deepspeed/module_inject/unet.py @@ -12,15 +12,55 @@ def __init__(self, unet): # SD pipeline accesses this attribute self.in_channels = unet.in_channels self._traced_unet = None - self._enabled = True + self._trace_enabled = False self.device = self.unet.device + self.fwd_count = 0 self.unet.requires_grad_(requires_grad=False) + self.unet.to(memory_format=torch.channels_last) + self.cuda_graph_created = False - def forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): - if self._enabled: + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def forward(self, *inputs, **kwargs): + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + return outputs + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): + if self._trace_enabled: if self._traced_unet is None: # boosts perf ~10% - self.unet.to(memory_format=torch.channels_last) + print("Unet: start tracing with Nvfuser") # force return tuple instead of dict self._traced_unet = torch.jit.trace( diff --git a/deepspeed/ops/transformer/__init__.py b/deepspeed/ops/transformer/__init__.py index 28c8de68dd8b..a2be76b4b7b2 100755 --- a/deepspeed/ops/transformer/__init__.py +++ b/deepspeed/ops/transformer/__init__.py @@ -1,3 +1,4 @@ from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .inference.transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference +from .inference.attention import DeepSpeedAttention \ No newline at end of file diff --git a/deepspeed/ops/transformer/inference/__init__.py b/deepspeed/ops/transformer/inference/__init__.py index 66f8124015cb..a2f37b2347c5 100644 --- a/deepspeed/ops/transformer/inference/__init__.py +++ b/deepspeed/ops/transformer/inference/__init__.py @@ -1,2 +1,3 @@ from .transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference +from .attention import DeepSpeedAttention \ No newline at end of file diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index ff50d43d8518..cb9a8f964368 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -408,10 +408,14 @@ def selfAttention_fp(): if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 - + qkv_out = linear_func(input, attn_qkvw, attn_qkvb, + attn_qkvb is None, + False, + False, + num_attention_heads_per_partition, DeepSpeedTransformerInference.layer_id) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ From 5aec40c270639b851af3753c79dd40375b3c3303 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 8 Oct 2022 00:23:03 +0500 Subject: [PATCH 04/20] add new files --- .../ops/transformer/inference/attention.py | 225 ++++++++++++++++++ .../ops/transformer/inference/triton_ops.py | 148 ++++++++++++ 2 files changed, 373 insertions(+) create mode 100644 deepspeed/ops/transformer/inference/attention.py create mode 100644 deepspeed/ops/transformer/inference/triton_ops.py diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py new file mode 100644 index 000000000000..ce470badbb99 --- /dev/null +++ b/deepspeed/ops/transformer/inference/attention.py @@ -0,0 +1,225 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import json +import math +import torch +from torch.autograd import Function +from ... import op_builder +import torch.nn as nn +from deepspeed import comm as dist +from deepspeed.utils.logging import log_dist +from deepspeed.utils.types import ActivationFuncType +from .triton_ops import triton_flash_attn +# Cuda modules will be imported if needed +inference_cuda_module = None +minus_inf = -10000.0 + +class DeepSpeedAttentionFunction(Function): + @staticmethod + def forward(ctx, + input, + input_mask, + config, + attn_qkvw, + attn_qkvb, + num_attention_heads_per_partition, + norm_factor, + hidden_size_per_partition, + attn_ow, + attn_ob, + score_context_func, + linear_func, + triton_flash_attn_kernel): + def _transpose_for_context(x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_layer_shape = x.size()[:-2] + \ + (hidden_size_per_partition,) + return x.view(*new_x_layer_shape).contiguous() + + def compute_attention(qkv_out, input_mask): + no_masking = input_mask is None + + head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) + if no_masking: + input_mask = torch.empty(1) + + context_layer, _, _ = score_context_func( + qkv_out, + ((1 - input_mask).to(qkv_out.dype) * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + config.rotary_dim, + config.rotate_half, + config.rotate_every_two, + num_attention_heads_per_partition, + (1 / norm_factor if config.scale_attention else 1.0), + config.triangular_masking, + config.local_attention, + config.window_size, + no_masking, + config.layer_id, + DeepSpeedAttention.layer_id, + torch.empty(1)) + return context_layer + + def selfAttention_fp(input, input_mask): + if config.fp16 and input.dtype == torch.float32: + input = input.half() + head_size = input.shape[-1] // config.heads + do_flash_attn = (input.shape[-2] % 128 == 0) and (head_size <= 128) + qkv_out = linear_func(input, + attn_qkvw, + attn_qkvb if attn_qkvb is not None else attn_qkvw, + attn_qkvb is not None, + True, + do_flash_attn, + config.heads, + DeepSpeedAttention.layer_id) + if do_flash_attn: + scale = (1 / norm_factor) * 1 / norm_factor + context_layer = triton_flash_attn_kernel(qkv_out[0], qkv_out[1], qkv_out[2], scale) + context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + else: + context_layer = compute_attention(qkv_out, input_mask) + + output = linear_func(context_layer, + attn_ow, + attn_ob, + attn_ob is not None, + True, + False, + config.heads, + DeepSpeedAttention.layer_id) + return output + output = selfAttention_fp(input, input_mask) + + return output + + @staticmethod + def backward(ctx, grad_output, grad_output1, grad_output2, grad_output3): + raise RuntimeError('You are running with DeepSpeed Inference mode. \ + Please switch to Training mode for running backward!') + + +class DeepSpeedAttention(nn.Module): + """Initialize the DeepSpeed Transformer Layer. + Arguments: + layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers, + layer_id will be 0,1,2...23 when each layer object is instantiated + config: An object of DeepSpeedInferenceConfig + """ + layer_id = 0 + + def __init__(self, + config,): + super(DeepSpeedAttention, self).__init__() + + self.config = config + self.config.layer_id = DeepSpeedAttention.layer_id + DeepSpeedAttention.layer_id += 1 + device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float + global inference_cuda_module + if inference_cuda_module is None: + builder = op_builder.InferenceBuilder() + inference_cuda_module = builder.load() + + if DeepSpeedAttention.layer_id == 1: + log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0]) + + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, + dtype=data_type_fp, + device=device), + requires_grad=False) + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.triton_flash_attn_kernel = triton_flash_attn() + self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size + self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size + self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads + + self.norm_factor = math.sqrt( + math.sqrt(self.config.hidden_size // self.config.heads)) + + self.score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ + inference_cuda_module.softmax_context_fp16 + self.linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ + inference_cuda_module.linear_layer_fp32 + self.cuda_graph_created = False + + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def forward(self, *inputs, **kwargs): + if False: + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + else: + outputs = self._forward(*inputs, **kwargs) + return outputs + + def _forward(self, + input, + input_mask=None): + output = DeepSpeedAttentionFunction.apply( + input, + input_mask, + self.config, + self.attn_qkvw, + self.attn_qkvb, + self.num_attention_heads_per_partition, + self.norm_factor, + self.hidden_size_per_partition, + self.attn_ow, + self.attn_ob, + self.score_context_func, + self.linear_func, + self.triton_flash_attn_kernel) + + return output diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py new file mode 100644 index 000000000000..20c3c6996274 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -0,0 +1,148 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, N_CTX, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +class triton_flash_attn(torch.nn.Module): + def __init__(self, ): + super(triton_flash_attn, self).__init__() + + def forward(self, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], + q.shape[2]), + device=q.device, + dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + return o \ No newline at end of file From 23674fd81ce13a92dbababc83beedf365f798bff Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 7 Oct 2022 14:56:48 -0700 Subject: [PATCH 05/20] formatting and add dtype to unet --- csrc/transformer/inference/csrc/gelu.cu | 53 +++++--- .../transformer/inference/csrc/pt_binding.cpp | 117 ++++++++++-------- .../includes/inference_cuda_layers.h | 8 +- deepspeed/inference/engine.py | 3 +- deepspeed/module_inject/replace_module.py | 16 ++- deepspeed/module_inject/replace_policy.py | 3 + deepspeed/module_inject/unet.py | 5 +- deepspeed/ops/transformer/__init__.py | 2 +- .../ops/transformer/inference/__init__.py | 2 +- .../ops/transformer/inference/attention.py | 51 ++++---- .../inference/transformer_inference.py | 2 +- .../ops/transformer/inference/triton_ops.py | 4 +- 12 files changed, 153 insertions(+), 113 deletions(-) diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 949b736592df..d386714b03da 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -476,38 +476,55 @@ template void launch_moe_res_matmul(__half* residual, int hidden_dim, cudaStream_t stream); -__global__ void pad_data_kernel(__half* padded_output, __half* output, int head_size, int padded_head_size) +__global__ void pad_data_kernel(__half* padded_output, + __half* output, + int head_size, + int padded_head_size) { - float4 *padded_output_cast = reinterpret_cast(padded_output); - float4 *output_cast = reinterpret_cast(output); + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); int bid = blockIdx.x * (blockDim.y) + threadIdx.y; int idx = threadIdx.x; padded_output_cast += (bid * padded_head_size); output_cast += (bid * head_size); float4 ZERO; const __half2 zero_h = __float2half2_rn(0.f); - __half2 *ZERO_h = reinterpret_cast<__half2*>(&ZERO); - #pragma unroll - for (int i = 0;i < 4;i++) - ZERO_h[i] = zero_h; + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; if (idx < head_size) padded_output_cast[idx] = output_cast[idx]; else padded_output_cast[idx] = ZERO; } -__global__ void pad_data_kernel(float* padded_output, float* output, int head_size, int padded_head_size) +__global__ void pad_data_kernel(float* padded_output, + float* output, + int head_size, + int padded_head_size) { -} +} template -void pad_data(T* padded_output, T* output, int bsz, int head_size, int padded_head_size, - cudaStream_t stream) +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream) { - - dim3 grid_dim((bsz-1) / 16 + 1); + dim3 grid_dim((bsz - 1) / 16 + 1); dim3 block_dim(padded_head_size / 8, 16); - pad_data_kernel<<>>(padded_output, output, head_size / 8, padded_head_size / 8); + pad_data_kernel<<>>( + padded_output, output, head_size / 8, padded_head_size / 8); } -template void pad_data(__half* padded_output, __half* output, int bsz, int head_size, int padded_head_size, - cudaStream_t stream); -template void pad_data(float* padded_output, float* output, int bsz, int head_size, int padded_head_size, - cudaStream_t stream); +template void pad_data(__half* padded_output, + __half* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); +template void pad_data(float* padded_output, + float* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index d7f1d7a76f3f..acb6fd604c2d 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -882,65 +882,74 @@ at::Tensor ds_linear_layer(at::Tensor& input, #endif if (add_bias) launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); - if (do_flash_attn) - { - if (add_padding){ + if (do_flash_attn) { + if (add_padding) { int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); auto padded_output = workspace + output.numel(); - auto final_output = padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); - pad_data(padded_output, workspace, 3*bsz * num_heads, head_size, padded_head_size, Context::Instance().GetCurrentStream()); - - launch_bias_add_transform_0213(final_output, - final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), - final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), - padded_output, - nullptr, - input.size(0), - input.size(1), - 0, - input.size(1), - (num_heads * padded_head_size), - num_heads, - -1, - false, - false, - Context::Instance().GetCurrentStream(), - 3, - input.size(1)); - return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), padded_head_size}, options); - //return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, padded_head_size}, options); - } - else - { + auto final_output = + padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); + pad_data(padded_output, + workspace, + 3 * bsz * num_heads, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), + final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), + padded_output, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + (num_heads * padded_head_size), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob(final_output, + {3, input.size(0), num_heads, input.size(1), padded_head_size}, + options); + // return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, + // padded_head_size}, options); + } else { auto final_output = workspace + output.numel(); - launch_bias_add_transform_0213(final_output, - final_output + (input.size(0) * input.size(1) * input_cont.size(2)), - final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), - workspace, - nullptr, - input.size(0), - input.size(1), - 0, - input.size(1), - input_cont.size(2), - num_heads, - -1, - false, - false, - Context::Instance().GetCurrentStream(), - 3, - input.size(1)); - return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); - //return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, head_size}, options); + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * input_cont.size(2)), + final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), + workspace, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + input_cont.size(2), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob( + final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); + // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, + // head_size}, options); } - - } - else + + } else return output; } diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index a10448d9d6e2..0a502c1a2220 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -173,5 +173,9 @@ void launch_bias_add_transform_0213(T* outputs, int trans_count, int max_out_tokens); template -void pad_data(T* padded_output, T* output, int bsz, int head_size, int padded_head_size, - cudaStream_t stream); \ No newline at end of file +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index b716a20d1cb0..6f9d90e3f743 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -360,7 +360,8 @@ def _apply_injection_policy(self, checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - generic_injection(self.module, fp16=(self.dtype == torch.half) or (self.dtype == torch.int8)) + generic_injection(self.module, + fp16=(self.dtype == torch.half) or (self.dtype == torch.int8)) if isinstance(self.module, torch.nn.Module): replace_transformer_layer( diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 823533510784..b6b22ad08d61 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -196,7 +196,6 @@ def _module_match(module): def generic_injection(module, fp16=False): - def replace_attn(child, policy, layer_id): policy_attn = policy.attention(child) if policy_attn is None: @@ -207,16 +206,19 @@ def replace_attn(child, policy, layer_id): qw, kvw, attn_ow, attn_ob, hidden_size, heads = policy_attn config = transformer_inference.DeepSpeedInferenceConfig( - hidden_size=hidden_size, - heads=heads, - fp16=fp16, - triangular_masking=False,) + hidden_size=hidden_size, + heads=heads, + fp16=fp16, + triangular_masking=False, + ) attn_module = transformer_inference.DeepSpeedAttention(config) + def transpose(data): data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) data.to(torch.cuda.current_device()) return data + if len(policy_attn) == 5: attn_module.attn_qkvw.data = transpose(qkvw.data) else: @@ -245,8 +247,9 @@ def transpose(data): for name in module.__dict__.keys(): sub_module = getattr(module, name) policy = _module_match(sub_module) - + if policy is not None: + def _replace_module(module, policy, layer_id=0): for name, child in module.named_children(): if child.__class__ in new_policies: @@ -258,6 +261,7 @@ def _replace_module(module, policy, layer_id=0): else: layer_id = _replace_module(child, policy, layer_id=layer_id) return layer_id + _replace_module(sub_module, policy) new_module = policy.apply(sub_module) print(f"**** found and replaced {name} w. {type(new_module)}") diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index f02e81439274..3924ed9c68f6 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -26,6 +26,7 @@ def attention(self): """ raise NotImplementedError + class UNetPolicy(DSPolicy): def __init__(self): super().__init__() @@ -41,6 +42,7 @@ def match(self, module): def apply(self, module): from .unet import DSUNet return DSUNet(module) + def attention(self, client_module): qw = client_module.to_q.weight kw = client_module.to_k.weight @@ -173,6 +175,7 @@ def layerNorm(self): transformer_layernorm.weight, \ transformer_layernorm.bias + class HFCLIPLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=False): super().__init__(inference, pre_attn_norm=True, scale_attention=False) diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py index fb0fe3eb8a61..9eb2f7303fec 100644 --- a/deepspeed/module_inject/unet.py +++ b/deepspeed/module_inject/unet.py @@ -14,6 +14,7 @@ def __init__(self, unet): self._traced_unet = None self._trace_enabled = False self.device = self.unet.device + self.dtype = self.unet.dtype self.fwd_count = 0 self.unet.requires_grad_(requires_grad=False) self.unet.to(memory_format=torch.channels_last) @@ -34,9 +35,9 @@ def forward(self, *inputs, **kwargs): outputs = self._graph_replay(*inputs, **kwargs) else: self._create_cuda_graph(*inputs, **kwargs) - outputs = self._graph_replay(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) return outputs - + def _create_cuda_graph(self, *inputs, **kwargs): # warmup to create the workspace and cublas handle cuda_stream = torch.cuda.Stream() diff --git a/deepspeed/ops/transformer/__init__.py b/deepspeed/ops/transformer/__init__.py index a2be76b4b7b2..49b543551f4b 100755 --- a/deepspeed/ops/transformer/__init__.py +++ b/deepspeed/ops/transformer/__init__.py @@ -1,4 +1,4 @@ from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .inference.transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference -from .inference.attention import DeepSpeedAttention \ No newline at end of file +from .inference.attention import DeepSpeedAttention diff --git a/deepspeed/ops/transformer/inference/__init__.py b/deepspeed/ops/transformer/inference/__init__.py index a2f37b2347c5..f5b042d7fdfb 100644 --- a/deepspeed/ops/transformer/inference/__init__.py +++ b/deepspeed/ops/transformer/inference/__init__.py @@ -1,3 +1,3 @@ from .transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference -from .attention import DeepSpeedAttention \ No newline at end of file +from .attention import DeepSpeedAttention diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index ce470badbb99..a122c2608750 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -1,20 +1,18 @@ ''' Copyright 2020 The Microsoft DeepSpeed Team ''' -import json import math import torch from torch.autograd import Function from ... import op_builder import torch.nn as nn -from deepspeed import comm as dist from deepspeed.utils.logging import log_dist -from deepspeed.utils.types import ActivationFuncType from .triton_ops import triton_flash_attn # Cuda modules will be imported if needed inference_cuda_module = None minus_inf = -10000.0 + class DeepSpeedAttentionFunction(Function): @staticmethod def forward(ctx, @@ -77,7 +75,10 @@ def selfAttention_fp(input, input_mask): DeepSpeedAttention.layer_id) if do_flash_attn: scale = (1 / norm_factor) * 1 / norm_factor - context_layer = triton_flash_attn_kernel(qkv_out[0], qkv_out[1], qkv_out[2], scale) + context_layer = triton_flash_attn_kernel(qkv_out[0], + qkv_out[1], + qkv_out[2], + scale) context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) else: context_layer = compute_attention(qkv_out, input_mask) @@ -91,6 +92,7 @@ def selfAttention_fp(input, input_mask): config.heads, DeepSpeedAttention.layer_id) return output + output = selfAttention_fp(input, input_mask) return output @@ -110,8 +112,10 @@ class DeepSpeedAttention(nn.Module): """ layer_id = 0 - def __init__(self, - config,): + def __init__( + self, + config, + ): super(DeepSpeedAttention, self).__init__() self.config = config @@ -199,27 +203,24 @@ def forward(self, *inputs, **kwargs): outputs = self._graph_replay(*inputs, **kwargs) else: self._create_cuda_graph(*inputs, **kwargs) - outputs = self._graph_replay(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) else: outputs = self._forward(*inputs, **kwargs) return outputs - - def _forward(self, - input, - input_mask=None): - output = DeepSpeedAttentionFunction.apply( - input, - input_mask, - self.config, - self.attn_qkvw, - self.attn_qkvb, - self.num_attention_heads_per_partition, - self.norm_factor, - self.hidden_size_per_partition, - self.attn_ow, - self.attn_ob, - self.score_context_func, - self.linear_func, - self.triton_flash_attn_kernel) + + def _forward(self, input, input_mask=None): + output = DeepSpeedAttentionFunction.apply(input, + input_mask, + self.config, + self.attn_qkvw, + self.attn_qkvb, + self.num_attention_heads_per_partition, + self.norm_factor, + self.hidden_size_per_partition, + self.attn_ow, + self.attn_ob, + self.score_context_func, + self.linear_func, + self.triton_flash_attn_kernel) return output diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index cb9a8f964368..ee68b055fd2a 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -408,7 +408,7 @@ def selfAttention_fp(): if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 - + qkv_out = linear_func(input, attn_qkvw, attn_qkvb, diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index 20c3c6996274..b26108d0cec5 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -100,7 +100,7 @@ def _fwd_kernel( class triton_flash_attn(torch.nn.Module): def __init__(self, ): super(triton_flash_attn, self).__init__() - + def forward(self, q, k, v, sm_scale): BLOCK = 128 # shape constraints @@ -145,4 +145,4 @@ def forward(self, q, k, v, sm_scale): num_warps=num_warps, num_stages=1, ) - return o \ No newline at end of file + return o From 6bf0c738adfc8d62e2b6cf0138d827dde6d3e163 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 10 Oct 2022 23:05:44 +0500 Subject: [PATCH 06/20] adding more optitmization by enabling ds-encoder with CUDA-Graph --- deepspeed/module_inject/replace_module.py | 11 ++++++++--- deepspeed/module_inject/unet.py | 2 -- deepspeed/ops/transformer/inference/attention.py | 8 +++++--- deepspeed/ops/transformer/inference/triton_ops.py | 10 +++++----- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index b6b22ad08d61..b1186f1777f7 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -241,9 +241,14 @@ def transpose(data): except ImportError: new_policies = {} - #replace_transformer_layer(None, module.text_encoder, training=False, - # replace_with_kernel_inject=True, - # triangular_masking=True) + replace_transformer_layer(None, + module.text_encoder, + training=False, + replace_with_kernel_inject=True, + triangular_masking=True) + from .encoder import DSClipEncoder + cg_encoder = DSClipEncoder(module.text_encoder) + setattr(module, 'text_encoder', cg_encoder) for name in module.__dict__.keys(): sub_module = getattr(module, name) policy = _module_match(sub_module) diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py index 9eb2f7303fec..a667d94e6c83 100644 --- a/deepspeed/module_inject/unet.py +++ b/deepspeed/module_inject/unet.py @@ -60,9 +60,7 @@ def _create_cuda_graph(self, *inputs, **kwargs): def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): if self._trace_enabled: if self._traced_unet is None: - # boosts perf ~10% print("Unet: start tracing with Nvfuser") - # force return tuple instead of dict self._traced_unet = torch.jit.trace( lambda _sample, diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index a122c2608750..1d1b377bc119 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -64,7 +64,7 @@ def selfAttention_fp(input, input_mask): if config.fp16 and input.dtype == torch.float32: input = input.half() head_size = input.shape[-1] // config.heads - do_flash_attn = (input.shape[-2] % 128 == 0) and (head_size <= 128) + do_flash_attn = (head_size <= 128) qkv_out = linear_func(input, attn_qkvw, attn_qkvb if attn_qkvb is not None else attn_qkvw, @@ -78,7 +78,8 @@ def selfAttention_fp(input, input_mask): context_layer = triton_flash_attn_kernel(qkv_out[0], qkv_out[1], qkv_out[2], - scale) + scale, + input.shape[-2] % 128 == 0) context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) else: context_layer = compute_attention(qkv_out, input_mask) @@ -167,6 +168,7 @@ def __init__( self.linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 self.cuda_graph_created = False + self.enable_cuda_graph = False def _graph_replay(self, *inputs, **kwargs): for i in range(len(inputs)): @@ -198,7 +200,7 @@ def _create_cuda_graph(self, *inputs, **kwargs): self.cuda_graph_created = True def forward(self, *inputs, **kwargs): - if False: + if self.enable_cuda_graph: if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) else: diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index b26108d0cec5..b0be8a1fefdc 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -101,8 +101,8 @@ class triton_flash_attn(torch.nn.Module): def __init__(self, ): super(triton_flash_attn, self).__init__() - def forward(self, q, k, v, sm_scale): - BLOCK = 128 + def forward(self, q, k, v, sm_scale, block_128=True): + BLOCK = 128 if block_128 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] o = torch.empty_like(q) @@ -136,9 +136,9 @@ def forward(self, q, k, v, sm_scale): o.stride(1), o.stride(2), o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], + k.shape[0], + k.shape[1], + k.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, From ac29fdcc4604e6f1829fd16f806ea2b8075f0ee8 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 10 Oct 2022 23:31:11 +0500 Subject: [PATCH 07/20] add missing file --- deepspeed/module_inject/encoder.py | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 deepspeed/module_inject/encoder.py diff --git a/deepspeed/module_inject/encoder.py b/deepspeed/module_inject/encoder.py new file mode 100644 index 000000000000..b9cd64e0744b --- /dev/null +++ b/deepspeed/module_inject/encoder.py @@ -0,0 +1,60 @@ +''' +Copyright 2022 The Microsoft DeepSpeed Team +''' +import torch +import diffusers + + +class DSClipEncoder(torch.nn.Module): + def __init__(self, enc): + super().__init__() + enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask + self.enc = enc + self.device = self.enc.device + self.dtype = self.enc.dtype + self.cuda_graph_created = False + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=torch.cuda.current_device()) + return mask + + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def forward(self, *inputs, **kwargs): + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + return outputs + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def _forward(self, *inputs, **kwargs): + + return self.enc(*inputs, **kwargs) From 27e752be563384e318d616d889bbadbf495247c4 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 11 Oct 2022 06:47:03 +0500 Subject: [PATCH 08/20] adapt the triton kernel to be used in more places --- csrc/transformer/inference/csrc/gelu.cu | 65 ++++++++++++++ .../transformer/inference/csrc/pt_binding.cpp | 45 ++++++++++ csrc/transformer/inference/csrc/softmax.cu | 5 +- .../includes/inference_cuda_layers.h | 10 +++ deepspeed/module_inject/replace_policy.py | 9 +- .../ops/transformer/inference/attention.py | 90 ++++++++++++++----- .../ops/transformer/inference/triton_ops.py | 6 +- 7 files changed, 201 insertions(+), 29 deletions(-) diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index d386714b03da..43a618bf7812 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -528,3 +528,68 @@ template void pad_data(float* padded_output, int head_size, int padded_head_size, cudaStream_t stream); + +__global__ void pad_head_seq_kernel(__half* padded_output, + __half* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); + int bsz = blockIdx.x; + int bid = blockIdx.y * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; + output_cast += (bsz * seq_len + bid) * head_size; + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + if (idx < head_size && bid < seq_len) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} +__global__ void pad_head_seq_kernel(float* padded_output, + float* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ +} +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream) +{ + dim3 grid_dim(bsz, padded_seq_len / 16); + dim3 block_dim(padded_head_size / 8, 16); + pad_head_seq_kernel<<>>( + padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); +} +template void pad_head_seq(__half* padded_output, + __half* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); +template void pad_head_seq(float* padded_output, + float* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); \ No newline at end of file diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 514cd7b3dc33..099b942abfd4 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -953,6 +953,45 @@ at::Tensor ds_linear_layer(at::Tensor& input, return output; } +template +std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value) +{ + int head_size = query.size(3); + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; + pad_head_seq(workspace, + (T*)query.data_ptr(), + query.size(0) * query.size(1), + query.size(2), + query.size(2), + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + pad_head_seq(key_pad_ptr, + (T*)key.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + pad_head_seq(value_pad_ptr, + (T*)value.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + return { + at::from_blob(workspace, {query.size(0), query.size(1), query.size(2), padded_head_size}, query.options()), + at::from_blob(key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()), + at::from_blob(value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()) + }; +} + template at::Tensor ds_linear_layer_int8(at::Tensor& input, at::Tensor& weight, @@ -1493,4 +1532,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &einsum_sec_sm_ecm<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); + m.def("add_padding_fp32", + &add_padding, + "DeepSpeed residual add with fp32 (CUDA)"); + m.def("add_padding_fp16", + &add_padding<__half>, + "DeepSpeed residual add with fp16 (CUDA)"); } diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index ce7c2e77759d..b85ac1eb0be8 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -12,7 +12,7 @@ Copyright 2022 The Microsoft DeepSpeed Team #include #include -#define ATTN_THREADS 1024 +#define ATTN_THREADS 256 #define MAX_REG_SIZE 8 #define minus_infinity -10000.0 @@ -427,7 +427,8 @@ void launch_attn_softmax_v2(T* vals, cudaStream_t stream) { int total_count = batch_size * heads * num_seq; - dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); + int warp_num = ATTN_THREADS / WARP_SIZE; + dim3 grid_dim((total_count - 1) / (warp_num / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); dim3 block_dim(ATTN_THREADS); const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 0a502c1a2220..b41a5d618ff0 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -179,3 +179,13 @@ void pad_data(T* padded_output, int head_size, int padded_head_size, cudaStream_t stream); + +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); \ No newline at end of file diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 3924ed9c68f6..ce8a2770f437 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -57,7 +57,14 @@ def attention(self, client_module): qw.shape[-1], \ client_module.heads else: - return None + #return None + kvw = Parameter(torch.cat((kw, vw), dim=0), requires_grad=False) + return qw, \ + kvw, \ + client_module.to_out[0].weight, \ + client_module.to_out[0].bias, \ + qw.shape[-1], \ + client_module.heads class TransformerPolicy(DSPolicy): diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index 1d1b377bc119..713224f3131c 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -17,9 +17,12 @@ class DeepSpeedAttentionFunction(Function): @staticmethod def forward(ctx, input, + context, input_mask, config, attn_qkvw, + attn_qw, + attn_kvw, attn_qkvb, num_attention_heads_per_partition, norm_factor, @@ -29,11 +32,20 @@ def forward(ctx, score_context_func, linear_func, triton_flash_attn_kernel): + def _transpose_for_context(x): - x = x.permute(0, 2, 1, 3).contiguous() + x = x.permute(0, 2, 1, 3) new_x_layer_shape = x.size()[:-2] + \ (hidden_size_per_partition,) - return x.view(*new_x_layer_shape).contiguous() + return x.reshape(*new_x_layer_shape) + + def _transpose_for_scores(x): + attention_head_size = x.shape[-1] // num_attention_heads_per_partition + new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, + attention_head_size) + x = x.reshape(*new_x_shape) + x = x.permute(0, 2, 1, 3) + return x.contiguous() def compute_attention(qkv_out, input_mask): no_masking = input_mask is None @@ -59,30 +71,49 @@ def compute_attention(qkv_out, input_mask): DeepSpeedAttention.layer_id, torch.empty(1)) return context_layer - - def selfAttention_fp(input, input_mask): + + def selfAttention_fp(input, context, input_mask): if config.fp16 and input.dtype == torch.float32: input = input.half() head_size = input.shape[-1] // config.heads do_flash_attn = (head_size <= 128) - qkv_out = linear_func(input, - attn_qkvw, - attn_qkvb if attn_qkvb is not None else attn_qkvw, - attn_qkvb is not None, - True, - do_flash_attn, - config.heads, - DeepSpeedAttention.layer_id) - if do_flash_attn: - scale = (1 / norm_factor) * 1 / norm_factor - context_layer = triton_flash_attn_kernel(qkv_out[0], - qkv_out[1], - qkv_out[2], - scale, - input.shape[-2] % 128 == 0) - context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + scale = (1 / norm_factor) * (1 / norm_factor) + if context == None: + qkv_out = linear_func(input, + attn_qkvw, + attn_qkvb if attn_qkvb is not None else attn_qkvw, + attn_qkvb is not None, + True, + do_flash_attn, + config.heads, + DeepSpeedAttention.layer_id) + if do_flash_attn: + context_layer = triton_flash_attn_kernel(qkv_out[0], + qkv_out[1], + qkv_out[2], + scale, + input.shape[-2] % 128 == 0) + context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + else: + context_layer = compute_attention(qkv_out, input_mask) else: - context_layer = compute_attention(qkv_out, input_mask) + query = torch.matmul(input, attn_qw) + key_value = torch.matmul(context, attn_kvw) + query = _transpose_for_scores(query) + key = _transpose_for_scores(key_value[:,:,:input.shape[-1]]) + value = _transpose_for_scores(key_value[:,:,input.shape[-1]:]) + + if do_flash_attn: + query, key, value = inference_cuda_module.add_padding_fp16(query, key, value) + context_layer = triton_flash_attn_kernel(query, + key, + value, + scale, + input.shape[-2] % 128 == 0) + context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + else: + attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1) + context_layer = _transpose_for_context(torch.matmul(attention_scores, value)) output = linear_func(context_layer, attn_ow, @@ -94,7 +125,7 @@ def selfAttention_fp(input, input_mask): DeepSpeedAttention.layer_id) return output - output = selfAttention_fp(input, input_mask) + output = selfAttention_fp(input, context, input_mask) return output @@ -140,6 +171,16 @@ def __init__( dtype=data_type, device=device), requires_grad=False) + self.attn_kvw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size * 2, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), @@ -210,11 +251,14 @@ def forward(self, *inputs, **kwargs): outputs = self._forward(*inputs, **kwargs) return outputs - def _forward(self, input, input_mask=None): + def _forward(self, input, context=None, input_mask=None): output = DeepSpeedAttentionFunction.apply(input, + context, input_mask, self.config, self.attn_qkvw, + self.attn_qw, + self.attn_kvw, self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index b0be8a1fefdc..cac9c1aa3476 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -37,14 +37,13 @@ def _fwd_kernel( ): start_m = tl.program_id(0) off_hz = tl.program_id(1) - off_z = off_hz // H # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk # Initialize pointers to Q, K, V q_ptrs = Q + off_q k_ptrs = K + off_k @@ -97,6 +96,7 @@ def _fwd_kernel( tl.store(out_ptrs, acc) + class triton_flash_attn(torch.nn.Module): def __init__(self, ): super(triton_flash_attn, self).__init__() From a40780b525199c934b3fdad0d941ec8c2d8379ed Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 11 Oct 2022 21:36:21 +0500 Subject: [PATCH 09/20] add more fusion --- .../transformer/inference/csrc/pt_binding.cpp | 48 ++++++++ csrc/transformer/inference/csrc/transform.cu | 109 ++++++++++++++++++ .../includes/inference_cuda_layers.h | 13 ++- deepspeed/module_inject/replace_module.py | 5 +- deepspeed/module_inject/replace_policy.py | 4 +- .../ops/transformer/inference/attention.py | 25 ++-- 6 files changed, 189 insertions(+), 15 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 099b942abfd4..f8b15d3f1a58 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -992,6 +992,48 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens }; } +template +std::vector padd_add_transform(at::Tensor& query, at::Tensor& key, at::Tensor& value, int heads, bool add_padding) +{ + int head_size = query.size(2) / heads; + int key_value_length = add_padding ? 128 : key.size(1); + int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) : head_size; + T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; + launch_pad_add_transform_0213(workspace, + (T*)query.data_ptr(), + query.size(0), + query.size(2), + query.size(1), + query.size(1), + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(key_pad_ptr, + (T*)key.data_ptr(), + key.size(0), + key.size(2), + key.size(1), + key_value_length, + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(value_pad_ptr, + (T*)value.data_ptr(), + value.size(0), + value.size(2), + value.size(1), + key_value_length, + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + return { + at::from_blob(workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), + at::from_blob(key_pad_ptr, {query.size(0), heads, key_value_length, padded_head_size}, query.options()), + at::from_blob(value_pad_ptr, {query.size(0), heads, key_value_length, padded_head_size}, query.options()) + }; +} template at::Tensor ds_linear_layer_int8(at::Tensor& input, at::Tensor& weight, @@ -1538,4 +1580,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)"); + m.def("pad_transform_fp32", + &padd_add_transform, + "DeepSpeed residual add with fp32 (CUDA)"); + m.def("pad_transform_fp16", + &padd_add_transform<__half>, + "DeepSpeed residual add with fp16 (CUDA)"); } diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 32d2df95be63..93461c766b88 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -249,6 +249,115 @@ void launch_bias_add_transform_0213<__half>(__half* output, max_out_tokens); } + + +// Bias add + +__global__ void pad_add_transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ + +} + + +__global__ void pad_add_transform_0213(__half* output, + const __half* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ +#if __CUDA_ARCH__ >= 700 + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127) + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = padded_head_size * padded_seq_len; + int d0_out_stride = heads * d2_out_stride; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride); + vals_vec += (d1 * d1_stride); + vals_vec += (d2 * d2_stride); + + output_vec += (d1 * padded_head_size); + output_vec += (d0 * d0_out_stride); + output_vec += (d2 * d2_out_stride); + + if (d3 < d2_stride && d1 < seq_length) + output_vec[d3] = vals_vec[d3]; + else + output_vec[d3] = ZERO; + +#endif +} + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); + +// [B S C*H] - > C * [B A S N] +template <> +void launch_pad_add_transform_0213(float* output, + const float* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ +} +template <> +void launch_pad_add_transform_0213<__half>(__half* output, + const __half* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ + hidden_dim >>= 3; + dim3 block_dim((padded_head_size >> 3), heads, 2); + dim3 grid_dim(batch_size, padded_seq_len / 2); + pad_add_transform_0213<<>>(output, + vals, + hidden_dim, + seq_length, + padded_seq_len, + heads, + padded_head_size >> 3); +} + // Bias add template __global__ void bias_add_transform_0213(T* output, diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index b41a5d618ff0..540db0800954 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -188,4 +188,15 @@ void pad_head_seq(T* padded_output, int padded_seq_len, int head_size, int padded_head_size, - cudaStream_t stream); \ No newline at end of file + cudaStream_t stream); + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); \ No newline at end of file diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 193474e24af6..e652ce0049df 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -203,7 +203,7 @@ def replace_attn(child, policy, layer_id): if len(policy_attn) == 5: qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn else: - qw, kvw, attn_ow, attn_ob, hidden_size, heads = policy_attn + qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, @@ -224,7 +224,8 @@ def transpose(data): else: attn_module.attn_qkvw = None attn_module.attn_qw.data = transpose(qw.data) - attn_module.attn_kvw.data = transpose(kvw.data) + attn_module.attn_kw.data = transpose(kw.data) + attn_module.attn_vw.data = transpose(vw.data) attn_module.attn_qkvb = None attn_module.attn_ow.data = transpose(attn_ow.data) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index ce8a2770f437..6d72e9e46468 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -58,9 +58,9 @@ def attention(self, client_module): client_module.heads else: #return None - kvw = Parameter(torch.cat((kw, vw), dim=0), requires_grad=False) + #kvw = Parameter(torch.cat((kw, vw), dim=0), requires_grad=False) return qw, \ - kvw, \ + kw, vw, \ client_module.to_out[0].weight, \ client_module.to_out[0].bias, \ qw.shape[-1], \ diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index 713224f3131c..a91126e88057 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -22,7 +22,8 @@ def forward(ctx, config, attn_qkvw, attn_qw, - attn_kvw, + attn_kw, + attn_vw, attn_qkvb, num_attention_heads_per_partition, norm_factor, @@ -98,13 +99,11 @@ def selfAttention_fp(input, context, input_mask): context_layer = compute_attention(qkv_out, input_mask) else: query = torch.matmul(input, attn_qw) - key_value = torch.matmul(context, attn_kvw) - query = _transpose_for_scores(query) - key = _transpose_for_scores(key_value[:,:,:input.shape[-1]]) - value = _transpose_for_scores(key_value[:,:,input.shape[-1]:]) - + key = torch.matmul(context, attn_kw) + value = torch.matmul(context, attn_vw) + #do_flash_attn = False + query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) if do_flash_attn: - query, key, value = inference_cuda_module.add_padding_fp16(query, key, value) context_layer = triton_flash_attn_kernel(query, key, value, @@ -171,8 +170,13 @@ def __init__( dtype=data_type, device=device), requires_grad=False) - self.attn_kvw = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.hidden_size * 2, + self.attn_kw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_vw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, dtype=data_type, device=device), requires_grad=False) @@ -258,7 +262,8 @@ def _forward(self, input, context=None, input_mask=None): self.config, self.attn_qkvw, self.attn_qw, - self.attn_kvw, + self.attn_kw, + self.attn_vw, self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, From 923ddd2f6ff9152248ebf1edfc6617d10e0e23db Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 11 Oct 2022 23:27:02 +0500 Subject: [PATCH 10/20] allocate workspace using the padded hidden_size --- csrc/transformer/inference/csrc/pt_binding.cpp | 7 ++++--- csrc/transformer/inference/includes/inference_context.h | 9 ++++++--- deepspeed/ops/transformer/inference/attention.py | 1 - 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index f8b15d3f1a58..7a83f2a02c8f 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -106,10 +106,11 @@ void allocate_workspace(size_t hidden_dim, unsigned num_layers, unsigned mp_size = 1, bool external_cache = false, - unsigned rank = 0) + unsigned rank = 0, + unsigned num_heads=128) { Context::Instance().GenWorkSpace( - num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank); + num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank, num_heads); } template @@ -855,7 +856,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), input.size(0), num_layers, 1, external_cache); + allocate_workspace(input.size(2), input.size(0), num_layers, 1, external_cache, 0, num_heads); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index e535972d2be3..721c76571899 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -93,13 +93,16 @@ class Context { const unsigned& mp_size, const bool& external_cache, const size_t& elem_size, - const unsigned& rank) + const unsigned& rank, + const unsigned& num_heads) { size_t total_size; if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } - size_t activation_size = 16 * hidden_dim * batch_size + MAX_OUT_TOKENS * 16 * batch_size; - size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; + int head_size = hidden_dim / num_heads; + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + size_t activation_size = 32 * (head_size * padded_head_size) * batch_size + MAX_OUT_TOKENS * num_heads * batch_size * 2; + size_t cache_size = num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; _max_seq_len = (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / elem_size)) / diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index a91126e88057..ae99b456095c 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -101,7 +101,6 @@ def selfAttention_fp(input, context, input_mask): query = torch.matmul(input, attn_qw) key = torch.matmul(context, attn_kw) value = torch.matmul(context, attn_vw) - #do_flash_attn = False query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) if do_flash_attn: context_layer = triton_flash_attn_kernel(query, From 2dca3acd5182eaeb14a89e807d0fa6242aa797c2 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 12 Oct 2022 00:33:48 +0500 Subject: [PATCH 11/20] skip the clip-encoder injection for now --- deepspeed/module_inject/encoder.py | 3 +++ deepspeed/module_inject/replace_module.py | 10 +++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepspeed/module_inject/encoder.py b/deepspeed/module_inject/encoder.py index b9cd64e0744b..e49805626f33 100644 --- a/deepspeed/module_inject/encoder.py +++ b/deepspeed/module_inject/encoder.py @@ -16,6 +16,9 @@ def __init__(self, enc): def _build_causal_attention_mask(self, bsz, seq_len, dtype): mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=torch.cuda.current_device()) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) + mask = mask.unsqueeze(1) return mask def _graph_replay(self, *inputs, **kwargs): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e652ce0049df..21925798ddb5 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -242,11 +242,11 @@ def transpose(data): except ImportError: new_policies = {} - replace_transformer_layer(None, - module.text_encoder, - training=False, - replace_with_kernel_inject=True, - triangular_masking=True) + #replace_transformer_layer(None, + # module.text_encoder, + # training=False, + # replace_with_kernel_inject=True, + # triangular_masking=True) from .encoder import DSClipEncoder cg_encoder = DSClipEncoder(module.text_encoder) setattr(module, 'text_encoder', cg_encoder) From 6079065f563923aaa4215f4f4f7abb5a15ccbcbc Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 11 Oct 2022 15:56:17 -0700 Subject: [PATCH 12/20] add triton to new extra --- requirements/requirements-sd.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 requirements/requirements-sd.txt diff --git a/requirements/requirements-sd.txt b/requirements/requirements-sd.txt new file mode 100644 index 000000000000..e393460a3639 --- /dev/null +++ b/requirements/requirements-sd.txt @@ -0,0 +1 @@ +triton==2.0.0.dev20221005 From fb1605f333c0b81c24b54ace285a0d70022fbc0c Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 11 Oct 2022 16:15:26 -0700 Subject: [PATCH 13/20] lazy import triton, add sd extra, formatting --- csrc/transformer/inference/csrc/gelu.cu | 62 ++++++------- .../transformer/inference/csrc/pt_binding.cpp | 87 ++++++++++--------- csrc/transformer/inference/csrc/transform.cu | 61 ++++++------- .../inference/includes/inference_context.h | 6 +- .../includes/inference_cuda_layers.h | 30 +++---- deepspeed/module_inject/encoder.py | 13 +-- .../ops/transformer/inference/attention.py | 46 +++++----- .../ops/transformer/inference/triton_ops.py | 21 ++++- requirements/requirements-sd.txt | 1 + 9 files changed, 176 insertions(+), 151 deletions(-) diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 43a618bf7812..cab8eb3fe63f 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -530,11 +530,11 @@ template void pad_data(float* padded_output, cudaStream_t stream); __global__ void pad_head_seq_kernel(__half* padded_output, - __half* output, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size) + __half* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) { float4* padded_output_cast = reinterpret_cast(padded_output); float4* output_cast = reinterpret_cast(output); @@ -555,22 +555,22 @@ __global__ void pad_head_seq_kernel(__half* padded_output, padded_output_cast[idx] = ZERO; } __global__ void pad_head_seq_kernel(float* padded_output, - float* output, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size) + float* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) { } template void pad_head_seq(T* padded_output, - T* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream) + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream) { dim3 grid_dim(bsz, padded_seq_len / 16); dim3 block_dim(padded_head_size / 8, 16); @@ -578,18 +578,18 @@ void pad_head_seq(T* padded_output, padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); } template void pad_head_seq(__half* padded_output, - __half* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream); + __half* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); template void pad_head_seq(float* padded_output, - float* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream); \ No newline at end of file + float* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 7a83f2a02c8f..130870467994 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -107,7 +107,7 @@ void allocate_workspace(size_t hidden_dim, unsigned mp_size = 1, bool external_cache = false, unsigned rank = 0, - unsigned num_heads=128) + unsigned num_heads = 128) { Context::Instance().GenWorkSpace( num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank, num_heads); @@ -856,7 +856,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), input.size(0), num_layers, 1, external_cache, 0, num_heads); + allocate_workspace( + input.size(2), input.size(0), num_layers, 1, external_cache, 0, num_heads); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); @@ -963,42 +964,50 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; pad_head_seq(workspace, - (T*)query.data_ptr(), - query.size(0) * query.size(1), - query.size(2), - query.size(2), - head_size, - padded_head_size, - Context::Instance().GetCurrentStream()); + (T*)query.data_ptr(), + query.size(0) * query.size(1), + query.size(2), + query.size(2), + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); pad_head_seq(key_pad_ptr, - (T*)key.data_ptr(), - query.size(0) * query.size(1), - key.size(2), - 128, - head_size, - padded_head_size, - Context::Instance().GetCurrentStream()); + (T*)key.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); pad_head_seq(value_pad_ptr, - (T*)value.data_ptr(), - query.size(0) * query.size(1), - key.size(2), - 128, - head_size, - padded_head_size, - Context::Instance().GetCurrentStream()); + (T*)value.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); return { - at::from_blob(workspace, {query.size(0), query.size(1), query.size(2), padded_head_size}, query.options()), - at::from_blob(key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()), - at::from_blob(value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()) - }; + at::from_blob(workspace, + {query.size(0), query.size(1), query.size(2), padded_head_size}, + query.options()), + at::from_blob( + key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()), + at::from_blob( + value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options())}; } template -std::vector padd_add_transform(at::Tensor& query, at::Tensor& key, at::Tensor& value, int heads, bool add_padding) +std::vector padd_add_transform(at::Tensor& query, + at::Tensor& key, + at::Tensor& value, + int heads, + bool add_padding) { int head_size = query.size(2) / heads; int key_value_length = add_padding ? 128 : key.size(1); - int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) : head_size; + int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) + : head_size; T* workspace = (T*)Context::Instance().GetWorkSpace(); T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; @@ -1030,10 +1039,14 @@ std::vector padd_add_transform(at::Tensor& query, at::Tensor& key, a padded_head_size, Context::Instance().GetCurrentStream()); return { - at::from_blob(workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), - at::from_blob(key_pad_ptr, {query.size(0), heads, key_value_length, padded_head_size}, query.options()), - at::from_blob(value_pad_ptr, {query.size(0), heads, key_value_length, padded_head_size}, query.options()) - }; + at::from_blob( + workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), + at::from_blob(key_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options()), + at::from_blob(value_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options())}; } template at::Tensor ds_linear_layer_int8(at::Tensor& input, @@ -1575,12 +1588,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &einsum_sec_sm_ecm<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); - m.def("add_padding_fp32", - &add_padding, - "DeepSpeed residual add with fp32 (CUDA)"); - m.def("add_padding_fp16", - &add_padding<__half>, - "DeepSpeed residual add with fp16 (CUDA)"); + m.def("add_padding_fp32", &add_padding, "DeepSpeed residual add with fp32 (CUDA)"); + m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)"); m.def("pad_transform_fp32", &padd_add_transform, "DeepSpeed residual add with fp32 (CUDA)"); diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 93461c766b88..a5a43c364ed6 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -249,29 +249,25 @@ void launch_bias_add_transform_0213<__half>(__half* output, max_out_tokens); } - - // Bias add __global__ void pad_add_transform_0213(float* output, - const float* vals, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size) + const float* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) { - } - __global__ void pad_add_transform_0213(__half* output, - const __half* vals, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size) + const __half* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) { #if __CUDA_ARCH__ >= 700 float4 ZERO; @@ -284,10 +280,10 @@ __global__ void pad_add_transform_0213(__half* output, int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127) - int d2 = threadIdx.y; // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127) + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) int d2_out_stride = padded_head_size * padded_seq_len; int d0_out_stride = heads * d2_out_stride; @@ -325,14 +321,14 @@ void launch_pad_add_transform_0213(T* output, // [B S C*H] - > C * [B A S N] template <> void launch_pad_add_transform_0213(float* output, - const float* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream) + const float* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) { } template <> @@ -349,13 +345,8 @@ void launch_pad_add_transform_0213<__half>(__half* output, hidden_dim >>= 3; dim3 block_dim((padded_head_size >> 3), heads, 2); dim3 grid_dim(batch_size, padded_seq_len / 2); - pad_add_transform_0213<<>>(output, - vals, - hidden_dim, - seq_length, - padded_seq_len, - heads, - padded_head_size >> 3); + pad_add_transform_0213<<>>( + output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3); } // Bias add diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 721c76571899..bb1d472907d7 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -101,8 +101,10 @@ class Context { int head_size = hidden_dim / num_heads; int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); - size_t activation_size = 32 * (head_size * padded_head_size) * batch_size + MAX_OUT_TOKENS * num_heads * batch_size * 2; - size_t cache_size = num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; + size_t activation_size = 32 * (head_size * padded_head_size) * batch_size + + MAX_OUT_TOKENS * num_heads * batch_size * 2; + size_t cache_size = + num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; _max_seq_len = (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / elem_size)) / diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 540db0800954..67479bbc0e50 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -182,21 +182,21 @@ void pad_data(T* padded_output, template void pad_head_seq(T* padded_output, - T* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream); + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); template void launch_pad_add_transform_0213(T* output, - const T* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream); \ No newline at end of file + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); diff --git a/deepspeed/module_inject/encoder.py b/deepspeed/module_inject/encoder.py index e49805626f33..41d593daed92 100644 --- a/deepspeed/module_inject/encoder.py +++ b/deepspeed/module_inject/encoder.py @@ -2,7 +2,6 @@ Copyright 2022 The Microsoft DeepSpeed Team ''' import torch -import diffusers class DSClipEncoder(torch.nn.Module): @@ -15,12 +14,16 @@ def __init__(self, enc): self.cuda_graph_created = False def _build_causal_attention_mask(self, bsz, seq_len, dtype): - mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=torch.cuda.current_device()) + mask = torch.empty(bsz, + seq_len, + seq_len, + dtype=dtype, + device=torch.cuda.current_device()) mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) mask = mask.unsqueeze(1) return mask - + def _graph_replay(self, *inputs, **kwargs): for i in range(len(inputs)): if torch.is_tensor(inputs[i]): @@ -31,7 +34,7 @@ def _graph_replay(self, *inputs, **kwargs): self._cuda_graphs.replay() return self.static_output - def forward(self, *inputs, **kwargs): + def forward(self, *inputs, **kwargs): if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) else: @@ -59,5 +62,5 @@ def _create_cuda_graph(self, *inputs, **kwargs): self.cuda_graph_created = True def _forward(self, *inputs, **kwargs): - + return self.enc(*inputs, **kwargs) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index ae99b456095c..cac3cc53e524 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -6,8 +6,8 @@ from torch.autograd import Function from ... import op_builder import torch.nn as nn +from .triton_ops import load_triton_flash_attn from deepspeed.utils.logging import log_dist -from .triton_ops import triton_flash_attn # Cuda modules will be imported if needed inference_cuda_module = None minus_inf = -10000.0 @@ -33,13 +33,12 @@ def forward(ctx, score_context_func, linear_func, triton_flash_attn_kernel): - def _transpose_for_context(x): x = x.permute(0, 2, 1, 3) new_x_layer_shape = x.size()[:-2] + \ (hidden_size_per_partition,) return x.reshape(*new_x_layer_shape) - + def _transpose_for_scores(x): attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, @@ -72,7 +71,7 @@ def compute_attention(qkv_out, input_mask): DeepSpeedAttention.layer_id, torch.empty(1)) return context_layer - + def selfAttention_fp(input, context, input_mask): if config.fp16 and input.dtype == torch.float32: input = input.half() @@ -103,15 +102,20 @@ def selfAttention_fp(input, context, input_mask): value = torch.matmul(context, attn_vw) query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) if do_flash_attn: - context_layer = triton_flash_attn_kernel(query, - key, + context_layer = triton_flash_attn_kernel(query, + key, value, scale, input.shape[-2] % 128 == 0) context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) else: - attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1) - context_layer = _transpose_for_context(torch.matmul(attention_scores, value)) + attention_scores = (torch.matmul(query, + key.transpose(-1, + -2)) * + scale).softmax(dim=-1) + context_layer = _transpose_for_context( + torch.matmul(attention_scores, + value)) output = linear_func(context_layer, attn_ow, @@ -170,20 +174,20 @@ def __init__( device=device), requires_grad=False) self.attn_kw = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.hidden_size, - dtype=data_type, - device=device), - requires_grad=False) + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.attn_vw = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.hidden_size, - dtype=data_type, - device=device), - requires_grad=False) + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.attn_qw = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.hidden_size, - dtype=data_type, - device=device), - requires_grad=False) + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), @@ -199,7 +203,7 @@ def __init__( dtype=data_type_fp, device=device), requires_grad=False) - self.triton_flash_attn_kernel = triton_flash_attn() + self.triton_flash_attn_kernel = load_triton_flash_attn() self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index cac9c1aa3476..4e6d7d18e343 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -1,7 +1,23 @@ import torch +from packaging import version as pkg_version -import triton -import triton.language as tl +# lazy load if/when needed +triton = None +tl = None + + +def load_triton_flash_attn(): + global triton, tl + try: + import triton + import triton.language as tl + except ImportError: + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + if pkg_version.parse(triton.__version__) < pkg_version.parse("2.0"): + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + return triton_flash_attn @triton.jit @@ -96,7 +112,6 @@ def _fwd_kernel( tl.store(out_ptrs, acc) - class triton_flash_attn(torch.nn.Module): def __init__(self, ): super(triton_flash_attn, self).__init__() diff --git a/requirements/requirements-sd.txt b/requirements/requirements-sd.txt index e393460a3639..c9026206a737 100644 --- a/requirements/requirements-sd.txt +++ b/requirements/requirements-sd.txt @@ -1 +1,2 @@ +diffusers triton==2.0.0.dev20221005 From 16bcef13cf3ecbcb0316d85e670e37032fa9f64c Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 12 Oct 2022 15:49:07 -0700 Subject: [PATCH 14/20] delay import --- deepspeed/ops/transformer/inference/attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index cac3cc53e524..d0881c8e1854 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -6,7 +6,6 @@ from torch.autograd import Function from ... import op_builder import torch.nn as nn -from .triton_ops import load_triton_flash_attn from deepspeed.utils.logging import log_dist # Cuda modules will be imported if needed inference_cuda_module = None @@ -72,7 +71,6 @@ def compute_attention(qkv_out, input_mask): torch.empty(1)) return context_layer - def selfAttention_fp(input, context, input_mask): if config.fp16 and input.dtype == torch.float32: input = input.half() head_size = input.shape[-1] // config.heads @@ -203,6 +201,7 @@ def __init__( dtype=data_type_fp, device=device), requires_grad=False) + from .triton_ops import load_triton_flash_attn self.triton_flash_attn_kernel = load_triton_flash_attn() self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size From 691fd3495dad0f687d4b1e535541350c8123316f Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 12 Oct 2022 17:54:24 -0700 Subject: [PATCH 15/20] fix previous issue i added --- deepspeed/ops/transformer/inference/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index d0881c8e1854..ebcd1ea8458d 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -71,6 +71,7 @@ def compute_attention(qkv_out, input_mask): torch.empty(1)) return context_layer + def selfAttention_fp(input, context, input_mask): if config.fp16 and input.dtype == torch.float32: input = input.half() head_size = input.shape[-1] // config.heads From c758b03e7e82fa4a4f50b4034492247c8a34fa8b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 13 Oct 2022 21:40:25 +0500 Subject: [PATCH 16/20] fix bug with adding bias --- deepspeed/ops/transformer/inference/transformer_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index ee68b055fd2a..8d7f2f46c261 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -408,11 +408,10 @@ def selfAttention_fp(): if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 - qkv_out = linear_func(input, attn_qkvw, attn_qkvb, - attn_qkvb is None, + attn_qkvb is not None, False, False, num_attention_heads_per_partition, From 406832c5f82b89dd838264d1de7ae58798d5abce Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 13 Oct 2022 12:21:18 -0700 Subject: [PATCH 17/20] fixes for triton import and add acks to triton-ops file --- .../ops/transformer/inference/attention.py | 20 ++++++++++++-- .../ops/transformer/inference/triton_ops.py | 26 +++++-------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index cac3cc53e524..e85b44f51ee0 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -6,11 +6,25 @@ from torch.autograd import Function from ... import op_builder import torch.nn as nn -from .triton_ops import load_triton_flash_attn +from packaging import version as pkg_version from deepspeed.utils.logging import log_dist # Cuda modules will be imported if needed inference_cuda_module = None minus_inf = -10000.0 +triton_flash_attn = None + + +def load_triton_flash_attn(): + global triton_flash_attn + try: + import triton + except ImportError: + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + if pkg_version.parse(triton.__version__) < pkg_version.parse("2.0"): + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + from .triton_ops import triton_flash_attn class DeepSpeedAttentionFunction(Function): @@ -203,7 +217,9 @@ def __init__( dtype=data_type_fp, device=device), requires_grad=False) - self.triton_flash_attn_kernel = load_triton_flash_attn() + if triton_flash_attn is None: + load_triton_flash_attn() + self.triton_flash_attn_kernel = triton_flash_attn self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index 4e6d7d18e343..423a2ff1134d 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -1,23 +1,11 @@ -import torch -from packaging import version as pkg_version - -# lazy load if/when needed -triton = None -tl = None - +""" +Inspired by original Triton implementation: +https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py +""" -def load_triton_flash_attn(): - global triton, tl - try: - import triton - import triton.language as tl - except ImportError: - raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") - - if pkg_version.parse(triton.__version__) < pkg_version.parse("2.0"): - raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") - - return triton_flash_attn +import torch +import triton +import triton.language as tl @triton.jit From 75fbcfe7d8fc56248237da355369ea08393f955e Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 14 Oct 2022 00:55:42 +0500 Subject: [PATCH 18/20] merge fix & formatting --- csrc/transformer/inference/csrc/pt_binding.cpp | 10 ++++++++-- .../inference/includes/inference_context.h | 18 +++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index d40117e96252..f09dfc569ed2 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -851,8 +851,14 @@ at::Tensor ds_linear_layer(at::Tensor& input, if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace( - input.size(2), input.size(0), input.size(1), num_layers, num_heads, 1, external_cache, 0); + allocate_workspace(input.size(2), + input.size(0), + input.size(1), + num_layers, + num_heads, + 1, + external_cache, + 0); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 4981bbb81c5c..64e490ef47fc 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -96,8 +96,7 @@ class Context { const unsigned& mp_size, const bool& external_cache, const size_t& elem_size, - const unsigned& rank, - const unsigned& num_heads) + const unsigned& rank) { size_t total_size; if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } @@ -105,8 +104,9 @@ class Context { int head_size = hidden_dim / num_heads; int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); size_t activation_size = 32 * (head_size * padded_head_size) * batch_size; - size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size; - size_t cache_size = num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; + size_t temp_size = batch_size * num_heads * MAX_OUT_TOKENS * 2; + size_t cache_size = + num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; if (_free_memory_size < minimal_requirements) { @@ -118,12 +118,12 @@ class Context { } _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / - (activation_size + cache_size); + (activation_size + temp_size + cache_size); _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); - size_t workSpaceSize = - ((external_cache ? activation_size : (activation_size + cache_size))) * _max_seq_len * - elem_size + - temp_size; + size_t workSpaceSize = ((external_cache ? (activation_size + temp_size) + : (activation_size + temp_size + cache_size))) * + _max_seq_len * elem_size; + temp_size *= _max_seq_len * elem_size; if (rank == 0 && !_workspace) printf( "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " From eff95e7e04cb686047fc2330f6d2774e885cfe07 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 14 Oct 2022 01:33:56 +0500 Subject: [PATCH 19/20] fix small issue --- deepspeed/ops/transformer/inference/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index e85b44f51ee0..f8bad34d5bff 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -219,7 +219,7 @@ def __init__( requires_grad=False) if triton_flash_attn is None: load_triton_flash_attn() - self.triton_flash_attn_kernel = triton_flash_attn + self.triton_flash_attn_kernel = triton_flash_attn() self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads From 770e88b11ec3973ebb822215963cb9557eb4b135 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 14 Oct 2022 03:05:35 +0500 Subject: [PATCH 20/20] skip cuda-graph for clip-encoder for now (it has issue on larger batch size) --- deepspeed/module_inject/replace_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e90fff5f4cee..d7fa50eca4ce 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -247,9 +247,9 @@ def transpose(data): # training=False, # replace_with_kernel_inject=True, # triangular_masking=True) - from .encoder import DSClipEncoder - cg_encoder = DSClipEncoder(module.text_encoder) - setattr(module, 'text_encoder', cg_encoder) + #from .encoder import DSClipEncoder + #cg_encoder = DSClipEncoder(module.text_encoder) + #setattr(module, 'text_encoder', cg_encoder) for name in module.__dict__.keys(): sub_module = getattr(module, name) policy = _module_match(sub_module)