Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support Mllama #777

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,7 @@ scripts/git-strip-merge
tests/backwards_compatibility/Ref_Out

# backwards compatibility
model_outputs
model_outputs

# TODO: remove after mllama dev
explore_mllama
39 changes: 39 additions & 0 deletions src/adapters/models/mllama/_init_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["MllamaAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import MllamaAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
266 changes: 266 additions & 0 deletions src/adapters/models/mllama/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import logging
from typing import List, Optional, Tuple, Union

import torch
from torch import nn

from adapters.heads.language_modeling import CausalLMOutputWithPast
from hf_transformers.build.lib.transformers.cache_utils import Cache
from hf_transformers.build.lib.transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mllama import (
MLLAMA_START_DOCSTRING,
MllamaPreTrainedModel,
MllamaVisionModel,
MllamaTextModel,
)
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
from transformers.utils import add_start_docstrings

from ...composition import adjust_tensors_for_parallel
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init


logger = logging.getLogger(__name__)


@add_start_docstrings(MLLAMA_START_DOCSTRING)
class MllamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MllamaPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1

self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaTextModel._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)

init(self.vision_model)
init(self.language_model)
self._init_head_modules()
self.post_init()

def get_input_embeddings(self):
return self.language_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def get_output_embeddings(self):
return self.language_model.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)

def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)

def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
): # TODO -> output format

# TODO: incorporate adapter logic with Forwardcontext and heads

# Establish parameter values
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# Check invalid argument combinations
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")

# If image is provided compute cross_attention_states
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)

# Compute cross_attention_mask
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]

outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
)

# TODO: head logic, until now just copied!
hidden_states = outputs[0]
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cross_attention_mask": cross_attention_mask,
}
)

# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
# to compute image hidden states, otherwise they are cached within each cross attn layer
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask

return model_inputs

def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)

# add cross-attn mask for new token
if cross_attention_mask_prev is not None:
model_kwargs["cross_attention_mask"] = torch.cat(
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
)
return model_kwargs
Loading
Loading