diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 66fc15ec0e3..aa664ea607a 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -956,15 +956,24 @@ def _unload(): pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint + from huggingface_hub import split_torch_state_dict_into_shards state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + state_dict_split = split_torch_state_dict_into_shards( + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + shards = state_dict_split.filename_to_tensors os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608645..5069fff6d66 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,10 +16,14 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path +import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, +) from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -372,6 +376,23 @@ def get_model( ) model_type = config_dict.get("model_type", None) + transformers_causal_lm_class = CausalLM + + # Fast transformers path + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + if transformers_model_class._supports_flex_attn: + transformers_causal_lm_class = TransformersFlashCausalLM + if ( + not FLASH_ATTENTION + and lora_adapter_ids is not None + and len(lora_adapter_ids) > 0 + ): + raise ValueError( + "Transformers backend AutoModel do not support `lora_adapter_ids`." + ) + quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: quantization_config = config_dict.get("compression_config", None) @@ -615,7 +636,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -674,7 +695,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id=model_id, revision=revision, quantize=quantize, @@ -722,7 +743,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -733,7 +754,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -758,7 +779,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -769,7 +790,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -806,7 +827,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -829,7 +850,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -853,7 +874,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -902,7 +923,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -928,7 +949,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -954,7 +975,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -979,7 +1000,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1007,7 +1028,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1057,7 +1078,7 @@ def get_model( config_class=RWConfig, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1082,7 +1103,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1107,7 +1128,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1134,7 +1155,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1159,7 +1180,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1302,7 +1323,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1323,7 +1344,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py new file mode 100644 index 00000000000..30ea4c8fcce --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -0,0 +1,270 @@ +import math +from typing import List, Optional + +import torch +from opentelemetry import trace +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import transformers.modeling_utils + +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.utils import initialize_torch_distributed + +from text_generation_server.layers.attention import paged_attention, attention, Seqlen +from text_generation_server.layers.attention.kv_cache import KVScales, KVCache +from text_generation_server.models.globals import ATTENTION + + +tracer = trace.get_tracer(__name__) + + +def tgi_flash_attention_forward( + module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers + kv_cache: List[KVCache], + kv_head_mapping: torch.Tensor, + slots: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + seqlen: Seqlen, + block_tables: torch.Tensor, + max_s: int, + kv_scales: KVScales, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling +): + + kv_cache = kv_cache[module.layer_idx] + + query_states = query_states.transpose(1, 2).squeeze(dim=0) + key_states = key_states.transpose(1, 2).squeeze(dim=0) + value_states = value_states.transpose(1, 2).squeeze(dim=0) + + # Take care of updating the cache in-place + kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) + + _, num_heads, head_dim = query_states.shape + softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + sliding_window = -1 if sliding_window is None else sliding_window + + if cu_seqlen_prefill is not None: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + attn_output = paged_attention( + query_states, + kv_cache, + kv_head_mapping, + softmax_scale, + block_tables, + seqlen, + max_s, + kv_scales=kv_scales, + softcap=softcap, + ) + + attn_output = attn_output.view(-1, num_heads * head_dim) + + return attn_output, None + + +transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward + + +class TransformersFlashCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + kv_cache_dtype: Optional[torch.dtype] = None, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto", + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + attn_implementation="tgi", + tp_plan="auto" if world_size > 1 else None, + ) + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None and isinstance( + model.config.eos_token_id, int + ): + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.num_layers = model.config.num_hidden_layers + self.num_heads = model.config.num_attention_heads // self.process_group.size() + self.num_kv_heads = model.config.num_key_value_heads + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) + self.head_size = model.config.hidden_size // model.config.num_attention_heads + + self.cuda_graphs = {} + self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + self.num_groups = self.num_heads // self.num_kv_heads + + # Those will never change and will be used in the forwards + self.kv_head_mapping = torch.arange( + 0, self.num_kv_heads, dtype=torch.int32, device=device + ).repeat_interleave(self.num_groups) + # This means no scale + self.kv_scales = KVScales( + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + ) + + torch.distributed.barrier(group=self.process_group) + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code + # We first copy the original model.forward because we still need it in the monkey patch + self.model.original_forward = self.model.forward + self.model.forward = self._model_forward + + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + return cls( + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + def _model_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: Optional[torch.Tensor], + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature + ): + hidden_states = self.model.model.forward( + input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers + position_ids=position_ids.unsqueeze( + 0 + ), # expand dim to easily fit transformers + past_key_values=None, # we use self.kv_cache instead of transformers cache object + use_cache=False, # we use self.kv_cache instead of transformers cache object + return_dict=True, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + kv_head_mapping=self.kv_head_mapping, + kv_scales=self.kv_scales, + )[0].squeeze(dim=0) + + # And compute logits from the lm_head, slicing correctly the indices + # NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules + # To update with full Transformers support asap + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.model.lm_head.forward(hidden_states) + + return logits, None diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 132e441be4f..64a285b93f8 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -5,13 +5,12 @@ from typing import List, Optional, DefaultDict from loguru import logger -from typing import Dict, Union +from typing import Dict from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.guide import RegexGuide from transformers import ( - LogitsWarper, LogitsProcessor, PreTrainedTokenizerBase, TemperatureLogitsWarper, @@ -219,7 +218,7 @@ def filter(self, indices): return None -class HeterogeneousTopPLogitsWarper(LogitsWarper): +class HeterogeneousTopPLogitsWarper(LogitsProcessor): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. This version allows for a separate value for each sample and runs inplace when possible. @@ -278,7 +277,7 @@ def filter(self, indices): return None -class HeterogeneousTopKLogitsWarper(LogitsWarper): +class HeterogeneousTopKLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. This version allows for a separate value for each sample and runs inplace when possible. @@ -359,7 +358,7 @@ def filter(self, indices): return None -class HeterogeneousTypicalLogitsWarper(LogitsWarper): +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. @@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. Args: - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + processors (`Dict[int, LogitsProcessor]`): A mapping of sample indices to logit warpers or processors, to be run sequentially. """ def __init__( self, - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + processors: Dict[int, LogitsProcessor], ): self.processors = processors