Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch generation with Exllamav2_HF is weird #606

Closed
fahadh4ilyas opened this issue Aug 29, 2024 · 7 comments
Closed

Batch generation with Exllamav2_HF is weird #606

fahadh4ilyas opened this issue Aug 29, 2024 · 7 comments

Comments

@fahadh4ilyas
Copy link
Contributor

Here is my code for Exllamav2_HF

import torch, os
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, List, Union, Dict
from transformers import AutoConfig, PretrainedConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoTokenizer

from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora

class ExLlamaV2ForCausalLM(GenerationMixin):
    
    def __init__(
        self,
        config: PretrainedConfig,
        generation_config: GenerationConfig,
        exllama_config: ExLlamaV2Config,
        model: ExLlamaV2,
        loras: Dict[str, ExLlamaV2Lora] = {'': None},
        active_adapter: str = '',
        **kwargs
    ):
        self.config = config
        self.generation_config = generation_config
        self.exllama_config = exllama_config
        self.model = model
        self.loras = loras
        if '' not in self.loras:
            self.loras[''] = None
        self._active_adapter = active_adapter
        self._adapter_enabled = True
    
    def can_generate(self):
        return True

    @property
    def _supports_cache_class(self) -> bool:
        return False
    
    @property
    def device(self) -> torch.device:
        return torch.device(0)
    
    @property
    def main_input_name(self) -> str:
        return 'input_ids'
    
    @property
    def active_adapters(self) -> List[str]:
        return [self._active_adapter] if self._adapter_enabled else []

    @property
    def active_adapter(self) -> List[str]:
        return self._active_adapter if self._adapter_enabled else ''
    
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_size: int = -1,
        **kwargs
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        loras = self.loras.get(self.active_adapter, None)
        loras = [loras] if loras else loras
        input_device = input_ids.device
        input_ids = input_ids.to('cpu')
        
        if labels is None:
            if past_key_values is None:
                past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
                self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, loras=loras, input_mask=attention_mask[...,:-1].to(torch.bool))

            logits = self.model.forward(input_ids[...,-1:], past_key_values, loras=loras, input_mask=attention_mask.to(torch.bool)).to(input_ids.device)
        else:
            if past_key_values is None:
                past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)

            logits = self.model.forward(input_ids, past_key_values, loras=loras, input_mask=attention_mask.to(torch.bool))
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        
        if not return_dict:
            output = (logits, past_key_values if use_cache else None)
            return (loss,) + output if loss is not None else output
        
        return CausalLMOutputWithPast(logits=logits.to(input_device), past_key_values=past_key_values if use_cache else None, loss=loss)
    
    def load_adapter(self, lora_path: Union[str, os.PathLike], adapter_name: str):

        if adapter_name in self.loras:
            raise ValueError('This adapter is already existed')
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        lora_model = ExLlamaV2Lora.from_directory(self.model, lora_path)

        self.loras[adapter_name] = lora_model

    def set_adapter(self, adapter_name: str):

        if adapter_name not in self.loras:
            raise ValueError('The adapter is not existed')
        
        self._active_adapter = adapter_name
    
    def enable_adapter_layers(self):

        self._adapter_enabled = True

    def disable_adapter_layers(self):

        self._adapter_enabled = False
    
    @contextmanager
    def disable_adapter(self):

        try:
            self.disable_adapter_layers()
            yield
        finally:
            self.enable_adapter_layers()
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        gpu_split: Optional[str] = None,
        lora_path: Optional[Union[str, os.PathLike]] = None,
        adapter_name: str = 'default',
        trust_remote_code: bool = False,
        use_flash_attention_2: bool = False
    ):
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        
        try:
            generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        except:
            generation_config = GenerationConfig()
        
        exllama_config = ExLlamaV2Config()
        exllama_config.model_dir = pretrained_model_name_or_path
        exllama_config.no_flash_attn = not use_flash_attention_2
        if getattr(config, 'rope_scaling', None) is not None:
            rope_type = config.rope_scaling.get('type', config.rope_scaling.get('rope_type', ''))
            if rope_type == 'linear':
                exllama_config.scale_pos_emb = config.rope_scaling['factor']
            elif rope_type == 'dynamic':
                exllama_config.scale_alpha_value = config.rope_scaling['factor']
            exllama_config.rope_config = config.rope_scaling
        exllama_config.prepare()
        
        model = ExLlamaV2(exllama_config)
        if gpu_split is not None:
            gpu_split = [float(d) for d in gpu_split.split(' ')]
        model.load(gpu_split=gpu_split)

        lora_model = None
        if lora_path is not None:
            lora_model = ExLlamaV2Lora.from_directory(model, lora_path)
        
        if lora_model is None:
            adapter_name = ''

        return cls(config, generation_config, exllama_config, model, {adapter_name: lora_model}, adapter_name)

    @staticmethod
    def _reorder_cache(past_key_values: ExLlamaV2Cache, beam_idx):

        for i in range(len(past_key_values.key_states)):
            past_key_values.key_states[i] = past_key_values.key_states[i].index_select(0, beam_idx.to(past_key_values.key_states[i].device))
            past_key_values.value_states[i] = past_key_values.value_states[i].index_select(0, beam_idx.to(past_key_values.value_states[i].device))
        
        return past_key_values


model = ExLlamaV2ForCausalLM.from_pretrained('my-llama-model')
tokenizer = AutoTokenizer.from_pretrained('my-llama-model', padding_side='left', truncation_side='left')
tokenizer.pad_token_id = tokenizer.eos_token_id

second_text = '''Here is a story:
Hi! Nice to meet you! My name is John Smith. I am 19 and a student in college. I go to college in New York. My favorite courses are Geometry, French, and History. English is my hardest course. My professors are very friendly and smart. It’s my second year in college now. I love it!

I live in a big house on Ivy Street. It’s near the college campus. I share the house with three other students. Their names are Bill, Tony, and Paul. We help each other with homework. On the weekend, we play football together.

I have a younger brother. He just started high school. He is 14 and lives with my parents. They live on Mulberry Street in Boston. Sometimes they visit me in New York. I am happy when they visit. My Mom always brings me sweets and candy when they come. I really miss them, too!

Question: Who are three students that live with John?'''

inputs = tokenizer.apply_chat_template([[{'role': 'user', 'content': 'What is AI?'}], [{'role': 'user', 'content': second_text}]], return_dict=True, return_tensors='pt', padding=True, add_generation_prompt=True)

with torch.inference_mode():
    result = model.generate(**inputs, max_new_tokens=1000)

decoded_result = tokenizer.batch_decode(result)

print(decoded_result[0])
print(decoded_result[1])

The result of this is this:

<|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|end_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is AI?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

HomeQInHomeThe ( ( AIn

QHomeSup ( (

and

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Here is a story:
Hi! Nice to meet you! My name is John Smith. I am 19 and a student in college. I go to college in New York. My favorite courses are Geometry, French, and History. English is my hardest course. My professors are very friendly and smart. It’s my second year in college now. I love it!

I live in a big house on Ivy Street. It’s near the college campus. I share the house with three other students. Their names are Bill, Tony, and Paul. We help each other with homework. On the weekend, we play football together.

I have a younger brother. He just started high school. He is 14 and lives with my parents. They live on Mulberry Street in Boston. Sometimes they visit me in New York. I am happy when they visit. My Mom always brings me sweets and candy when they come. I really miss them, too!

Question: Who are three students that live with John?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The three students that live with John are Bill, Tony, and Paul.<|eot_id|>

Somehow the shortest input got random answers. Why is that?

@turboderp
Copy link
Owner

I would check that the attention mask is correct. Input IDs should be a right-aligned tensor, and you should pass a position_offsets argument to forward to get correct positional embeddings for sequences of dissimilar length.

You could compare your tokenized IDs and offsets tensor to the output of ExLlamaV2Tokenizer.encode with return_offsets = True. The padding mask is a half tensor of the same shape as the input IDs, with a bias of -inf for padding tokens.

@fahadh4ilyas
Copy link
Contributor Author

I would check that the attention mask is correct. Input IDs should be a right-aligned tensor, and you should pass a position_offsets argument to forward to get correct positional embeddings for sequences of dissimilar length.

You could compare your tokenized IDs and offsets tensor to the output of ExLlamaV2Tokenizer.encode with return_offsets = True. The padding mask is a half tensor of the same shape as the input IDs, with a bias of -inf for padding tokens.

It seems that ExLlamaV2Tokenizer did not use the attention_mask and instead used the offsets tensor. Is that correct?

@fahadh4ilyas
Copy link
Contributor Author

My problem is that the offset value is kind of counter intuitive. Its value is the negative value of padding length. So, if the max input token length is 100 and the padding length is 10 (real input token length is 90), then the offset value is -10. Could you explain why?

@turboderp
Copy link
Owner

It's the offset added to the position IDs. If you have two sequences in a batch, and one is 5 tokens and the other is 10 tokens, they make up a (2, 10) right-aligned tensor:

P P P P P A A A A A
B B B B B B B B B B

Where P is a padding token. The model, meanwhile, has no way to count padding tokens when applying positional encodings. Hence the offsets, so that RoPE can be applied for positions 0..9 and -5..4, respectively.

More importantly, attention natively doesn't have a way to prevent attending to padding tokens. They must have a -inf attention score (or zero score post-softmax), and you can't achieve that with any special embedding for those tokens, so you must have a mask that's applied on every attention layer. In the above case the mask would look like:

0 0 0 0 0 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1

@fahadh4ilyas
Copy link
Contributor Author

fahadh4ilyas commented Sep 2, 2024

Okay, I implement position_offsets like this

import torch, os
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, List, Union, Dict
from transformers import AutoConfig, PretrainedConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora

class ExLlamaV2ForCausalLM(GenerationMixin):
    
    def __init__(
        self,
        config: PretrainedConfig,
        generation_config: GenerationConfig,
        exllama_config: ExLlamaV2Config,
        model: ExLlamaV2,
        loras: Dict[str, ExLlamaV2Lora] = {'': None},
        active_adapter: str = '',
        **kwargs
    ):
        self.config = config
        self.generation_config = generation_config
        self.exllama_config = exllama_config
        self.model = model
        self.loras = loras
        if '' not in self.loras:
            self.loras[''] = None
        self._active_adapter = active_adapter
        self._adapter_enabled = True
    
    def can_generate(self):
        return True

    @property
    def _supports_cache_class(self) -> bool:
        return False
    
    @property
    def device(self) -> torch.device:
        return torch.device(0)
    
    @property
    def main_input_name(self) -> str:
        return 'input_ids'
    
    @property
    def active_adapters(self) -> List[str]:
        return [self._active_adapter] if self._adapter_enabled else []

    @property
    def active_adapter(self) -> List[str]:
        return self._active_adapter if self._adapter_enabled else ''
    
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_size: int = -1,
        **kwargs
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        loras = self.loras.get(self.active_adapter, None)
        loras = [loras] if loras else loras
        input_device = input_ids.device
        input_ids = input_ids.to('cpu')
        attention_mask = attention_mask.to(torch.bool)
        position_offsets = -(~(attention_mask)).sum(dim=1, keepdim=True).to(torch.int)
        
        if labels is None:
            if past_key_values is None:
                past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
                self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, loras=loras, input_mask=attention_mask[...,:-1], position_offsets=position_offsets)

            logits = self.model.forward(input_ids[...,-1:], past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
        else:
            if past_key_values is None:
                past_key_values = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)

            logits = self.model.forward(input_ids, past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        
        if not return_dict:
            output = (logits, past_key_values if use_cache else None)
            return (loss,) + output if loss is not None else output
        
        return CausalLMOutputWithPast(logits=logits.to(input_device), past_key_values=past_key_values if use_cache else None, loss=loss)
    
    def load_adapter(self, lora_path: Union[str, os.PathLike], adapter_name: str):

        if adapter_name in self.loras:
            raise ValueError('This adapter is already existed')
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        lora_model = ExLlamaV2Lora.from_directory(self.model, lora_path)

        self.loras[adapter_name] = lora_model

    def set_adapter(self, adapter_name: str):

        if adapter_name not in self.loras:
            raise ValueError('The adapter is not existed')
        
        self._active_adapter = adapter_name
    
    def enable_adapter_layers(self):

        self._adapter_enabled = True

    def disable_adapter_layers(self):

        self._adapter_enabled = False
    
    @contextmanager
    def disable_adapter(self):

        try:
            self.disable_adapter_layers()
            yield
        finally:
            self.enable_adapter_layers()
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        gpu_split: Optional[str] = None,
        lora_path: Optional[Union[str, os.PathLike]] = None,
        adapter_name: str = 'default',
        trust_remote_code: bool = False,
        use_flash_attention_2: bool = False
    ):
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        
        try:
            generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        except:
            generation_config = GenerationConfig()
        
        exllama_config = ExLlamaV2Config()
        exllama_config.model_dir = pretrained_model_name_or_path
        exllama_config.no_flash_attn = not use_flash_attention_2
        if getattr(config, 'rope_scaling', None) is not None:
            rope_type = config.rope_scaling.get('type', config.rope_scaling.get('rope_type', ''))
            if rope_type == 'linear':
                exllama_config.scale_pos_emb = config.rope_scaling['factor']
            elif rope_type == 'dynamic':
                exllama_config.scale_alpha_value = config.rope_scaling['factor']
            exllama_config.rope_config = config.rope_scaling
        exllama_config.prepare()
        
        model = ExLlamaV2(exllama_config)
        if gpu_split is not None:
            gpu_split = [float(d) for d in gpu_split.split(' ')]
        model.load(gpu_split=gpu_split)

        lora_model = None
        if lora_path is not None:
            lora_model = ExLlamaV2Lora.from_directory(model, lora_path)
        
        if lora_model is None:
            adapter_name = ''

        return cls(config, generation_config, exllama_config, model, {adapter_name: lora_model}, adapter_name)

    @staticmethod
    def _reorder_cache(past_key_values: ExLlamaV2Cache, beam_idx):

        for i in range(len(past_key_values.key_states)):
            past_key_values.key_states[i] = past_key_values.key_states[i].index_select(0, beam_idx.to(past_key_values.key_states[i].device))
            past_key_values.value_states[i] = past_key_values.value_states[i].index_select(0, beam_idx.to(past_key_values.value_states[i].device))
        
        return past_key_values

But, the answer didn't change. still wrong for short input. The attention mask is padded left just like what you say.

@turboderp
Copy link
Owner

I think I expressed that poorly. The tensor you've got is the right layout, but it still needs to be a half tensor of -inf and 0 values. E.g. if I do this with your example:

inputs["attention_mask"] = torch.where(
    inputs["attention_mask"] == 0, 
    torch.tensor(-float('inf'), dtype = torch.float16), 
    torch.tensor(0.0, dtype = torch.float16)
)

I get sensible outputs from both prompts (Llama3-8B-instruct 5bpw):

<|end_of_text|><|end_of_text|> ....... <|end_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is AI?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Artificial Intelligence (AI) is a wide-ranging field of research and development that is concerned with the use of machines to perform tasks that would normally be done by humans. AI is often associated with the use of machines to perform tasks that would normally be done by humans.<|eot_id|>

And:

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Here is a story:
Hi! Nice to meet you! My name is John Smith. I am 19 and a student in college. I go to college in New York. My favorite courses are Geometry, French, and History. English is my hardest course. My professors are very friendly and smart. It’s my second year in college now. I love it!

I live in a big house on Ivy Street. It’s near the college campus. I share the house with three other students. Their names are Bill, Tony, and Paul. We help each other with homework. On the weekend, we play football together.

I have a younger brother. He just started high school. He is 14 and lives with my parents. They live on Mulberry Street in Boston. Sometimes they visit me in New York. I am happy when they visit. My Mom always brings me sweets and candy when they come. I really miss them, too!

Question: Who are three students that live with John?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

According to the story, the three students who live with John are:

1. Bill
2. Tony
3. Paul<|eot_id|>

As for the offset, I may be misremembering as it turns out, I think since that parameter is kinda being used for a lot of things depending on the inference mode. In any case, negative offsets don't seem to work, but positive offsets do, like this (accounting for the mask being inverted):

position_offsets = ((attention_mask)).sum(dim=1, keepdim=True).to(torch.int)

However, this shouldn't matter for RoPE to begin with, since it's designed to be relative, so give or take some rounding errors attention should work out exactly the same whether you start each sequence at zero or with some offset. And position_offsets = None does also work for your example, so as long as the longest sequence in the batch doesn't exceed the total context length of the model you should be good to just omit the offsets.

I'm not sure what the correct approach would be for a wrapper like this to get rid of the left-padding in the output.

@fahadh4ilyas
Copy link
Contributor Author

Okay, so here is the final code for HF format in exllamav2:

import torch, os, gc
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, List, Union, Dict
from transformers import AutoConfig, PretrainedConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora

class ExLlamaV2ForCausalLM(GenerationMixin):
    
    def __init__(
        self,
        config: PretrainedConfig,
        generation_config: GenerationConfig,
        exllama_config: ExLlamaV2Config,
        model: ExLlamaV2,
        loras: Dict[str, ExLlamaV2Lora] = {'': None},
        active_adapter: str = '',
        **kwargs
    ):
        self.config = config
        self.generation_config = generation_config
        self.exllama_config = exllama_config
        self.model = model
        self.loras = loras
        if '' not in self.loras:
            self.loras[''] = None
        self._active_adapter = active_adapter
        self._adapter_enabled = True
        self._selfcache = None
    
    def can_generate(self):
        return True

    @property
    def _supports_cache_class(self) -> bool:
        return False
    
    @property
    def device(self) -> torch.device:
        return torch.device(0)
    
    @property
    def main_input_name(self) -> str:
        return 'input_ids'
    
    @property
    def active_adapters(self) -> List[str]:
        return [self._active_adapter] if self._adapter_enabled else []

    @property
    def active_adapter(self) -> List[str]:
        return self._active_adapter if self._adapter_enabled else ''
    
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_size: int = -1,
        **kwargs
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        loras = self.loras.get(self.active_adapter, None)
        loras = [loras] if loras else loras
        input_device = input_ids.device
        input_ids = input_ids.to('cpu')
        attention_mask = attention_mask.to(torch.bool)
        position_offsets = -(~(attention_mask)).sum(dim=1, keepdim=True).to(torch.int)
        attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float16), torch.tensor(-float('inf'), dtype=torch.float16))
        
        if labels is None:
            if past_key_values is None:
                if self._selfcache is not None:
                    self._selfcache.key_states = None
                    self._selfcache.value_states = None
                    gc.collect()
                    torch.cuda.empty_cache()
                self._selfcache = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
                past_key_values = self._selfcache
                self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, loras=loras, input_mask=attention_mask[...,:-1], position_offsets=position_offsets)

            logits = self.model.forward(input_ids[...,-1:], past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
        else:
            if past_key_values is None:
                if self._selfcache is not None:
                    self._selfcache.key_states = None
                    self._selfcache.value_states = None
                    gc.collect()
                    torch.cuda.empty_cache()
                self._selfcache = ExLlamaV2Cache(self.model, input_ids.shape[0], cache_size)
                past_key_values = self._selfcache

            logits = self.model.forward(input_ids, past_key_values, loras=loras, input_mask=attention_mask, position_offsets=position_offsets)
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        
        if not return_dict:
            output = (logits, past_key_values if use_cache else None)
            return (loss,) + output if loss is not None else output
        
        return CausalLMOutputWithPast(logits=logits.to(input_device), past_key_values=past_key_values if use_cache else None, loss=loss)
    
    def load_adapter(self, lora_path: Union[str, os.PathLike], adapter_name: str):

        if adapter_name in self.loras:
            raise ValueError('This adapter is already existed')
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        lora_model = ExLlamaV2Lora.from_directory(self.model, lora_path)

        self.loras[adapter_name] = lora_model

    def set_adapter(self, adapter_name: str):

        if adapter_name not in self.loras:
            raise ValueError('The adapter is not existed')
        
        self._active_adapter = adapter_name
    
    def enable_adapter_layers(self):

        self._adapter_enabled = True

    def disable_adapter_layers(self):

        self._adapter_enabled = False
    
    @contextmanager
    def disable_adapter(self):

        try:
            self.disable_adapter_layers()
            yield
        finally:
            self.enable_adapter_layers()
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        gpu_split: Optional[str] = None,
        lora_path: Optional[Union[str, os.PathLike]] = None,
        adapter_name: str = 'default',
        trust_remote_code: bool = False,
        use_flash_attention_2: bool = False
    ):
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
        
        if isinstance(lora_path, str):
            lora_path = Path(lora_path)
        
        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        
        try:
            generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        except:
            generation_config = GenerationConfig()
        
        exllama_config = ExLlamaV2Config()
        exllama_config.model_dir = pretrained_model_name_or_path
        exllama_config.no_flash_attn = not use_flash_attention_2
        if getattr(config, 'rope_scaling', None) is not None:
            rope_type = config.rope_scaling.get('type', config.rope_scaling.get('rope_type', ''))
            if rope_type == 'linear':
                exllama_config.scale_pos_emb = config.rope_scaling['factor']
            elif rope_type == 'dynamic':
                exllama_config.scale_alpha_value = config.rope_scaling['factor']
            exllama_config.rope_config = config.rope_scaling
        exllama_config.prepare()
        
        model = ExLlamaV2(exllama_config)
        if gpu_split is not None:
            gpu_split = [float(d) for d in gpu_split.split(' ')]
        model.load(gpu_split=gpu_split)

        lora_model = None
        if lora_path is not None:
            lora_model = ExLlamaV2Lora.from_directory(model, lora_path)
        
        if lora_model is None:
            adapter_name = ''

        return cls(config, generation_config, exllama_config, model, {adapter_name: lora_model}, adapter_name)

    @staticmethod
    def _reorder_cache(past_key_values: ExLlamaV2Cache, beam_idx):

        for i in range(len(past_key_values.key_states)):
            past_key_values.key_states[i] = past_key_values.key_states[i].index_select(0, beam_idx.to(past_key_values.key_states[i].device))
            past_key_values.value_states[i] = past_key_values.value_states[i].index_select(0, beam_idx.to(past_key_values.value_states[i].device))
        
        return past_key_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants