-
-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch generation with Exllamav2_HF is weird #606
Comments
I would check that the attention mask is correct. Input IDs should be a right-aligned tensor, and you should pass a You could compare your tokenized IDs and offsets tensor to the output of |
It seems that |
My problem is that the offset value is kind of counter intuitive. Its value is the negative value of padding length. So, if the max input token length is 100 and the padding length is 10 (real input token length is 90), then the offset value is -10. Could you explain why? |
It's the offset added to the position IDs. If you have two sequences in a batch, and one is 5 tokens and the other is 10 tokens, they make up a (2, 10) right-aligned tensor:
Where More importantly, attention natively doesn't have a way to prevent attending to padding tokens. They must have a -inf attention score (or zero score post-softmax), and you can't achieve that with any special embedding for those tokens, so you must have a mask that's applied on every attention layer. In the above case the mask would look like:
|
Okay, I implement import torch, os
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, List, Union, Dict
from transformers import AutoConfig, PretrainedConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora
class ExLlamaV2ForCausalLM(GenerationMixin):
def __init__(
self,
config: PretrainedConfig,
generation_config: GenerationConfig,
exllama_config: ExLlamaV2Config,
model: ExLlamaV2,
loras: Dict[str, ExLlamaV2Lora] = {'': None},
active_adapter: str = '',
**kwargs
):
self.config = config
self.generation_config = generation_config
self.exllama_config = exllama_config
self.model = model
self.loras = loras
if '' not in self.loras:
self.loras[''] = None
self._active_adapter = active_adapter
self._adapter_enabled = True
def can_generate(self):
return True
@property
def _supports_cache_class(self) -> bool:
return False
@property
def device(self) -> torch.device:
return torch.device(0)
@property
def main_input_name(self) -> str:
return 'input_ids'
@property
def active_adapters(self) -> List[str]:
return [self._active_adapter] if self._adapter_enabled else []
@property
def active_adapter(self) -> List[str]:
return self._active_adapter if self._adapter_enabled else ''
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_size: int = -1,
**kwargs
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loras = self.loras.get(self.active_adapter, None)
loras = [loras] if loras else loras
input_device = input_ids.device
input_ids = input_ids.to('cpu')
attention_mask = attention_mask.to(torch.bool)
position_offsets = -(~(attention_mask)).sum(dim=1, keepdim=True).to(torch.int)
if labels is None:
if past_key_values is None:
past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, loras=loras, input_mask=attention_mask[...,:-1], position_offsets=position_offsets)
logits = self.model.forward(input_ids[...,-1:], past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
else:
if past_key_values is None:
past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
logits = self.model.forward(input_ids, past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, past_key_values if use_cache else None)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(logits=logits.to(input_device), past_key_values=past_key_values if use_cache else None, loss=loss)
def load_adapter(self, lora_path: Union[str, os.PathLike], adapter_name: str):
if adapter_name in self.loras:
raise ValueError('This adapter is already existed')
if isinstance(lora_path, str):
lora_path = Path(lora_path)
lora_model = ExLlamaV2Lora.from_directory(self.model, lora_path)
self.loras[adapter_name] = lora_model
def set_adapter(self, adapter_name: str):
if adapter_name not in self.loras:
raise ValueError('The adapter is not existed')
self._active_adapter = adapter_name
def enable_adapter_layers(self):
self._adapter_enabled = True
def disable_adapter_layers(self):
self._adapter_enabled = False
@contextmanager
def disable_adapter(self):
try:
self.disable_adapter_layers()
yield
finally:
self.enable_adapter_layers()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
gpu_split: Optional[str] = None,
lora_path: Optional[Union[str, os.PathLike]] = None,
adapter_name: str = 'default',
trust_remote_code: bool = False,
use_flash_attention_2: bool = False
):
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if isinstance(lora_path, str):
lora_path = Path(lora_path)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
except:
generation_config = GenerationConfig()
exllama_config = ExLlamaV2Config()
exllama_config.model_dir = pretrained_model_name_or_path
exllama_config.no_flash_attn = not use_flash_attention_2
if getattr(config, 'rope_scaling', None) is not None:
rope_type = config.rope_scaling.get('type', config.rope_scaling.get('rope_type', ''))
if rope_type == 'linear':
exllama_config.scale_pos_emb = config.rope_scaling['factor']
elif rope_type == 'dynamic':
exllama_config.scale_alpha_value = config.rope_scaling['factor']
exllama_config.rope_config = config.rope_scaling
exllama_config.prepare()
model = ExLlamaV2(exllama_config)
if gpu_split is not None:
gpu_split = [float(d) for d in gpu_split.split(' ')]
model.load(gpu_split=gpu_split)
lora_model = None
if lora_path is not None:
lora_model = ExLlamaV2Lora.from_directory(model, lora_path)
if lora_model is None:
adapter_name = ''
return cls(config, generation_config, exllama_config, model, {adapter_name: lora_model}, adapter_name)
@staticmethod
def _reorder_cache(past_key_values: ExLlamaV2Cache, beam_idx):
for i in range(len(past_key_values.key_states)):
past_key_values.key_states[i] = past_key_values.key_states[i].index_select(0, beam_idx.to(past_key_values.key_states[i].device))
past_key_values.value_states[i] = past_key_values.value_states[i].index_select(0, beam_idx.to(past_key_values.value_states[i].device))
return past_key_values But, the answer didn't change. still wrong for short input. The attention mask is padded left just like what you say. |
I think I expressed that poorly. The tensor you've got is the right layout, but it still needs to be a half tensor of inputs["attention_mask"] = torch.where(
inputs["attention_mask"] == 0,
torch.tensor(-float('inf'), dtype = torch.float16),
torch.tensor(0.0, dtype = torch.float16)
) I get sensible outputs from both prompts (Llama3-8B-instruct 5bpw):
And:
As for the offset, I may be misremembering as it turns out, I think since that parameter is kinda being used for a lot of things depending on the inference mode. In any case, negative offsets don't seem to work, but positive offsets do, like this (accounting for the mask being inverted): position_offsets = ((attention_mask)).sum(dim=1, keepdim=True).to(torch.int) However, this shouldn't matter for RoPE to begin with, since it's designed to be relative, so give or take some rounding errors attention should work out exactly the same whether you start each sequence at zero or with some offset. And I'm not sure what the correct approach would be for a wrapper like this to get rid of the left-padding in the output. |
Okay, so here is the final code for HF format in exllamav2: import torch, os, gc
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, List, Union, Dict
from transformers import AutoConfig, PretrainedConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora
class ExLlamaV2ForCausalLM(GenerationMixin):
def __init__(
self,
config: PretrainedConfig,
generation_config: GenerationConfig,
exllama_config: ExLlamaV2Config,
model: ExLlamaV2,
loras: Dict[str, ExLlamaV2Lora] = {'': None},
active_adapter: str = '',
**kwargs
):
self.config = config
self.generation_config = generation_config
self.exllama_config = exllama_config
self.model = model
self.loras = loras
if '' not in self.loras:
self.loras[''] = None
self._active_adapter = active_adapter
self._adapter_enabled = True
self._selfcache = None
def can_generate(self):
return True
@property
def _supports_cache_class(self) -> bool:
return False
@property
def device(self) -> torch.device:
return torch.device(0)
@property
def main_input_name(self) -> str:
return 'input_ids'
@property
def active_adapters(self) -> List[str]:
return [self._active_adapter] if self._adapter_enabled else []
@property
def active_adapter(self) -> List[str]:
return self._active_adapter if self._adapter_enabled else ''
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_size: int = -1,
**kwargs
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loras = self.loras.get(self.active_adapter, None)
loras = [loras] if loras else loras
input_device = input_ids.device
input_ids = input_ids.to('cpu')
attention_mask = attention_mask.to(torch.bool)
position_offsets = -(~(attention_mask)).sum(dim=1, keepdim=True).to(torch.int)
attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float16), torch.tensor(-float('inf'), dtype=torch.float16))
if labels is None:
if past_key_values is None:
if self._selfcache is not None:
self._selfcache.key_states = None
self._selfcache.value_states = None
gc.collect()
torch.cuda.empty_cache()
self._selfcache = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
past_key_values = self._selfcache
self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, loras=loras, input_mask=attention_mask[...,:-1], position_offsets=position_offsets)
logits = self.model.forward(input_ids[...,-1:], past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
else:
if past_key_values is None:
if self._selfcache is not None:
self._selfcache.key_states = None
self._selfcache.value_states = None
gc.collect()
torch.cuda.empty_cache()
self._selfcache = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
past_key_values = self._selfcache
logits = self.model.forward(input_ids, past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, past_key_values if use_cache else None)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(logits=logits.to(input_device), past_key_values=past_key_values if use_cache else None, loss=loss)
def load_adapter(self, lora_path: Union[str, os.PathLike], adapter_name: str):
if adapter_name in self.loras:
raise ValueError('This adapter is already existed')
if isinstance(lora_path, str):
lora_path = Path(lora_path)
lora_model = ExLlamaV2Lora.from_directory(self.model, lora_path)
self.loras[adapter_name] = lora_model
def set_adapter(self, adapter_name: str):
if adapter_name not in self.loras:
raise ValueError('The adapter is not existed')
self._active_adapter = adapter_name
def enable_adapter_layers(self):
self._adapter_enabled = True
def disable_adapter_layers(self):
self._adapter_enabled = False
@contextmanager
def disable_adapter(self):
try:
self.disable_adapter_layers()
yield
finally:
self.enable_adapter_layers()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
gpu_split: Optional[str] = None,
lora_path: Optional[Union[str, os.PathLike]] = None,
adapter_name: str = 'default',
trust_remote_code: bool = False,
use_flash_attention_2: bool = False
):
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if isinstance(lora_path, str):
lora_path = Path(lora_path)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
except:
generation_config = GenerationConfig()
exllama_config = ExLlamaV2Config()
exllama_config.model_dir = pretrained_model_name_or_path
exllama_config.no_flash_attn = not use_flash_attention_2
if getattr(config, 'rope_scaling', None) is not None:
rope_type = config.rope_scaling.get('type', config.rope_scaling.get('rope_type', ''))
if rope_type == 'linear':
exllama_config.scale_pos_emb = config.rope_scaling['factor']
elif rope_type == 'dynamic':
exllama_config.scale_alpha_value = config.rope_scaling['factor']
exllama_config.rope_config = config.rope_scaling
exllama_config.prepare()
model = ExLlamaV2(exllama_config)
if gpu_split is not None:
gpu_split = [float(d) for d in gpu_split.split(' ')]
model.load(gpu_split=gpu_split)
lora_model = None
if lora_path is not None:
lora_model = ExLlamaV2Lora.from_directory(model, lora_path)
if lora_model is None:
adapter_name = ''
return cls(config, generation_config, exllama_config, model, {adapter_name: lora_model}, adapter_name)
@staticmethod
def _reorder_cache(past_key_values: ExLlamaV2Cache, beam_idx):
for i in range(len(past_key_values.key_states)):
past_key_values.key_states[i] = past_key_values.key_states[i].index_select(0, beam_idx.to(past_key_values.key_states[i].device))
past_key_values.value_states[i] = past_key_values.value_states[i].index_select(0, beam_idx.to(past_key_values.value_states[i].device))
return past_key_values |
Here is my code for Exllamav2_HF
The result of this is this:
and
Somehow the shortest input got random answers. Why is that?
The text was updated successfully, but these errors were encountered: