diff --git a/paddlemix/examples/aria/run_predict.py b/paddlemix/examples/aria/run_predict.py new file mode 100644 index 000000000..d9d54adbe --- /dev/null +++ b/paddlemix/examples/aria/run_predict.py @@ -0,0 +1,41 @@ +import paddle +from PIL import Image +from paddlemix.models.aria.modeling_aria import AriaPretrainedModel, AriaForConditionalGeneration +from paddlemix.processors.processing_aria import AriaProcessor +from paddlemix.processors.aria_vision_processor import AriaVisionProcessor +import json + + +model_id_or_path = 'rhymes-ai/Aria' + +config_path = f"{model_id_or_path}/config.json" +with open(config_path, 'r') as f: + config = json.load(f) + print("Config loaded successfully:") + print(json.dumps(config, indent=2)) + +try: + model = AriaForConditionalGeneration.from_pretrained(model_id_or_path) + print(11) + processor = AriaProcessor.from_pretrained(model_id_or_path, + trust_remote_code=True) + print(12) + # image = Image.open(requests.get(image_path, stream=True).raw) + image = Image.open('paddlemix/demo_images/examples_image1.jpg').convert('RGB') + print(13) + messages = [{'role': 'user', 'content': [{'text': None, 'type': 'image'}, { + 'text': 'what is the image?', 'type': 'text'}]}] + # text = processor.apply_chat_template(messages, add_generation_prompt=True) + text = processor.apply_chat_template(messages) + inputs = processor(text=text, images=image, return_tensors='pd') + inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype) + inputs = {k: v.to(model.place) for k, v in inputs.items()} + with paddle.no_grad(), paddle.amp.auto_cast(dtype='bfloat16'): + output = model.generate(**inputs, max_new_tokens=500, stop_strings=[ + '<|im_end|>'], tokenizer=processor.tokenizer, do_sample=True, + temperature=0.9) + output_ids = output[0][tuple(inputs['input_ids'].shape)[1]:] + result = processor.decode(output_ids, skip_special_tokens=True) + print(result) +except Exception as e: + print(f"Error loading model: {e}") \ No newline at end of file diff --git a/paddlemix/models/__init__.py b/paddlemix/models/__init__.py index 5fe525794..58f781ca6 100644 --- a/paddlemix/models/__init__.py +++ b/paddlemix/models/__init__.py @@ -13,6 +13,7 @@ # see the license for the specific language governing permissions and # limitations under the license. +from .aria import * from .audioldm2.configuration import * from .audioldm2.modeling import * from .blip2.modeling import * diff --git a/paddlemix/models/aria/__init__.py b/paddlemix/models/aria/__init__.py new file mode 100644 index 000000000..ac6c585cd --- /dev/null +++ b/paddlemix/models/aria/__init__.py @@ -0,0 +1,7 @@ +from .configuration_aria import * +from .modeling_aria import * +from .moe_lm import * +from .projector import * +from .vision_encoder import * +from ...processors.processing_aria import * +from ...processors.aria_vision_processor import * diff --git a/paddlemix/models/aria/configuration_aria.py b/paddlemix/models/aria/configuration_aria.py new file mode 100644 index 000000000..668d8194e --- /dev/null +++ b/paddlemix/models/aria/configuration_aria.py @@ -0,0 +1,69 @@ +import paddlenlp +import logging +from .moe_lm import AriaMoELMConfig +from .vision_encoder import AriaVisionConfig +from paddlenlp.transformers.configuration_utils import PretrainedConfig +logger = logging.getLogger(__name__) + + +class AriaConfig(PretrainedConfig): + """ + Configuration class for Aria model. + + This class handles the configuration for both vision and text components of the Aria model, + as well as additional parameters for image token handling and projector mapping. + + Args: + vision_config (AriaVisionConfig or dict): Configuration for the vision component. + text_config (AriaMoELMConfig or dict): Configuration for the text component. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + **kwargs: Additional keyword arguments passed to the parent class. + + Attributes: + model_type (str): Type of the model, set to "aria". + is_composition (bool): Whether the model is a composition of multiple components. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions. + vision_config (AriaVisionConfig): Configuration for the vision component. + text_config (AriaMoELMConfig): Configuration for the text component. + """ + model_type = 'aria' + is_composition = False + + def __init__(self, vision_config=AriaVisionConfig(), text_config= + AriaMoELMConfig(), projector_patch_to_query_dict={(1225): 128, ( + 4900): 256}, ignore_index=-100, image_token_index=32000, + tie_word_embeddings=False, **kwargs): + super().__init__(**kwargs) + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.tie_word_embeddings = tie_word_embeddings + attn_implementation = kwargs.pop('attn_implementation', None) + self._attn_implementation = ('flash_attention_2' if + attn_implementation is None else attn_implementation) + self.projector_patch_to_query_dict = {int(k): int(v) for k, v in + projector_patch_to_query_dict.items()} + if isinstance(vision_config, dict) and 'model_type' in vision_config: + vision_config = AriaVisionConfig(**vision_config) + if attn_implementation is None: + vision_attn_implementation = 'flash_attention_2' + elif attn_implementation == 'sdpa': + logger.warning( + 'SDPA is not supported for vit, using flash_attention_2 instead' + ) + vision_attn_implementation = 'flash_attention_2' + else: + vision_attn_implementation = attn_implementation + vision_config._attn_implementation = vision_attn_implementation + self.vision_config = vision_config + if isinstance(text_config, dict) and 'model_type' in text_config: + text_attn_implementation = ('sdpa' if attn_implementation is + None else attn_implementation) + + text_config = AriaMoELMConfig(**text_config) + text_config._attn_implementation = text_attn_implementation + self.text_config = text_config + self.num_hidden_layers = self.text_config.num_hidden_layers diff --git a/paddlemix/models/aria/modeling_aria.py b/paddlemix/models/aria/modeling_aria.py new file mode 100644 index 000000000..1adab254e --- /dev/null +++ b/paddlemix/models/aria/modeling_aria.py @@ -0,0 +1,273 @@ +import paddle +import paddlenlp +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +from .configuration_aria import AriaConfig +from .moe_lm import AriaMoELMForCausalLM +from .projector import AriaProjector +from .vision_encoder import AriaVisionModel +from paddlenlp.transformers.model_outputs import ModelOutput +from paddlenlp.transformers.model_utils import PretrainedModel +from paddlenlp.generation.utils import GenerationMixin +logger = paddle.utils.try_import('logging').getLogger(name=__name__) + + +class AriaPretrainedModel(PretrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + config_class = AriaConfig + base_model_prefix = 'model' + _no_split_modules = [] + supports_gradient_checkpointing = True + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_cache_class = True + _supports_static_cache = True + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA (Scaled Dot Product Attention) or not. + """ + + return self.language_model._supports_sdpa + + +@dataclass +class AriaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Aria causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + loss: Optional[paddle.Tensor] = None + logits: paddle.float32 = None + past_key_values: Optional[List[paddle.Tensor]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + image_hidden_states: Optional[Tuple[paddle.Tensor]] = None + + +def build_mm_projector(config: AriaConfig): + """ + Builds and returns an AriaProjector instance based on the provided configuration. + + Args: + config (AriaConfig): The configuration object containing necessary parameters. + + Returns: + AriaProjector: An instance of the AriaProjector class. + """ + return AriaProjector(patch_to_query_dict=config. + projector_patch_to_query_dict, embed_dim=config.vision_config. + hidden_size, num_heads=config.vision_config.num_attention_heads, + kv_dim=config.vision_config.hidden_size, ff_dim=config.text_config. + hidden_size, output_dim=config.text_config.hidden_size) + + +class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language model + to perform tasks that involve both image and text inputs. + """ + + def __init__(self, config: AriaConfig): + super().__init__(config) + + self.vision_tower = AriaVisionModel(config.vision_config) + print(2) + self.multi_modal_projector = build_mm_projector(config) + print(3) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaMoELMForCausalLM(config.text_config) + print(4) + self.pad_token_id = (self.config.pad_token_id if self.config. + pad_token_id is not None else -1) + self.post_init() + print(5) + + def freeze_vit(self): + """Freeze the parameters of the vision tower.""" + for param in self.vision_tower.parameters(): + param.stop_gradient = not False + + def freeze_projector(self): + """Freeze the parameters of the multi-modal projector.""" + for param in self.multi_modal_projector.parameters(): + param.stop_gradient = not False + + def freeze_llm(self): + """Freeze the parameters of the language model.""" + for param in self.language_model.parameters(): + param.stop_gradient = not False + + def get_input_embeddings(self) ->paddle.nn.Layer: + """Retrieve the input embeddings from the language model.""" + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + """Set the input embeddings for the language model.""" + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + """Retrieve the output embeddings from the language model.""" + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, value): + """Set the output embeddings for the language model.""" + self.language_model.set_output_embeddings(value) + + def set_moe_z_loss_coeff(self, value): + """ + Set the z-loss coefficient for Mixture of Experts (MoE) models. + + Args: + value: The z-loss coefficient value to set. + """ + self.language_model.set_z_loss_coeff(value) + + def set_moe_aux_loss_coeff(self, value): + """ + Set the auxiliary loss coefficient for Mixture of Experts (MoE) models. + + Args: + value: The auxiliary loss coefficient value to set. + """ + self.language_model.set_aux_loss_coeff(value) + + def forward(self, input_ids: paddle.Tensor=None, pixel_values: paddle. + Tensor=None, pixel_mask: paddle.Tensor=None, attention_mask: + Optional[paddle.Tensor]=None, position_ids: Optional[paddle.Tensor] + =None, past_key_values: Optional[List[paddle.Tensor]]=None, + inputs_embeds: Optional[paddle.Tensor]=None, labels: Optional[ + paddle.Tensor]=None, use_cache: Optional[bool]=None, + output_attentions: Optional[bool]=None, output_hidden_states: + Optional[bool]=None, return_dict: Optional[bool]=None, + cache_position: Optional[paddle.Tensor]=None, num_logits_to_keep: int=0 + ) ->Union[Tuple, AriaCausalLMOutputWithPast]: + """ + Forward pass of the AriaForConditionalGeneration model. + + This method processes both text and image inputs, merges them if necessary, + and generates output using the language model. + + Args: + input_ids (torch.LongTensor, optional): Input token ids. + pixel_values (torch.FloatTensor, optional): Pixel values of the images. + pixel_mask (torch.LongTensor, optional): Mask for the pixel values. + attention_mask (torch.Tensor, optional): Attention mask. + position_ids (torch.LongTensor, optional): Position ids. + past_key_values (List[torch.FloatTensor], optional): Past key values for efficient processing. + inputs_embeds (torch.FloatTensor, optional): Input embeddings. + labels (torch.LongTensor, optional): Labels for computing the language modeling loss. + use_cache (bool, optional): Whether to use the model's cache mechanism. + output_attentions (bool, optional): Whether to output attention weights. + output_hidden_states (bool, optional): Whether to output hidden states. + return_dict (bool, optional): Whether to return a ModelOutput object. + + Returns: + Union[Tuple, AriaCausalLMOutputWithPast]: Model outputs. + """ + output_attentions = (output_attentions if output_attentions is not + None else self.config.output_attentions) + output_hidden_states = (output_hidden_states if + output_hidden_states is not None else self.config. + output_hidden_states) + return_dict = (return_dict if return_dict is not None else self. + config.use_return_dict) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + image_features = None + if pixel_values is not None: + image_outputs, image_attn_mask = self.vision_tower(pixel_values, + pixel_mask=pixel_mask) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature, + attn_mask=image_attn_mask) + if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum( + ).item() + n_image_features = tuple(image_features.shape)[0] * tuple( + image_features.shape)[1] + if n_image_tokens != n_image_features: + raise ValueError( + f'Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}' + ) + special_image_mask = (input_ids == self.config.image_token_index + ).unsqueeze(axis=-1).expand_as(y=inputs_embeds).to( + inputs_embeds.place) + image_features = image_features.to(inputs_embeds.place, + inputs_embeds.dtype) + """Class Method: *.masked_scatter, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" + inputs_embeds = inputs_embeds.masked_scatter_(special_image_mask, image_features) #ppdiffusers/examples/PhotoMaker/photomaker/model.py 的FuseModule 也那么干了 + outputs = self.language_model(attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states= + output_hidden_states, return_dict=return_dict, cache_position= + cache_position, num_logits_to_keep=num_logits_to_keep) + logits = outputs[0] + loss = None + if labels is not None: + if attention_mask is not None: + shift_attention_mask = attention_mask[:, -(tuple(logits. + shape)[1] - 1):].to(logits.place) + shift_logits = logits[..., :-1, :][shift_attention_mask.to( + logits.place) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to( + labels.place) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), + shift_labels.view(-1).to(shift_logits.place)) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + return AriaCausalLMOutputWithPast(loss=loss, logits=logits, + past_key_values=outputs.past_key_values, hidden_states=outputs. + hidden_states, attentions=outputs.attentions) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, pixel_values=None, pixel_mask=None, + attention_mask=None, cache_position=None, num_logits_to_keep=None, + **kwargs): + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds= + inputs_embeds, attention_mask=attention_mask, cache_position= + cache_position, num_logits_to_keep=num_logits_to_keep, **kwargs) + if cache_position[0] == 0: + model_inputs['pixel_values'] = pixel_values + model_inputs['pixel_mask'] = pixel_mask + return model_inputs diff --git a/paddlemix/models/aria/moe_lm.py b/paddlemix/models/aria/moe_lm.py new file mode 100644 index 000000000..a3f62983c --- /dev/null +++ b/paddlemix/models/aria/moe_lm.py @@ -0,0 +1,629 @@ +import os +import paddle +import paddlenlp +import logging +from typing import Tuple +logger = logging.getLogger(__name__) +from paddlenlp.transformers.llama.configuration import LlamaConfig +from paddlenlp.generation.utils import GenerationMixin +from paddlenlp.transformers.llama.modeling import LlamaMLP, LlamaDecoderLayer, LlamaRMSNorm, LlamaModel, LlamaRotaryEmbedding, LlamaForCausalLM +from paddlemix.activations import ACT2FN +class AriaMoELMConfig(LlamaConfig): + """ + Configuration class for AriaMoE language model. + + This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture. + """ + model_type = 'aria_moe_lm' + + def __init__(self, moe_intermediate_size: int=4096, moe_num_experts: + int=8, moe_topk: int=2, moe_z_loss_coeff: float=1e-05, + moe_aux_loss_coeff: float=0.001, moe_num_shared_experts: int=2, ** + kwargs): + """ + Initialize the AriaMoELMConfig. + + Args: + moe_intermediate_size (int): The intermediate size for MoE layers. Default is 4096. + moe_num_experts (int): The number of experts in the MoE layer. Default is 8. + moe_topk (int): The number of top experts to route to for each token. Default is 2. + moe_z_loss_coeff (float): The coefficient for the auxiliary z-loss. Default is 1e-5. + moe_aux_loss_coeff (float): The coefficient for the auxiliary load balancing loss. Default is 1e-3. + moe_num_shared_experts (int): The number of shared experts. Default is 2. + **kwargs: Additional keyword arguments to be passed to the parent LlamaConfig. + """ + super().__init__(**kwargs) + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_z_loss_coeff = moe_z_loss_coeff + self.moe_aux_loss_coeff = moe_aux_loss_coeff + self.moe_num_shared_experts = moe_num_shared_experts + + +class MoEAuxLossAutoScaler(paddle.autograd.PyLayer): + """An AutoScaler that compute and scales the grad for auxiliary loss.""" + main_loss_backward_scale: paddle.Tensor = paddle.to_tensor(data=1.0) + + @staticmethod + def forward(ctx, output: paddle.Tensor, aux_loss: paddle.Tensor): + """Preserve the aux_loss by storing it in the context to avoid garbage collection. + + Args: + output (torch.Tensor): The output tensor. + aux_loss (torch.Tensor): The auxiliary loss tensor. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(aux_loss) + return output + + @staticmethod + def backward(ctx, grad_output: paddle.Tensor): + """Compute and scale the gradient for auxiliary loss.. + + Args: + grad_output (torch.Tensor): The gradient of the output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. + """ + """Class Attribute: torch.autograd.function.FunctionCtx.saved_tensors, can not convert, please check whether it is torch.Tensor.*/torch.autograd.function.FunctionCtx.*/torch.distributions.Distribution.* and convert manually""" + aux_loss, = ctx.saved_tensors # 这可能会报错,但是PaddleNLP/paddlenlp/transformers/bloom/modeling.py GeLUFunction 也那么干的 + aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale + scaled_aux_loss_grad = paddle.ones_like(x=aux_loss + ) * aux_loss_backward_scale + return grad_output, scaled_aux_loss_grad + + @staticmethod + def set_loss_scale(scale: paddle.Tensor): + """set the scale of the aux loss. + + Args: + scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. + """ + MoEAuxLossAutoScaler.main_loss_backward_scale = scale + + +def z_loss_func(logits, z_loss_coeff): + """Encourages the router's logits to remain small to enhance stability. + Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. + + Args: + logits (torch.Tensor): The logits of the router. + + Returns: + torch.Tensor: The logits after applying the z-loss. + """ + z_loss = paddle.mean(x=paddle.square(x=paddle.logsumexp(x=logits, axis=-1)) + ) * z_loss_coeff + return z_loss + + +def switch_load_balancing_loss_func(probs: paddle.Tensor, tokens_per_expert: + paddle.Tensor, topk: int, moe_aux_loss_coeff: float): + """Calculate the auxiliary loss for better load balancing. + Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. + + Args: + probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts] + tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts] + + Returns: + torch.Tensor: The auxiliary loss for load balancing. + """ + num_tokens = tuple(probs.shape)[0] * topk + num_experts = tuple(probs.shape)[1] + probs_mean_per_expert = probs.mean(axis=0) + aux_loss = paddle.sum(x=probs_mean_per_expert * tokens_per_expert) * ( + num_experts / num_tokens * moe_aux_loss_coeff) + return aux_loss + + +class TopKRouter(paddle.nn.Layer): + """ + Top-K Router for Mixture of Experts (MoE) models. + + This router determines which experts should process each token based on the top-k scoring experts. + It also applies auxiliary losses to encourage load balancing among experts. + + Args: + config (AriaMoELMConfig): Configuration object containing MoE-related parameters. + """ + + def __init__(self, config: AriaMoELMConfig): + super().__init__() + self.config = config + self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor + =paddle.empty(shape=(self.config.moe_num_experts, self.config. + hidden_size))) + + def gating(self, input: paddle.Tensor) ->paddle.Tensor: + """ + Compute the gating logits for each token-expert pair. + + Args: + input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. + + Returns: + torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts]. + """ + logits = paddle.nn.functional.linear(x=input, weight=self.weight.T) + return logits + + def apply_z_loss(self, logits: paddle.Tensor) ->paddle.Tensor: + """ + Apply z-loss to encourage router logits to remain small for enhanced stability. + + Args: + logits (torch.Tensor): Router logits. + + Returns: + torch.Tensor: Logits with z-loss applied. + """ + z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff) + logits = MoEAuxLossAutoScaler.apply(logits, z_loss) + return logits + + def apply_aux_loss(self, logits: paddle.Tensor, tokens_per_expert: + paddle.Tensor, activation: paddle.Tensor) ->paddle.Tensor: + """ + Apply auxiliary loss for load balancing among experts. + + Args: + logits (torch.Tensor): Router logits. + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + activation (torch.Tensor): Activation values. + + Returns: + torch.Tensor: Activation with auxiliary loss applied. + """ + probs = paddle.nn.functional.softmax(x=logits, axis=-1, dtype='float32' + ) + aux_loss = switch_load_balancing_loss_func(probs, tokens_per_expert, + self.config.moe_topk, self.config.moe_aux_loss_coeff) + return MoEAuxLossAutoScaler.apply(activation, aux_loss) + + def routing(self, logits: paddle.Tensor) ->Tuple[paddle.Tensor, paddle. + Tensor, paddle.Tensor]: + """ + Perform the routing operation to determine expert assignments. + + Args: + logits (torch.Tensor): Router logits. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - scores: Softmax probabilities for top-k experts. + - top_indices: Indices of top-k experts for each token. + - tokens_per_expert: Number of tokens assigned to each expert. + """ + if self.training: + logits = self.apply_z_loss(logits) + top_logits, top_indices = paddle.topk(k=self.config.moe_topk, x= + logits, axis=1) + scores = paddle.nn.functional.softmax(x=top_logits, axis=-1, dtype= + 'float32').astype(dtype=logits.dtype) + tokens_per_expert = paddle.histogram(input=top_indices.flatten(), + bins=self.config.moe_num_experts, min=0, max=self.config. + moe_num_experts - 1).astype(top_indices.flatten().dtype) + if self.training: + scores = self.apply_aux_loss(logits, tokens_per_expert, scores) + return scores, top_indices, tokens_per_expert + + def forward(self, input: paddle.Tensor) ->Tuple[paddle.Tensor, paddle. + Tensor, paddle.Tensor]: + """ + Forward pass of the TopKRouter. + + Args: + input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - scores: Softmax probabilities for top-k experts. + - top_indices: Indices of top-k experts for each token. + - tokens_per_expert: Number of tokens assigned to each expert. + """ + logits = self.gating(input) + logits = logits.view(-1, self.config.moe_num_experts) + scores, top_indices, tokens_per_expert = self.routing(logits) + return scores, top_indices, tokens_per_expert + + +class TokenDispatcher: + """ + Handles the dispatching and gathering of tokens to and from experts. + + This class is responsible for permuting tokens based on expert assignments and + unpermuting them after expert processing. + + Args: + config (AriaMoELMConfig): Configuration object containing MoE-related parameters. + """ + + def __init__(self, config: AriaMoELMConfig): + self.config = config + self.hidden_states_shape = None + self.reversed_input_permutation_mapping = None + + def token_permutation(self, hidden_states: paddle.Tensor, indices: + paddle.Tensor) ->paddle.Tensor: + """ + Permute tokens based on expert assignments. + + Args: + hidden_states (torch.Tensor): Input hidden states. + indices (torch.Tensor): Expert assignment indices. + + Returns: + torch.Tensor: Permuted tokens. + """ + self.hidden_states_shape = tuple(hidden_states.shape) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flatten_indices = indices.flatten() + sorted_indices = paddle.argsort(x=flatten_indices, stable=True) + permuted_tokens = hidden_states.index_select(axis=0, index= + sorted_indices // self.config.moe_topk) + self.reversed_input_permutation_mapping = sorted_indices + return permuted_tokens + + def token_unpermutation(self, permuted_tokens: paddle.Tensor, scores: + paddle.Tensor) ->paddle.Tensor: + """ + Unpermute tokens and combine expert outputs. + + Args: + permuted_tokens (torch.Tensor): Tokens after expert processing. + scores (torch.Tensor): Expert assignment scores. + + Returns: + torch.Tensor: Unpermuted and combined output. + """ + num_unpermuted_tokens = scores.size + unpermuted_tokens = paddle.zeros(shape=(num_unpermuted_tokens, + permuted_tokens.shape[1]), dtype=permuted_tokens.dtype) + unpermuted_tokens.scatter_(self.reversed_input_permutation_mapping, + permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, self.config. + moe_topk, permuted_tokens.shape[1]) + unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(axis=-1) + unpermuted_tokens = unpermuted_tokens.sum(axis=1).astype(dtype= + permuted_tokens.dtype) + output = unpermuted_tokens.view(self.hidden_states_shape) + return output + + +class SharedExpertMLP(LlamaMLP): + """ + Shared Expert MLP for shared experts. + + Unlike routed experts, shared experts process all tokens without routing. + This class reconfigures the intermediate size in comparison to the LlamaMLP. + + Args: + config (AriaMoELMConfig): Configuration object for the AriaMoE language model. + """ + + def __init__(self, config: AriaMoELMConfig): + super().__init__() # 这个怎么弄 + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = (config.moe_intermediate_size * config. + moe_num_shared_experts) + self.gate_proj = paddle.nn.Linear(in_features=self.hidden_size, + out_features=self.intermediate_size, bias_attr=config.mlp_bias) + self.up_proj = paddle.nn.Linear(in_features=self.hidden_size, + out_features=self.intermediate_size, bias_attr=config.mlp_bias) + self.down_proj = paddle.nn.Linear(in_features=self. + intermediate_size, out_features=self.hidden_size, bias_attr= + config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + +def sequential_gemm(input, weight, tokens_per_expert): + """ + Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. + + Args: + input (torch.Tensor): Input tensor of shape (num_tokens, in_features). + weight (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + num_tokens = tuple(input.shape)[0] + out_features = tuple(weight.shape)[-1] + output = paddle.zeros(shape=[num_tokens, out_features], dtype=input.dtype) + cumsum_num_tokens = paddle.cumsum(x=tokens_per_expert, axis=0) + zero_tensor = paddle.zeros(shape=[1], dtype='int64') + cumsum_num_tokens = paddle.concat(x=(zero_tensor, cumsum_num_tokens)) + for expert_num in range(tuple(weight.shape)[0]): + start = cumsum_num_tokens[expert_num] + end = cumsum_num_tokens[expert_num + 1] + tokens = input[start:end] + out = paddle.matmul(x=tokens, y=weight[expert_num]) + output[start:end] = out + return output + + +try: + from grouped_gemm.ops import gmm as experts_gemm + if os.environ.get('USE_GROUPED_GEMM', '1') == '0': + logger.warning( + 'environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM instead.' + ) + experts_gemm = sequential_gemm +except ImportError: + logger.warning( + '`grouped_gemm` is not installed, using sequential GEMM, which is slower.' + ) + experts_gemm = sequential_gemm + + +class GroupedGEMM(paddle.nn.Layer): + """ + Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. + This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm) + for optimized performance. If the grouped_gemm library is not installed, it gracefully + falls back to a sequential GEMM implementation, which may be slower but ensures + functionality. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + groups (int): Number of expert groups. + """ + + def __init__(self, in_features, out_features, groups): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.groups = groups + self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor + =paddle.empty(shape=[groups, in_features, out_features])) + + def forward(self, input, tokens_per_expert): + """ + Perform grouped matrix multiplication. + + Args: + input (torch.Tensor): Input tensor of shape (num_tokens, in_features). + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + tokens_per_expert = tokens_per_expert.cpu() + paddle.device.set_device(device='gpu:' + str(input.place) if + isinstance(input.place, int) else str(input.place).replace( + 'cuda', 'gpu')) + return experts_gemm(input, self.weight, tokens_per_expert) + + +class GroupedMLP(paddle.nn.Layer): + """ + Grouped MLP module for Mixture of Experts. + + Args: + config (AriaMoELMConfig): Configuration object for the model. + """ + + def __init__(self, config: AriaMoELMConfig) ->None: + super().__init__() + self.config = config + self.fc1 = GroupedGEMM(config.hidden_size, config. + moe_intermediate_size * 2, config.moe_num_experts) + self.fc2 = GroupedGEMM(config.moe_intermediate_size, config. + hidden_size, config.moe_num_experts) + + def glu(x): + x = paddle.chunk(x=x, chunks=2, axis=-1) + return paddle.nn.functional.silu(x=x[0]) * x[1] + self.activation_func = glu + + def forward(self, permuted_tokens, tokens_per_expert): + """ + Forward pass of the Grouped MLP. + + Args: + permuted_tokens (torch.Tensor): Permuted input tokens. + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor after passing through the MLP. + """ + fc1_output = self.fc1(permuted_tokens, tokens_per_expert) + fc1_output = self.activation_func(fc1_output) + fc2_output = self.fc2(fc1_output, tokens_per_expert) + return fc2_output + + +class MoELayer(paddle.nn.Layer): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to different experts + based on a routing algorithm, processes them through the experts, and then combines + the outputs. + + Args: + config (AriaMoELMConfig): Configuration object for the MoE layer. + """ + + def __init__(self, config: AriaMoELMConfig): + super().__init__() + self.router = TopKRouter(config) + self.token_dispatcher = TokenDispatcher(config) + self.experts = GroupedMLP(config) + self.shared_experts = SharedExpertMLP(config) + + def forward(self, hidden_states: paddle.Tensor) ->paddle.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + + Process: + 1. Route tokens to experts using the router. + 2. Permute tokens based on routing decisions. + 3. Process tokens through experts. + 4. Unpermute and combine expert outputs. + 5. Add shared expert output to the final result. + """ + scores, indices, tokens_per_expert = self.router(hidden_states) + permuted_tokens = self.token_dispatcher.token_permutation(hidden_states + , indices) + expert_output = self.experts(permuted_tokens, tokens_per_expert) + output = self.token_dispatcher.token_unpermutation(expert_output, + scores) + shared_expert_output = self.shared_experts(hidden_states) + output += shared_expert_output + return output + +import paddle.nn as nn +import paddle.nn.functional as F + +class BaseAttention(nn.Layer): + def __init__(self, config, layer_idx): + super(BaseAttention, self).__init__() + self.config = config + self.layer_idx = layer_idx + + def forward(self, hidden_states, attention_mask=None, **kwargs): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SelfAttention(BaseAttention): + def __init__(self, config, layer_idx): + super(SelfAttention, self).__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) + + def forward(self, hidden_states, attention_mask=None, **kwargs): + batch_size, seq_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states).reshape( + [batch_size, seq_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + key_states = self.k_proj(hidden_states).reshape( + [batch_size, seq_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + value_states = self.v_proj(hidden_states).reshape( + [batch_size, seq_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states, transpose_y=True) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_probs = F.softmax(attn_weights, axis=-1) + attn_output = paddle.matmul(attn_probs, value_states).transpose([0, 2, 1, 3]).reshape( + [batch_size, seq_len, self.hidden_size]) + + attn_output = self.out_proj(attn_output) + return attn_output + + + +LLAMA_ATTENTION_CLASSES = { + "self_attention": SelfAttention, +} + + +class MoEDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard `LlamaDecoderLayer` by + replacing the traditional MLP with a Mixture of Experts (MoE) Layer. + + Args: + config (LlamaConfig): Configuration object for the layer. + layer_idx (int): Index of the current layer in the model. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config) + self.hidden_size = config.hidden_size + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx) + self.mlp = MoELayer(config) + self.input_layernorm = (LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)) + self.post_attention_layernorm = (LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)) + + +class AriaMoELMModel(LlamaModel): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by + replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + + This model implements a Mixture of Experts (MoE) approach, where each layer contains + multiple expert networks that specialize in different aspects of the input. + + Args: + config (LlamaConfig): Configuration object for the model. + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = paddle.nn.Embedding(num_embeddings=config. + vocab_size, embedding_dim=config.hidden_size, padding_idx=self. + padding_idx) + self.layers = paddle.nn.LayerList(sublayers=[MoEDecoderLayer(config, + layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = (LlamaRotaryEmbedding(config=config)) + self.gradient_checkpointing = False + self.post_init() + + +class AriaMoELMForCausalLM(LlamaForCausalLM, GenerationMixin): + """ + AriaMoE model for causal language modeling tasks. + + This class extends LlamaForCausalLM to incorporate the Mixture of Experts (MoE) approach, + allowing for more efficient and scalable language modeling. + + Args: + config (AriaMoELMConfig): Configuration object for the model. + """ + _tied_weights_keys = ['lm_head.weight'] + config_class = AriaMoELMConfig + _no_split_modules = ['MoEDecoderLayer'] + + def __init__(self, config): + super().__init__(config) + self.model = AriaMoELMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = paddle.nn.Linear(in_features=config.hidden_size, + out_features=config.vocab_size, bias_attr=False) + self.post_init() + + def set_z_loss_coeff(self, z_loss_coeff: float): + """ + Set the coefficient for the z-loss in the MoE routing. + + Args: + z_loss_coeff (float): The coefficient for the z-loss. + """ + self.config.moe_z_loss_coeff = z_loss_coeff + + def set_aux_loss_coeff(self, aux_loss_coeff: float): + """ + Set the coefficient for the auxiliary loss in the MoE routing. + + Args: + aux_loss_coeff (float): The coefficient for the auxiliary loss. + """ + self.config.moe_aux_loss_coeff = aux_loss_coeff diff --git a/paddlemix/models/aria/projector.py b/paddlemix/models/aria/projector.py new file mode 100644 index 000000000..90bdcfe8e --- /dev/null +++ b/paddlemix/models/aria/projector.py @@ -0,0 +1,156 @@ +import paddle +import paddlenlp +from paddlemix.models.imagebind.transformer import MultiheadAttention +from paddlemix.activations import ACT2FN +class FFN(paddle.nn.Layer): + """ + Feed-Forward Network module. + + Args: + embed_dim (int): Input embedding dimension. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + """ + + def __init__(self, embed_dim, ff_dim, output_dim): + super().__init__() + self.linear_in = paddle.nn.Linear(in_features=embed_dim, + out_features=ff_dim, bias_attr=False) + self.linear_out = paddle.nn.Linear(in_features=ff_dim, out_features + =output_dim, bias_attr=False) + self.act = ACT2FN['gelu_new'] + + def forward(self, hidden_states): + hidden_states = self.act(self.linear_in(hidden_states)) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + +class CrossAttention(paddle.nn.Layer): + """ + Cross-Attention module. + + Args: + kv_dim (int): Dimension of key and value. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + drop_out_rate (float): Dropout rate. Default is 0. + """ + + def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + super().__init__() + self.num_heads = num_heads + self.q_proj = paddle.nn.Linear(in_features=embed_dim, out_features= + embed_dim, bias_attr=False) + self.k_proj = paddle.nn.Linear(in_features=kv_dim, out_features= + embed_dim, bias_attr=False) + self.v_proj = paddle.nn.Linear(in_features=kv_dim, out_features= + embed_dim, bias_attr=False) + self.multihead_attn = MultiheadAttention(embed_dim, num_heads) # 报错要自己实现MultiheadAttention + self.linear = paddle.nn.Linear(in_features=embed_dim, out_features= + embed_dim) + self.dropout = paddle.nn.Dropout(p=drop_out_rate) + self.layer_norm = paddle.nn.LayerNorm(normalized_shape=embed_dim) + self.ln_kv = paddle.nn.LayerNorm(normalized_shape=kv_dim) + + def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + """ + Forward pass of the CrossAttention module. + + Args: + x (torch.Tensor): Input tensor for key and value. + hidden_states (torch.Tensor): Input tensor for query. + attn_mask (torch.Tensor, optional): Attention mask. Default is None. + add_residual (bool): Whether to add residual connection. Default is False. + + Returns: + torch.Tensor: Output tensor after cross-attention. + """ + normed_hidden_states = self.layer_norm(hidden_states) + query = self.q_proj(normed_hidden_states).transpose(perm=[1, 0, 2]) + x = self.ln_kv(x) + key = self.k_proj(x).transpose(perm=[1, 0, 2]) + value = self.v_proj(x).transpose(perm=[1, 0, 2]) + attn_output, _ = self.multihead_attn(query, key, value, attn_mask= + attn_mask) + attn_output = attn_output.transpose(perm=[1, 0, 2]) + if add_residual: + attn_output = hidden_states + self.dropout(self.linear(attn_output) + ) + else: + attn_output = self.dropout(self.linear(attn_output)) + return attn_output + + +class AriaProjector(paddle.nn.Layer): + """ + A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__(self, patch_to_query_dict, embed_dim, num_heads, kv_dim, + ff_dim, output_dim, norm_layer=paddle.nn.LayerNorm): + super().__init__() + self.patch_to_query_dict = patch_to_query_dict + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query = paddle.base.framework.EagerParamBase.from_tensor(tensor + =paddle.zeros(shape=[max(patch_to_query_dict.values()), self. + embed_dim])) + init_TruncatedNormal = paddle.nn.initializer.TruncatedNormal(std=0.02) + init_TruncatedNormal(self.query) + self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + self.ln_ffn = norm_layer(embed_dim) + self.ffn = FFN(embed_dim, ff_dim, output_dim) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, paddle.nn.Linear): + init_TruncatedNormal = paddle.nn.initializer.TruncatedNormal(std + =0.02) + init_TruncatedNormal(m.weight) + if isinstance(m, paddle.nn.Linear) and m.bias is not None: + init_Constant = paddle.nn.initializer.Constant(value=0) + init_Constant(m.bias) + elif isinstance(m, paddle.nn.LayerNorm): + init_Constant = paddle.nn.initializer.Constant(value=0) + init_Constant(m.bias) + init_Constant = paddle.nn.initializer.Constant(value=1.0) + init_Constant(m.weight) + + def forward(self, x, attn_mask=None): + """ + Forward pass of the Projector module. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim). + attn_mask (torch.Tensor, optional): Attention mask. Default is None. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim). + """ + bs = tuple(x.shape)[0] + queries = self.query.unsqueeze(axis=0).tile(repeat_times=[bs, 1, 1]) + query_num = self.patch_to_query_dict.get(tuple(x.shape)[1], None) + assert query_num is not None, f'Query number for {tuple(x.shape)[1]} patches is not provided' + queries = queries[:, :query_num, :] + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(repeats=self.num_heads, + axis=0) + attn_mask = attn_mask.unsqueeze(axis=1).expand(shape=[-1, + queries.shape[1], -1]) + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + out = self.ffn(self.ln_ffn(attention_out)) + return out diff --git a/paddlemix/models/aria/vision_encoder.py b/paddlemix/models/aria/vision_encoder.py new file mode 100644 index 000000000..f5952cb93 --- /dev/null +++ b/paddlemix/models/aria/vision_encoder.py @@ -0,0 +1,116 @@ +import paddle +import paddlenlp +from paddlenlp.transformers import PretrainedConfig +from ppdiffusers.transformers.model_utils import PretrainedModel +from typing import Optional, Tuple, Union +# from .modeling_aria import AriaPretrainedModel +from paddlenlp.transformers.model_outputs import BaseModelOutput, BaseModelOutputWithPooling +import os +from paddle.nn import Layer, Linear, Conv2D, LayerList +from paddle.nn import functional as F +from paddle.nn.layer.norm import LayerNorm +from paddlenlp.transformers.auto.configuration import AutoConfig +from paddlemix.models.minicpm_v.modeling_navit_siglip import SigLipAttention, SigLipVisionConfig + + + +class AriaVisionConfig(SigLipVisionConfig): + """Configuration class for AriaVisionModel.""" + model_type = 'aria_vision_model' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class IdentityOp(paddle.nn.Layer): + """ + An identity operation that returns the input unchanged. + + This can be used as a placeholder or to maintain architectural consistency + when a specific operation is not needed. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + + +class AriaVisionTransformer(SigLipAttention): + """ + Aria Vision Transformer model based on Idefics2VisionTransformer. + + This class extends the original Idefics2VisionTransformer by removing the post-layernorm operation. + """ + + def __init__(self, config: AriaVisionConfig): + + super().__init__(config) + self.post_layernorm = IdentityOp() + + + +class AriaVisionModel(PretrainedModel): + """ + Aria Vision Model extends SiglipVisionModel to support pixel_mask. + + The pixel_mask is a 2D boolean tensor that indicates which pixels in the input + image are actual content and which are padding. It has the same height and width + as the input image, where: + - True (1) values represent pixels from the original image + - False (0) values represent padding pixels + + This mask helps the model focus on the relevant parts of the image during processing. + """ + config_class = AriaVisionConfig + main_input_name = 'pixel_values' + _supports_sdpa = False + def __init__(self, config: AriaVisionConfig): + super().__init__(config) + self.vision_model = AriaVisionTransformer(config) + self.post_init() + + def forward(self, pixel_values: paddle.Tensor, pixel_mask: Optional[ + paddle.Tensor]=None, output_attentions: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, return_dict: Optional[ + bool]=None) ->Union[Tuple, BaseModelOutputWithPooling]: + """ + Forward pass of the AriaVisionModel. + + Args: + pixel_values (torch.Tensor): The pixel values of the input images. + pixel_mask (Optional[torch.BoolTensor]): Mask for the pixel values. + output_attentions (Optional[bool]): Whether to output attentions. + output_hidden_states (Optional[bool]): Whether to output hidden states. + return_dict (Optional[bool]): Whether to return a ModelOutput object. + + Returns: + Union[Tuple, BaseModelOutputWithPooling]: The model's output. + """ + return_dict = (return_dict if return_dict is not None else self. + config.use_return_dict) + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + vit_oup = self.vision_model(pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, output_attentions= + output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict) + image_atts = self._create_image_attention_mask(patch_attention_mask) + return vit_oup, image_atts + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + patches_subgrid = pixel_mask.unfold(axis=1, size=self.vision_model. + config.patch_size, step=self.vision_model.config.patch_size + ).unfold(axis=2, size=self.vision_model.config.patch_size, step + =self.vision_model.config.patch_size) + return (patches_subgrid.sum(axis=(-1, -2)) > 0).astype(dtype='bool') + + def _create_image_attention_mask(self, patch_attention_mask): + if patch_attention_mask is None: + return None + flattened_mask = patch_attention_mask.flatten(start_axis=1) + return paddle.logical_not(x=flattened_mask) + diff --git a/paddlemix/processors/__init__.py b/paddlemix/processors/__init__.py index 1d9ee61eb..47e43ff20 100644 --- a/paddlemix/processors/__init__.py +++ b/paddlemix/processors/__init__.py @@ -35,3 +35,5 @@ from .tokenizer import SimpleTokenizer, tokenize from .visualglm_image_processing import * from .visualglm_processing import * +from .aria_vision_processor import * +from .processing_aria import * \ No newline at end of file diff --git a/paddlemix/processors/aria_vision_processor.py b/paddlemix/processors/aria_vision_processor.py new file mode 100644 index 000000000..211c38659 --- /dev/null +++ b/paddlemix/processors/aria_vision_processor.py @@ -0,0 +1,208 @@ +import paddle +import paddlenlp +from typing import List, Optional, Union +import numpy as np +from PIL import Image, ImageOps +from paddlenlp.transformers.image_processing_utils import BaseImageProcessor +from paddlenlp.transformers.feature_extraction_utils import BatchFeature +from paddlenlp.taskflow.utils import Compose +from paddlemix.models.imagebind.helpers import Normalize +from paddlemix.processors.image_utils import TensorType +import paddle.vision.transforms as T + +def _select_best_resolution(img_width: int, img_height: int, target_ratios: + List[List[int]], patch_size: int): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + img_width: the original widths of images. + img_height: the original heights of images. + target_ratios (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + aspect_ratio = img_width / img_height + best_ratio_diff = float('inf') + best_ratio_w, best_ratio_h = 1, 1 + area = np.int32(img_width) * np.int32(img_height) + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + elif ratio_diff == best_ratio_diff and area > 0.5 * patch_size * patch_size * ratio[ + 0] * ratio[1]: + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + return best_ratio_w, best_ratio_h + + +def _split_image(image: Image.Image, split_image: bool, split_ratio: List[ + List[int]], patch_size: int) ->List[Image.Image]: + """ + Split image into multiple patches + + Args: + image (PIL.Image): Input image. + split_image (bool): Whether to split the image into patches. + split_ratio (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + List[PIL.Image]: List of splitted images. + """ + if split_image: + ratio_width, ratio_height = _select_best_resolution(image.width, + image.height, split_ratio, patch_size) + resize_width = patch_size * ratio_width + resize_height = patch_size * ratio_height + blocks = ratio_width * ratio_height + resized_img = image.resize((resize_width, resize_height)) + processed_images = [] + for i in range(blocks): + box = i % (resize_width // patch_size) * patch_size, i // ( + resize_width // patch_size) * patch_size, (i % ( + resize_width // patch_size) + 1) * patch_size, (i // ( + resize_width // patch_size) + 1) * patch_size + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if len(processed_images) != 1: + processed_images.insert(0, image) + return processed_images + else: + return [image] + + +def keep_ratio_resize_and_pixel_mask(img: Image.Image, max_size, min_size= + 336, padding_value=0): + """ + Resize an image while maintaining aspect ratio and create a pixel mask. + + Args: + img (PIL.Image): Input image. + max_size (int): Maximum size for the larger dimension of the image. + min_size (int, optional): Minimum size for the smaller dimension. Defaults to 336. + padding_value (int, optional): Value used for padding. Defaults to 0. + + Returns: + tuple: A tuple containing: + - PIL.Image: Resized and padded image. + - torch.Tensor: Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where: + - True (1) values indicate pixels that belong to the original resized image. + - False (0) values indicate pixels that are part of the padding. + The mask helps distinguish between actual image content and padded areas in subsequent processing steps. + """ + img = img.convert('RGB') + scale = max_size / max(img.size) + w, h = img.size + if w >= h: + new_size = max_size, max(int(h * scale), min_size) + else: + new_size = max(int(w * scale), min_size), max_size + img_resized = img.resize(new_size, resample=Image.Resampling.BICUBIC) + padding_right, padding_bottom = max_size - new_size[0 + ], max_size - new_size[1] + img_padded = ImageOps.expand(img_resized, (0, 0, padding_right, + padding_bottom), fill=padding_value) + pixel_mask = paddle.zeros(shape=[max_size, max_size]) + pixel_mask[:new_size[1], :new_size[0]] = 1 + pixel_mask = pixel_mask.astype(dtype='bool') + return img_padded, pixel_mask + +class AriaVisionProcessor(BaseImageProcessor): + """ + A vision processor for the Aria model that handles image preprocessing. + """ + + def __init__(self, max_image_size=980, min_image_size=336, image_mean=[ + 0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], **kwargs): + """ + Initialize the AriaVisionProcessor. + + Args: + max_image_size (int, optional): Maximum image size. Defaults to 980. + min_image_size (int, optional): Minimum image size. Defaults to 336. + mean (list, optional): Mean values for normalization. Defaults to [0.5, 0.5, 0.5]. + std (list, optional): Standard deviation values for normalization. Defaults to [0.5, 0.5, 0.5]. + """ + super().__init__(**kwargs) + self.max_image_size = max_image_size + self.min_image_size = min_image_size + self.image_mean = image_mean + self.image_std = image_std + self.auto_map = {'AutoProcessor': 'processing_aria.AriaProcessor', + 'AutoImageProcessor': 'vision_processor.AriaVisionProcessor'} + self._transform = None + self._set_processor_class('AriaProcessor') + + @property + def transform(self): + if self._transform is None: + self._transform = Compose([T.ToTensor(), Normalize( # 很可能Compose 和 Normalize 导错了,有可能都是T. + self.image_mean, self.image_std)]) + return self._transform + + def __call__(self, images: Union[Image.Image, List[Image.Image]], + max_image_size: Optional[int]=980, min_image_size: Optional[int]= + 336, return_tensors: Optional[Union[str, TensorType]]= + 'pt', split_image: Optional[bool]=False, split_ratio: Optional[List + [List[int]]]=[[1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8 + ], [2, 4], [2, 3], [2, 2], [2, 1], [3, 1], [3, 2], [4, 1], [4, 2], + [5, 1], [6, 1], [7, 1], [8, 1]]): + """ + Process a list of images. + + Args: + images (list): List of PIL.Image objects. + max_image_size (int, optional): Override the default max image size. Defaults to None. + return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt". + split_image (bool, optional): Whether to split the image. Defaults to False. + split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios. + Returns: + BatchFeature: A BatchFeature object containing: + - 'pixel_values': Tensor of processed image pixel values. + - 'pixel_mask': Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where: + - True (1) values indicate pixels that belong to the original resized image. + - False (0) values indicate pixels that are part of the padding. + The mask helps distinguish between actual image content and padded areas in subsequent processing steps. + - 'num_crops': Tensor of the number of crops for each image. + """ + max_size = (self.max_image_size if max_image_size is None else + max_image_size) + min_size = (self.min_image_size if min_image_size is None else + min_image_size) + if max_size not in [490, 980]: + raise ValueError('max_image_size must be either 490 or 980') + if isinstance(images, Image.Image): + images = [images] + pixel_values = [] + pixel_masks = [] + num_crops = [] + for image in images: + crop_images = _split_image(image, split_image, split_ratio, + max_size) + num_crops.append(paddle.to_tensor(data=len(crop_images))) + for crop_image in crop_images: + img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask( + crop_image, max_size, min_size) + img_padded = self.transform(img_padded) + pixel_values.append(img_padded) + pixel_masks.append(pixel_mask) + return BatchFeature(data={'pixel_values': paddle.stack + (x=pixel_values), 'pixel_mask': paddle.stack(x=pixel_masks), + 'num_crops': paddle.stack(x=num_crops)}, tensor_type=return_tensors + ) + + def preprocess(self, images, max_image_size=None, min_image_size=None, + return_tensors: Optional[Union[str, TensorType]]=None, + split_image: Optional[bool]=False, split_ratio: Optional[List[List[ + int]]]=[[1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8], [2, + 4], [2, 3], [2, 2], [2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [5, 1], + [6, 1], [7, 1], [8, 1]]): + return self.__call__(images, max_image_size=max_image_size, + min_image_size=min_image_size, return_tensors=return_tensors, + split_image=split_image, split_ratio=split_ratio) diff --git a/paddlemix/processors/processing_aria.py b/paddlemix/processors/processing_aria.py new file mode 100644 index 000000000..0e0b5fe33 --- /dev/null +++ b/paddlemix/processors/processing_aria.py @@ -0,0 +1,258 @@ +import paddlenlp +import inspect +import logging +import re +from typing import List, Optional, Union, Tuple +from typing import List, Optional, Union +from .aria_vision_processor import AriaVisionProcessor +from paddlenlp.transformers.processing_utils import ProcessorMixin +from paddlenlp.transformers.auto.tokenizer import AutoTokenizer +from paddlenlp.transformers.feature_extraction_utils import BatchFeature +from paddlenlp.transformers.tokenizer_utils_base import PaddingStrategy +from paddlenlp.transformers.image_utils import ImageInput +from paddlenlp.transformers.tokenizer_utils_base import TensorType, TruncationStrategy + +logger = logging.getLogger(__name__) + +TextInput = str +PreTokenizedInput = List[str] +EncodedInput = List[int] +TextInputPair = Tuple[str, str] +PreTokenizedInputPair = Tuple[List[str], List[str]] +EncodedInputPair = Tuple[List[int], List[int]] + +class AriaProcessor(ProcessorMixin): + """ + AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer. + Args: + image_processor(AriaVisionProcessor): The AriaVisionProcessor to use for image preprocessing. + tokenizer(AutoTokenizer): The AutoTokenizer to use for tokenizing the text. + patch_size(int): The patch size to use for the image processor. + chat_template(str): The chat template to use for the tokenizer. + image_token(str): The image token to use for the tokenizer. + """ + attributes = [] + valid_kwargs = ['chat_template', 'patch_size', 'image_token'] + image_processor_class = None + tokenizer_class = 'AutoTokenizer' + + def __init__(self, image_processor: AriaVisionProcessor=None, tokenizer: + Union[AutoTokenizer, str]=None, patch_size: int=490, + chat_template: str=None, image_token: str='<|img|>'): + super().__init__(chat_template=chat_template) + if image_processor is None: + self.image_processor = AriaVisionProcessor(max_image_size= + patch_size) + else: + self.image_processor = image_processor + if isinstance(tokenizer, str): + # self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True, use_fast=False) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self.tokenizer = tokenizer + if self.tokenizer is not None and self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.unk_token + self.image_token = image_token + + + def __call__(self, text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], images: ImageInput=None, + padding: Union[bool, str, PaddingStrategy]=False, truncation: Union[bool, str, TruncationStrategy]=None, + max_length: Optional[int]=None, max_image_size: Optional[int]=980, + split_image: Optional[bool]=False, return_tensors: Optional[Union[str, TensorType]]=TensorType.PADDLE, + return_final_prompts: Optional[bool]=False) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + max_image_size (`int`, *optional*): + Maximum size of the image to be processed. + split_image (`bool`, *optional*): + Whether to split the image into patches before processing. + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`. + """ + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + 'Invalid input text. Please provide a string, or a list of strings' + ) + if images is not None: + image_inputs = self.image_processor(images, return_tensors= + return_tensors, max_image_size=max_image_size, split_image= + split_image) + prompt_strings = [] + crop_iter = iter(image_inputs.pop('num_crops')) + for prompt in text: + prompt_strings.append(re.sub(re.escape(self.image_token), + lambda _: next(crop_iter) * self.image_token, prompt)) + max_image_size = (max_image_size if max_image_size is not None else + self.image_processor.max_image_size) + if max_image_size == 490: + num_image_tokens = 128 + elif max_image_size == 980: + num_image_tokens = 256 + else: + raise ValueError( + f'max_image_size must be either 490 or 980, got {max_image_size}' + ) + prompt_strings = [sample.replace(self.image_token, self. + image_token * num_image_tokens) for sample in prompt_strings] + else: + image_inputs = {} + prompt_strings = text + text_inputs = self.tokenizer(prompt_strings, return_tensors= + return_tensors, padding=padding, truncation=truncation, + max_length=max_length) + if return_final_prompts: + return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings + else: + return BatchFeature(data={**text_inputs, **image_inputs}) + + @staticmethod + def _extract_kwargs(func: callable, **kwargs) ->dict: + """ + Extract the kwargs that are valid for the given function. + """ + return {k: v for k, v in kwargs.items() if k in inspect.signature( + func).parameters} + + def save_pretrained(self, save_directory, **kwargs): + """ + Save both the image processor and tokenizer. + """ + if self.image_processor is not None: + self.image_processor.save_pretrained(save_directory, **self. + _extract_kwargs(self.image_processor.save_pretrained, **kwargs) + ) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(save_directory, **self. + _extract_kwargs(self.tokenizer.save_pretrained, **kwargs)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, tokenizer_path= + None, image_processor_path=None, **kwargs): + """ + Load both the image processor and tokenizer from a pretrained model path. + """ + tokenizer_path = (tokenizer_path if tokenizer_path is not None else + pretrained_model_name_or_path) + image_processor_path = (image_processor_path if + image_processor_path is not None else pretrained_model_name_or_path + ) + image_processor = AriaVisionProcessor.from_pretrained( + image_processor_path, **cls._extract_kwargs(AriaVisionProcessor + .from_pretrained, **kwargs)) + # if 'use_fast' in kwargs: + # logger.warning( + # 'use_fast is not supported for AriaProcessor. Ignoring...') + # kwargs.pop('use_fast') + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, **cls._extract_kwargs( + AutoTokenizer.from_pretrained, **kwargs)) + # tokenizer = AutoTokenizer.from_pretrained( + # tokenizer_path, use_fast=False, **cls._extract_kwargs( + # AutoTokenizer.from_pretrained, **kwargs)) + # chat_template = tokenizer.chat_template + if 'chat_template' in kwargs: + tokenizer.chat_template = kwargs['chat_template'] + else: + chat_template = None + except Exception as e: + logger.warning( + f'Failed to load tokenizer from {tokenizer_path}: {e}') + tokenizer = None + if 'chat_template' in kwargs: + chat_template = None + # return cls(image_processor=image_processor, tokenizer=tokenizer, + # chat_template=chat_template) + return cls(image_processor=image_processor, tokenizer=tokenizer) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + if self.tokenizer is None: + raise ValueError( + 'Tokenizer is not initialized. Please provide a valid tokenizer.' + ) + return self.tokenizer.batch_decode(*args, **kwargs) + + def apply_chat_template(self, text: str, chat_template: Optional[str] = None) -> str: + """ + Apply the chat template to the given text. + + Args: + text (str): The input text to apply the chat template to. + chat_template (Optional[str]): The chat template to use. If not provided, the default chat template will be used. + + Returns: + str: The text with the chat template applied. + """ + if chat_template is None: + chat_template = self.chat_template + + if chat_template is None: + raise ValueError("No chat template provided or set.") + + # Assuming the chat template has placeholders like "{text}" that need to be replaced with the actual text + formatted_text = chat_template.format(text=text) + + return formatted_text + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + if self.tokenizer is None: + raise ValueError( + 'Tokenizer is not initialized. Please provide a valid tokenizer.' + ) + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + + image_processor_input_names))