From 09e922158a71fbcf1c8cae743220bbf376c9ea87 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 11:27:51 +0200 Subject: [PATCH 01/81] First commit --- src/transformers/__init__.py | 12 + src/transformers/models/bloom/__init__.py | 35 +- .../models/bloom/modeling_flax_bloom.py | 809 ++++++++++++++++++ .../models/bloom/test_modeling_flax_bloom.py | 214 +++++ 4 files changed, 1066 insertions(+), 4 deletions(-) create mode 100644 src/transformers/models/bloom/modeling_flax_bloom.py create mode 100644 tests/models/bloom/test_modeling_flax_bloom.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fe0a68e056a2..e09bc50a49bc 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3883,6 +3883,13 @@ "FlaxBlenderbotSmallPreTrainedModel", ] ) + _import_structure["models.bloom"].extend( + [ + "FlaxBloomForCausalLM", + "FlaxBloomModel", + "FlaxBloomPreTrainedModel", + ] + ) _import_structure["models.clip"].extend( [ "FlaxCLIPModel", @@ -7263,6 +7270,11 @@ FlaxBlenderbotSmallModel, FlaxBlenderbotSmallPreTrainedModel, ) + from .models.bloom import ( + FlaxBloomForCausalLM, + FlaxBloomModel, + FlaxBloomPreTrainedModel, + ) from .models.clip import ( FlaxCLIPModel, FlaxCLIPPreTrainedModel, diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index c91b3d0f385a..c14285d28a9a 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -14,11 +14,14 @@ from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available, is_flax_available _import_structure = { - "configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"], + "configuration_bloom": [ + "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BloomConfig", + ], } try: if not is_tokenizers_available(): @@ -44,8 +47,21 @@ "BloomForQuestionAnswering", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_bloom"] = [ + "FlaxBloomForCausalLM", + "FlaxBloomModel", + "FlaxBloomPreTrainedModel", + ] + + if TYPE_CHECKING: - from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig + from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig try: if not is_tokenizers_available(): @@ -71,7 +87,18 @@ BloomPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_bloom import ( + FlaxBloomForCausalLM, + FlaxBloomModel, + FlaxBloomPreTrainedModel, + ) else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py new file mode 100644 index 000000000000..da268997f087 --- /dev/null +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -0,0 +1,809 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and Bigscience Workshop. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax BLOOM model. """ +# TODO: see todos throughout this file +# TODO: check correctness against pytorch implementation +# TODO: add unit tests +# TODO: add documentation / check that documentation is correct + # TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) +# TODO: check that code is jit-able + +from functools import partial +from typing import Any, Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_bloom import BloomConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bigscience/bloom" +_CONFIG_FOR_DOC = "BloomConfig" +_TOKENIZER_FOR_DOC = "BloomTokenizer" + +def attention_mask_func(args): + # TODO: implement this helper fn. see pytorch impl. for reference + raise NotImplementedError + +BLOOM_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BloomConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BLOOM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +class FlaxBloomScaledSoftmax(nn.Module): + config: BloomConfig + mask_func: Callable + softmax_in_fp32: bool + scale: float = None + """ + Scaled Softmax module. Also performs masking. + Args: + mask_func (`function`, *required*): + mask function to be applied. + softmax_in_fp32 (`bool`, *required*): + if true, softmax in performed at fp32 precision. + scale (`float`, *optional*): + scaling factor used in input tensor scaling. + """ + + def setup(self): + if not (self.scale is None or self.softmax_in_fp32): + raise ValueError("softmax should be in fp32 if scale is not `None`") + + def __call__(self, input, mask, causal_mask): + input_dtype = input.dtype + softmax_dtype = jnp.float32 if self.softmax_in_fp32 else input_dtype + + if self.scale is not None: + input = input * self.scale + + if mask is not None: + mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) + # TODO: ideally we could pass a dtype argument to nn.softmax like in PyTorch (see discussion on PR #17474 about fp32 softmax) + mask_output.astype(softmax_dtype) + probs = nn.softmax(mask_output, axis=-1) * (~padded_causal_mask) + else: + input.astype(softmax_dtype) + probs = nn.softmax(input, axis=-1) + + if input_dtype != softmax_dtype: + probs = probs.astype(input_dtype) + + return probs + + +class FlaxBloomAttention(nn.Module): + config: BloomConfig + layer_number: int = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # TODO: make sure these affect behavior correctly + self.pretraining_tp = self.config.pretraining_tp + self.slow_but_exact = self.config.slow_but_exact + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + # TODO: deal with softmax + self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32 + # TODO: deal with hidden dropout + self.hidden_dropout = self.config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + # Layer-wise attention scaling + self.norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * self.layer_number + + + self.attn_dropout = nn.Dropout(self.config.attention_dropout) + + self.scale_mask_softmax = FlaxBloomScaledSoftmax( + self.config, + attention_mask_func, + self.attention_softmax_in_fp32, + self.layer_number, + ) + + dense = partial( + nn.Dense, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + + self.query_key_value = dense(self.hidden_size * 3) + self.dense = dense(self.hidden_size) + self.attention_dropout = nn.Dropout(self.config.attention_dropout) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + residual, + layer_past=None, + attention_mask=None, + alibi=None, + head_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + # TODO: this module __call__ needs checking for correctness of implementation. + fused_qkv = self.query_key_value(hidden_states) + + query, key, value = jnp.split(fused_qkv, 3, axis=-1) + + query = self._split_heads(query) + key = self._split_heads(key) + value = self._split_heads(value) + + query_length, key_length = query.shape[1], key.shape[1] + + # TODO: check size of hidden_states to confirm this is the right dim for causal mask to use + causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[0]), dtype="bool"), dtype="bool") + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + if attention_mask is not None: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + + dropout_rng = None + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + + attention_bias = None + + # transform boolean mask into float mask + if attention_mask is not None: + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + ) + + + # usual dot product attention + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.dense(attn_output) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxBloomMLP(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + hidden_size = self.config.hidden_size + + self.pretraining_tp = self.config.pretraining_tp + self.slow_but_exact = self.config.slow_but_exact + + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) + + self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) + self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) + self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) + self.act = ACT2FN[self.config.activation_function] + + def __call__(self, hidden_states, residual, deterministic: bool = True): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + + # TODO: this code block is from the pytorch implementation. needs changing to work. + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = jnp.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + nn.functional.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxBloomBlock(nn.Module): + config: BloomConfig + layer_number: int = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + hidden_size = self.config.hidden_size + + self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + + self.self_attention = FlaxBloomAttention(self.config, layer_number=self.layer_number, dtype=self.dtype) + self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + + self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) + + self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm + self.hidden_dropout = self.config.hidden_dropout + + def __call__( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + alibi=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + layernorm_output = self.input_layernorm(hidden_states) + # layer norm before saving residual if config calls for it + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # self-attention + attn_outputs = self.self_attention( + hidden_states, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + attention_output = attention_output + residual + + post_layernorm = self.post_attention_layernorm(attention_output) + + # set residual based on config + if self.apply_residual_connection_post_layernorm: + residual = post_layernorm + else: + residual = attention_output + + output = self.mlp(post_layernorm, residual, deterministic=deterministic) + + output = output + residual + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs + +# TODO: does this still require position_ids? +# TODO: gradient checkpointing +# TODO: _no_split_modules? +# TODO: check initialization +class FlaxBloomPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BloomConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: BloomConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + # TODO: check whether this is correct (position ids might not be required) + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) + def __call__( + self, + input_ids, + past_key_values: dict = None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, sequence_length = input_ids.shape + + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + # TODO: build alibi here? + # TODO: check the inputs and their order to this + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + +# TODO: haven't modified this block yet, remove if remains unused +class FlaxBloomBlockCollection(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. + self.blocks = [ + FlaxBloomBlock(self.config, layer_number=max(1, i), name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxBloomModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + +class FlaxBloomModule(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # TODO: check initialization correctness + self.embed_dim = self.config.hidden_size + + embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) + + # word embeddings (no positional embedding layer) TODO: confirm this statement correct + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.embed_dim, + embedding_init=embedding_init, + ) + # post-embedding layernorm + self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) + + # transformer layers + self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype) + + # final layernorm + self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + # TODO: change how gradient checkpointing is done + self.gradient_checkpointing = False + + def __call__( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + # do post-embedding layernorm + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + past_key_values = () if use_cache else None # TODO: come back to this line + # TODO: how to handle alibi? build alibi tensor here? + + # TODO: fix inputs to this (and args to submodules in general) + # TODO: gradient checkpointing + outputs = self.h( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + # TODO: don't think this return value / ordering is correct + return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings( + "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", + BLOOM_START_DOCSTRING, +) +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom +class FlaxBloomModel(FlaxBloomPreTrainedModel): + module_class = FlaxBloomModule + + +append_call_sample_docstring( + FlaxBloomModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC +) + + +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLMModule with GPTNeo->Bloom +class FlaxBloomForCausalLMModule(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.transformer = FlaxBloomModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """ + The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + BLOOM_START_DOCSTRING, +) +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->Bloom +class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): + module_class = FlaxBloomForCausalLMModule + # TODO: check if this class is correct / take out position ids + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since Bloom uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBloomForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC +) \ No newline at end of file diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py new file mode 100644 index 000000000000..e533fc7cc73f --- /dev/null +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -0,0 +1,214 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import timeout_decorator # noqa + +from transformers import BloomConfig, is_flax_available +from transformers.testing_utils import require_flax, require_sentencepiece, slow + +from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin +from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import os + + # The slow tests are often failing with OOM error on GPU + # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed + # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + + import jax + import jax.numpy as jnp + from transformers import FlaxBloomForCausalLM, FlaxBloomModel, BloomTokenizerFast + + +def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): + if attention_mask is None: + attention_mask = np.where(input_ids != config.pad_token_id, 1, 0) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + +@require_flax +class FlaxBloomModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + n_layer=2, + n_head=4, + hidden_act="gelu", + hidden_dropout=0.1, + attention_probs_dropout_prob=0.1, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + initializer_range=0.02, + apply_residual_connection_post_layernorm=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.n_layer = n_layer + self.n_head = n_head + self.hidden_act = hidden_act + self.hidden_dropout = hidden_dropout + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.initializer_range = initializer_range + self.is_encoder_decoder = False + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + + def prepare_config_and_inputs(self): + input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size) + input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1) + + config = BloomConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + n_layer=self.n_layer, + n_head=self.n_head, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_probs_dropout_prob, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=False, + use_cache=False, + ) + inputs_dict = prepare_opt_inputs_dict(config, input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def check_use_cache_forward(self, model_class_name, config, inputs_dict): + max_length = 20 + model = model_class_name(config) + + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + + past_key_values = model.init_cache(input_ids.shape[0], max_length) + attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4") + + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], + (input_ids.shape[0], input_ids.shape[-1] - 1), + ) + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model( + input_ids[:, -1:], + attention_mask=attention_mask, + past_key_values=outputs_cache.past_key_values, + position_ids=position_ids, + ) + + outputs = model(input_ids) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): + max_length = 20 + model = model_class_name(config) + + input_ids, attention_mask = ( + inputs_dict["input_ids"], + inputs_dict["attention_mask"], + ) + + attention_mask_cache = jnp.concatenate( + [ + attention_mask, + jnp.zeros((attention_mask.shape[0], max_length - attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(input_ids.shape[0], max_length) + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], + (input_ids.shape[0], input_ids.shape[-1] - 1), + ) + + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask_cache, + past_key_values=past_key_values, + position_ids=position_ids, + ) + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model( + input_ids[:, -1:], + past_key_values=outputs_cache.past_key_values, + attention_mask=attention_mask_cache, + position_ids=position_ids, + ) + + outputs = model(input_ids, attention_mask=attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + +@require_flax +class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin): + all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else () + all_generative_model_classes = () if is_flax_available() else () + + def setUp(self): + self.model_tester = FlaxBloomModelTester(self) + + def test_use_cache_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) + + def test_use_cache_forward_with_attn_mask(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("bigscience/bloom-350m") + input_ids = np.ones((1, 1)) * model.config.eos_token_id + outputs = model(input_ids) + self.assertIsNotNone(outputs) From 5c215279ba7dd2720922d23aeaf265ce935d959d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 11:45:48 +0200 Subject: [PATCH 02/81] step 1 working --- .../models/bloom/modeling_flax_bloom.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index da268997f087..b2dbe8910df4 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -30,6 +30,7 @@ from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict +from flax.linen.activation import tanh from jax import lax from ...modeling_flax_outputs import ( @@ -329,6 +330,11 @@ def __call__( outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs +class BloomGELU(nn.Module): + def setup(self): + self.dtype = jnp.float32 + def __call__(self, x): + return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) class FlaxBloomMLP(nn.Module): config: BloomConfig @@ -345,7 +351,7 @@ def setup(self): self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) - self.act = ACT2FN[self.config.activation_function] + self.act = BloomGELU() def __call__(self, hidden_states, residual, deterministic: bool = True): hidden_states = self.dense_h_to_4h(hidden_states) @@ -363,7 +369,7 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): else: intermediate_output = self.dense_4h_to_h(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) return hidden_states @@ -396,6 +402,7 @@ def __call__( deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, + use_cache: bool = False, ): layernorm_output = self.input_layernorm(hidden_states) # layer norm before saving residual if config calls for it @@ -755,7 +762,7 @@ def __call__( hidden_states = outputs[0] if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) From 794a50890500eb401f078bdbfea3241984b701e7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 12:49:04 +0200 Subject: [PATCH 03/81] add alibi --- .../models/bloom/modeling_flax_bloom.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index b2dbe8910df4..f7b2228d4ad6 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -22,6 +22,7 @@ from functools import partial from typing import Any, Callable, Optional, Tuple +import math import flax.linen as nn import jax @@ -120,6 +121,31 @@ def attention_mask_func(args): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +def flax_unsqueeze(x, axis): + return jnp.expand_dims(x, axis) + +def build_alibi_tensor_flax(max_seq_len, n_head, dtype): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + slopes = jnp.array(get_slopes(n_head)) + slopes = flax_unsqueeze(flax_unsqueeze(slopes, 1), 1) + arange_tensor = flax_unsqueeze(flax_unsqueeze(jnp.arange(max_seq_len, dtype=dtype), 0), 0) + + alibi = slopes * jnp.broadcast_to(arange_tensor, (n_head, 1, arange_tensor.shape[-1])) + return alibi + class FlaxBloomScaledSoftmax(nn.Module): config: BloomConfig mask_func: Callable @@ -253,9 +279,9 @@ def __call__( self, hidden_states, residual, + alibi, layer_past=None, attention_mask=None, - alibi=None, head_mask=None, deterministic: bool = True, init_cache: bool = False, @@ -264,11 +290,10 @@ def __call__( # TODO: this module __call__ needs checking for correctness of implementation. fused_qkv = self.query_key_value(hidden_states) - query, key, value = jnp.split(fused_qkv, 3, axis=-1) + new_tensor_shape = fused_qkv.shape[:-1] + (self.num_heads, 3 * self.head_dim) + fused_qkv = fused_qkv.reshape(new_tensor_shape) - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) + query, key, value = jnp.split(fused_qkv, 3, axis=-1) query_length, key_length = query.shape[1], key.shape[1] @@ -395,10 +420,10 @@ def setup(self): def __call__( self, hidden_states, + alibi, layer_past=None, attention_mask=None, head_mask=None, - alibi=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, @@ -413,11 +438,11 @@ def __call__( # self-attention attn_outputs = self.self_attention( - hidden_states, + layernorm_output, residual, + alibi, layer_past=layer_past, attention_mask=attention_mask, - alibi=alibi, head_mask=head_mask, deterministic=deterministic, init_cache=init_cache, @@ -595,6 +620,7 @@ def setup(self): def __call__( self, hidden_states, + alibi, attention_mask=None, deterministic: bool = True, init_cache: bool = False, @@ -612,6 +638,7 @@ def __call__( layer_outputs = block( hidden_states, attention_mask, + alibi, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, @@ -673,6 +700,9 @@ def __call__( # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) + curr_seq_len = hidden_states.shape[1] + alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) + past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? @@ -680,7 +710,8 @@ def __call__( # TODO: gradient checkpointing outputs = self.h( hidden_states, - attention_mask, + alibi=alibi, + attention_mask=attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, From 8416c3762cc38b6c193822777d0248f8c2503f36 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 12:44:05 +0100 Subject: [PATCH 04/81] placeholder for `scan` --- .../models/bloom/modeling_flax_bloom.py | 62 ++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index f7b2228d4ad6..78f1aaf07650 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -30,6 +30,7 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.linen.partitioning import scan_with_axes from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen.activation import tanh from jax import lax @@ -403,10 +404,9 @@ class FlaxBloomBlock(nn.Module): config: BloomConfig layer_number: int = None dtype: jnp.dtype = jnp.float32 + use_scan: bool = False def setup(self): - hidden_size = self.config.hidden_size - self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.self_attention = FlaxBloomAttention(self.config, layer_number=self.layer_number, dtype=self.dtype) @@ -429,6 +429,9 @@ def __call__( output_attentions: bool = False, use_cache: bool = False, ): + if self.use_scan: + hidden_states = hidden_states[0] + layernorm_output = self.input_layernorm(hidden_states) # layer norm before saving residual if config calls for it if self.apply_residual_connection_post_layernorm: @@ -472,6 +475,9 @@ def __call__( else: outputs = (output,) + outputs[1:] + if self.use_scan: + outputs = (outputs, None) + return outputs # TODO: does this still require position_ids? @@ -495,9 +501,10 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + use_scan: bool = False, **kwargs, ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, use_scan=use_scan, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: @@ -609,11 +616,12 @@ def __call__( class FlaxBloomBlockCollection(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 + use_scan: bool = False def setup(self): # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. self.blocks = [ - FlaxBloomBlock(self.config, layer_number=max(1, i), name=str(i), dtype=self.dtype) + FlaxBloomBlock(self.config, layer_number=max(1, i), name=str(i), dtype=self.dtype, use_scan=False) for i in range(self.config.num_hidden_layers) ] @@ -631,11 +639,20 @@ def __call__( all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = block( + if self.use_scan: + # since all decoder layers are the same, we use nn.scan directly + assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" + assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" + hidden_states = (hidden_states,) + + # TODO: add layerdrop in (possibly checkpointed) scan (note: default value for layerdrop in config is zero) + hidden_states, _ = scan_with_axes( + FlaxBloomBlock, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), + length=self.config.num_hidden_layers, + )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, attention_mask, alibi, @@ -643,10 +660,25 @@ def __call__( init_cache=init_cache, output_attentions=output_attentions, ) - hidden_states = layer_outputs[0] + hidden_states = hidden_states[0] + + else: + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + alibi, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] - if output_attentions: - all_attentions += (layer_outputs[1],) + if output_attentions: + all_attentions += (layer_outputs[1],) # this contains possible `None` values - `FlaxBloomModule` will filter them out outputs = (hidden_states, all_hidden_states, all_attentions) @@ -656,6 +688,7 @@ def __call__( class FlaxBloomModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 + use_scan: bool = False def setup(self): # TODO: check initialization correctness @@ -673,7 +706,7 @@ def setup(self): self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) # transformer layers - self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype) + self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) # final layernorm self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) @@ -758,9 +791,10 @@ class FlaxBloomModel(FlaxBloomPreTrainedModel): class FlaxBloomForCausalLMModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 + use_scan: bool = self.use_scan def setup(self): - self.transformer = FlaxBloomModule(self.config, dtype=self.dtype) + self.transformer = FlaxBloomModule(self.config, dtype=self.dtype, use_scan=self.use_scan) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, From 2eaa09298b6bfe25939c068f6e85a5c6de8845a2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 14:03:02 +0200 Subject: [PATCH 05/81] add matrix mult alibi --- .../models/bloom/modeling_flax_bloom.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 78f1aaf07650..c3602d742942 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -299,7 +299,9 @@ def __call__( query_length, key_length = query.shape[1], key.shape[1] # TODO: check size of hidden_states to confirm this is the right dim for causal mask to use - causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[0]), dtype="bool"), dtype="bool") + causal_mask = make_causal_mask( + jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool" + ) if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] @@ -313,6 +315,7 @@ def __call__( batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + # TODO: fix this if attention_mask is not None: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) @@ -327,7 +330,6 @@ def __call__( key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) attention_bias = None - # transform boolean mask into float mask if attention_mask is not None: attention_bias = lax.select( @@ -336,18 +338,24 @@ def __call__( jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) + # Reshape input tensors + output_size = (query.shape[0], query.shape[2], query.shape[1], key.shape[1]) + + # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] + query = jnp.transpose(query, (1, 0, 2, 3)).reshape(output_size[2], output_size[0] * output_size[1], -1) + + # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] + key = jnp.transpose(key, (1, 0, 2, 3)).reshape(output_size[3], output_size[0] * output_size[1], -1) + # Reshape according to batch size + query = jnp.transpose(query, (1, 0, 2)) + key = jnp.transpose(key, (1, 2, 0)) # usual dot product attention - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) + alpha = (1.0 / self.norm_factor) + attn_weights = alibi + jnp.matmul(query, key)*alpha + + # TODO: apply softmax to attention weights + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) @@ -442,8 +450,8 @@ def __call__( # self-attention attn_outputs = self.self_attention( layernorm_output, - residual, - alibi, + residual=residual, + alibi=alibi, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, @@ -654,8 +662,8 @@ def __call__( length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, - attention_mask, - alibi, + alibi=alibi, + attention_mask=attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, @@ -716,8 +724,8 @@ def setup(self): def __call__( self, input_ids=None, - past_key_values=None, attention_mask=None, + past_key_values=None, position_ids=None, head_mask=None, inputs_embeds=None, @@ -815,8 +823,8 @@ def __call__( ): outputs = self.transformer( input_ids, - attention_mask, - position_ids, + attention_mask=attention_mask, + position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, From 274767e9041d7071005ff168c5e9a72dd44683c8 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 13:11:11 +0100 Subject: [PATCH 06/81] beta scaling factor for bmm --- src/transformers/models/bloom/modeling_flax_bloom.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index c3602d742942..d8292e282ea8 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -350,9 +350,13 @@ def __call__( # Reshape according to batch size query = jnp.transpose(query, (1, 0, 2)) key = jnp.transpose(key, (1, 2, 0)) - # usual dot product attention + + # scaling factors alpha = (1.0 / self.norm_factor) - attn_weights = alibi + jnp.matmul(query, key)*alpha + beta = 1.0 / self.layer_number + + # usual dot product attention + attn_weights = beta * alibi + alpha * jnp.matmul(query, key) # TODO: apply softmax to attention weights From f923d5196b9951a61cb967db54b04029aa9e7920 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 16:46:37 +0200 Subject: [PATCH 07/81] working v1 - simple forward pass --- src/transformers/__init__.py | 6 +- src/transformers/models/bloom/__init__.py | 16 +- .../models/bloom/modeling_flax_bloom.py | 148 +++++++++++------- 3 files changed, 102 insertions(+), 68 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e09bc50a49bc..95a09842b70f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -7270,11 +7270,7 @@ FlaxBlenderbotSmallModel, FlaxBlenderbotSmallPreTrainedModel, ) - from .models.bloom import ( - FlaxBloomForCausalLM, - FlaxBloomModel, - FlaxBloomPreTrainedModel, - ) + from .models.bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel from .models.clip import ( FlaxCLIPModel, FlaxCLIPPreTrainedModel, diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index c14285d28a9a..73c4410371b3 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -14,7 +14,13 @@ from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available, is_flax_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tokenizers_available, + is_torch_available, +) _import_structure = { @@ -93,12 +99,8 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_flax_bloom import ( - FlaxBloomForCausalLM, - FlaxBloomModel, - FlaxBloomPreTrainedModel, - ) + from .modeling_flax_bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index d8292e282ea8..c11f232e6bdd 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -17,22 +17,21 @@ # TODO: check correctness against pytorch implementation # TODO: add unit tests # TODO: add documentation / check that documentation is correct - # TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) +# TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) # TODO: check that code is jit-able -from functools import partial -from typing import Any, Callable, Optional, Tuple import math +from functools import partial +from typing import Callable, Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights +from flax.linen.activation import tanh from flax.linen.partitioning import scan_with_axes from flax.traverse_util import flatten_dict, unflatten_dict -from flax.linen.activation import tanh from jax import lax from ...modeling_flax_outputs import ( @@ -40,7 +39,7 @@ FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutput, ) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_bloom import BloomConfig @@ -51,9 +50,33 @@ _CONFIG_FOR_DOC = "BloomConfig" _TOKENIZER_FOR_DOC = "BloomTokenizer" -def attention_mask_func(args): - # TODO: implement this helper fn. see pytorch impl. for reference - raise NotImplementedError + +def masked_fill(mask, a, fill): + return jax.lax.select(mask, a, jax.lax.broadcast(fill, a.shape)) + + +def attention_mask_func(attention_scores, attention_mask, causal_mask): + attention_mask_bool = ~(attention_mask) + + query_length, key_length, n_heads = attention_scores.shape[2], attention_scores.shape[3], attention_scores.shape[1] + padded_causal_mask = jnp.logical_or( + attention_mask_bool[:, None, key_length - query_length : key_length, None], + ~(causal_mask[:, :, key_length - query_length : key_length, :key_length] == 1), + ) + padded_causal_mask = jnp.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length]) + # Make use of floats + return ( + masked_fill( + jnp.broadcast_to( + ~padded_causal_mask, + (padded_causal_mask.shape[0], n_heads, padded_causal_mask.shape[2], padded_causal_mask.shape[3]), + ), + attention_scores, + -1e4, + ), + ~padded_causal_mask, + ) + BLOOM_START_DOCSTRING = r""" @@ -122,9 +145,11 @@ def attention_mask_func(args): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ + def flax_unsqueeze(x, axis): return jnp.expand_dims(x, axis) + def build_alibi_tensor_flax(max_seq_len, n_head, dtype): def get_slopes(n): def get_slopes_power_of_2(n): @@ -140,6 +165,7 @@ def get_slopes_power_of_2(n): get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) + slopes = jnp.array(get_slopes(n_head)) slopes = flax_unsqueeze(flax_unsqueeze(slopes, 1), 1) arange_tensor = flax_unsqueeze(flax_unsqueeze(jnp.arange(max_seq_len, dtype=dtype), 0), 0) @@ -147,11 +173,11 @@ def get_slopes_power_of_2(n): alibi = slopes * jnp.broadcast_to(arange_tensor, (n_head, 1, arange_tensor.shape[-1])) return alibi + class FlaxBloomScaledSoftmax(nn.Module): config: BloomConfig mask_func: Callable softmax_in_fp32: bool - scale: float = None """ Scaled Softmax module. Also performs masking. Args: @@ -164,28 +190,26 @@ class FlaxBloomScaledSoftmax(nn.Module): """ def setup(self): - if not (self.scale is None or self.softmax_in_fp32): - raise ValueError("softmax should be in fp32 if scale is not `None`") + pass - def __call__(self, input, mask, causal_mask): + def __call__(self, input, mask, causal_mask, scale): input_dtype = input.dtype + input_in_16bit = input_dtype in [jnp.float16, jnp.bfloat16] softmax_dtype = jnp.float32 if self.softmax_in_fp32 else input_dtype - if self.scale is not None: - input = input * self.scale - - if mask is not None: - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - # TODO: ideally we could pass a dtype argument to nn.softmax like in PyTorch (see discussion on PR #17474 about fp32 softmax) - mask_output.astype(softmax_dtype) - probs = nn.softmax(mask_output, axis=-1) * (~padded_causal_mask) - else: - input.astype(softmax_dtype) - probs = nn.softmax(input, axis=-1) - + if scale is not None: + input = input * scale + + if mask is None: + mask = jnp.ones((input.shape[0], input.shape[1]), dtype=bool) + + mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) + mask_output.astype(softmax_dtype) + probs = nn.softmax(mask_output, axis=-1) * (padded_causal_mask) + if input_dtype != softmax_dtype: probs = probs.astype(input_dtype) - + return probs @@ -216,7 +240,6 @@ def setup(self): # Layer-wise attention scaling self.norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * self.layer_number - self.attn_dropout = nn.Dropout(self.config.attention_dropout) @@ -224,7 +247,6 @@ def setup(self): self.config, attention_mask_func, self.attention_softmax_in_fp32, - self.layer_number, ) dense = partial( @@ -233,7 +255,6 @@ def setup(self): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) - self.query_key_value = dense(self.hidden_size * 3) self.dense = dense(self.hidden_size) self.attention_dropout = nn.Dropout(self.config.attention_dropout) @@ -299,9 +320,7 @@ def __call__( query_length, key_length = query.shape[1], key.shape[1] # TODO: check size of hidden_states to confirm this is the right dim for causal mask to use - causal_mask = make_causal_mask( - jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool" - ) + causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool") if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] @@ -315,7 +334,6 @@ def __call__( batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - # TODO: fix this if attention_mask is not None: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) @@ -337,10 +355,10 @@ def __call__( jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) - + # Reshape input tensors output_size = (query.shape[0], query.shape[2], query.shape[1], key.shape[1]) - + # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query = jnp.transpose(query, (1, 0, 2, 3)).reshape(output_size[2], output_size[0] * output_size[1], -1) @@ -349,31 +367,49 @@ def __call__( # Reshape according to batch size query = jnp.transpose(query, (1, 0, 2)) - key = jnp.transpose(key, (1, 2, 0)) + key = jnp.transpose(key, (1, 2, 0)) # scaling factors - alpha = (1.0 / self.norm_factor) + alpha = 1.0 / self.norm_factor beta = 1.0 / self.layer_number # usual dot product attention attn_weights = beta * alibi + alpha * jnp.matmul(query, key) + attn_weights = attn_weights.reshape(output_size) # TODO: apply softmax to attention weights - + att_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, self.layer_number) + if head_mask is not None: + attention_probs = attention_probs * head_mask + + output_size = (value.shape[0], value.shape[2], query.shape[1], value.shape[3]) - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.dense(attn_output) + value = jnp.transpose(value, (1, 0, 2, 3)).reshape(value.shape[1], output_size[0] * output_size[1], -1) + attention_probs_reshaped = jnp.reshape(att_probs, (output_size[0] * output_size[1], output_size[2], -1)) + + context = jnp.matmul(attention_probs_reshaped, jnp.transpose(value, (1, 0, 2))) + context = context.reshape(output_size) + context = jnp.transpose(context, (2, 0, 1, 3)) + + # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] + new_context_layer_shape = context.shape[:-2] + (self.hidden_size,) + context = context.reshape(new_context_layer_shape) + + attn_output = self.dense(context) + attn_output = jnp.transpose(attn_output, (1, 0, 2)) + residual outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs + return outputs + class BloomGELU(nn.Module): def setup(self): self.dtype = jnp.float32 + def __call__(self, x): return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + class FlaxBloomMLP(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 @@ -407,6 +443,7 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): else: intermediate_output = self.dense_4h_to_h(hidden_states) + intermediate_output = intermediate_output + residual hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) return hidden_states @@ -450,7 +487,7 @@ def __call__( residual = layernorm_output else: residual = hidden_states - + # self-attention attn_outputs = self.self_attention( layernorm_output, @@ -468,8 +505,6 @@ def __call__( outputs = attn_outputs[1:] - attention_output = attention_output + residual - post_layernorm = self.post_attention_layernorm(attention_output) # set residual based on config @@ -477,11 +512,9 @@ def __call__( residual = post_layernorm else: residual = attention_output - + output = self.mlp(post_layernorm, residual, deterministic=deterministic) - output = output + residual - if use_cache: outputs = (output,) + outputs else: @@ -492,6 +525,7 @@ def __call__( return outputs + # TODO: does this still require position_ids? # TODO: gradient checkpointing # TODO: _no_split_modules? @@ -557,6 +591,7 @@ def init_cache(self, batch_size, max_length): jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) return unfreeze(init_variables["cache"]) + # TODO: check whether this is correct (position ids might not be required) @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) def __call__( @@ -581,7 +616,6 @@ def __call__( batch_size, sequence_length = input_ids.shape - if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -624,6 +658,7 @@ def __call__( return outputs + # TODO: haven't modified this block yet, remove if remains unused class FlaxBloomBlockCollection(nn.Module): config: BloomConfig @@ -681,8 +716,8 @@ def __call__( layer_outputs = block( hidden_states, - attention_mask, alibi, + attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, @@ -697,6 +732,7 @@ def __call__( return outputs + class FlaxBloomModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 @@ -716,10 +752,10 @@ def setup(self): ) # post-embedding layernorm self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) - + # transformer layers self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) - + # final layernorm self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) # TODO: change how gradient checkpointing is done @@ -748,7 +784,7 @@ def __call__( curr_seq_len = hidden_states.shape[1] alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) - past_key_values = () if use_cache else None # TODO: come back to this line + past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? # TODO: fix inputs to this (and args to submodules in general) @@ -775,7 +811,7 @@ def __call__( if not return_dict: # TODO: don't think this return value / ordering is correct - return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) + return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -803,7 +839,7 @@ class FlaxBloomModel(FlaxBloomPreTrainedModel): class FlaxBloomForCausalLMModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = self.use_scan + use_scan: bool = False def setup(self): self.transformer = FlaxBloomModule(self.config, dtype=self.dtype, use_scan=self.use_scan) @@ -890,4 +926,4 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): append_call_sample_docstring( FlaxBloomForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC -) \ No newline at end of file +) From 2425e28aa1d58bca769a58913e4c1aab34a83e46 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 16:02:58 +0100 Subject: [PATCH 08/81] move layer_number from attribute to arg in call --- .../models/bloom/modeling_flax_bloom.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index c11f232e6bdd..e61a22887ef3 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -215,7 +215,6 @@ def __call__(self, input, mask, causal_mask, scale): class FlaxBloomAttention(nn.Module): config: BloomConfig - layer_number: int = None dtype: jnp.dtype = jnp.float32 def setup(self): @@ -238,9 +237,6 @@ def setup(self): f"`num_heads`: {self.num_heads})." ) - # Layer-wise attention scaling - self.norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * self.layer_number - self.attn_dropout = nn.Dropout(self.config.attention_dropout) self.scale_mask_softmax = FlaxBloomScaledSoftmax( @@ -308,6 +304,7 @@ def __call__( deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, + layer_number: int = None ): # TODO: this module __call__ needs checking for correctness of implementation. fused_qkv = self.query_key_value(hidden_states) @@ -369,16 +366,19 @@ def __call__( query = jnp.transpose(query, (1, 0, 2)) key = jnp.transpose(key, (1, 2, 0)) - # scaling factors - alpha = 1.0 / self.norm_factor - beta = 1.0 / self.layer_number + # Layer-wise attention scaling + # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. + layer_number = jnp.where(layer_number < 1, 1, layer_number) + norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * layer_number + alpha = 1.0 / norm_factor + beta = 1.0 / layer_number # usual dot product attention attn_weights = beta * alibi + alpha * jnp.matmul(query, key) attn_weights = attn_weights.reshape(output_size) # TODO: apply softmax to attention weights - att_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, self.layer_number) + att_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, layer_number) if head_mask is not None: attention_probs = attention_probs * head_mask @@ -451,14 +451,13 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): class FlaxBloomBlock(nn.Module): config: BloomConfig - layer_number: int = None dtype: jnp.dtype = jnp.float32 use_scan: bool = False def setup(self): self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - self.self_attention = FlaxBloomAttention(self.config, layer_number=self.layer_number, dtype=self.dtype) + self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype) self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) @@ -477,6 +476,7 @@ def __call__( init_cache: bool = False, output_attentions: bool = False, use_cache: bool = False, + layer_number: int = None, ): if self.use_scan: hidden_states = hidden_states[0] @@ -499,6 +499,7 @@ def __call__( deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, + layer_number=layer_number, ) attention_output = attn_outputs[0] @@ -666,9 +667,8 @@ class FlaxBloomBlockCollection(nn.Module): use_scan: bool = False def setup(self): - # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. self.blocks = [ - FlaxBloomBlock(self.config, layer_number=max(1, i), name=str(i), dtype=self.dtype, use_scan=False) + FlaxBloomBlock(self.config, name=str(i), dtype=self.dtype, use_scan=False) for i in range(self.config.num_hidden_layers) ] @@ -710,7 +710,7 @@ def __call__( hidden_states = hidden_states[0] else: - for block in self.blocks: + for layer_number, block in enumerate(self.blocks): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -721,6 +721,7 @@ def __call__( deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, + layer_number=layer_number, ) hidden_states = layer_outputs[0] From 8caac99ce41a07e7567687956855c1ad0fbb322e Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 17:50:48 +0100 Subject: [PATCH 09/81] partial functioning scan --- .../models/bloom/modeling_flax_bloom.py | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index e61a22887ef3..8eb4ae6a017e 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -304,7 +304,7 @@ def __call__( deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, - layer_number: int = None + layer_number: int = None, ): # TODO: this module __call__ needs checking for correctness of implementation. fused_qkv = self.query_key_value(hidden_states) @@ -660,18 +660,12 @@ def __call__( return outputs -# TODO: haven't modified this block yet, remove if remains unused class FlaxBloomBlockCollection(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 use_scan: bool = False - def setup(self): - self.blocks = [ - FlaxBloomBlock(self.config, name=str(i), dtype=self.dtype, use_scan=False) - for i in range(self.config.num_hidden_layers) - ] - + @nn.compact def __call__( self, hidden_states, @@ -688,33 +682,29 @@ def __call__( if self.use_scan: # since all decoder layers are the same, we use nn.scan directly - assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" - assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" + # assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" + # assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" hidden_states = (hidden_states,) - # TODO: add layerdrop in (possibly checkpointed) scan (note: default value for layerdrop in config is zero) hidden_states, _ = scan_with_axes( FlaxBloomBlock, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), + in_axes=(nn.broadcast, 0), length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, + alibi, + layer_number=jnp.arange(0, self.config.num_hidden_layers) ) hidden_states = hidden_states[0] else: - for layer_number, block in enumerate(self.blocks): + for layer_number in range(self.config.num_hidden_layers): if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = block( + layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( hidden_states, alibi, attention_mask, From 91b7f757f8a266ba983f0439950c799d15faea2d Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 18:20:30 +0100 Subject: [PATCH 10/81] hacky working scan --- src/transformers/models/bloom/modeling_flax_bloom.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 8eb4ae6a017e..bd055d6cb284 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -690,12 +690,19 @@ def __call__( FlaxBloomBlock, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, 0), + in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, 0), length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, alibi, - layer_number=jnp.arange(0, self.config.num_hidden_layers) + None, # kwargs not supported by scan + None, + None, + deterministic, + init_cache, + output_attentions, + False, + layer_number=jnp.arange(self.config.num_hidden_layers), ) hidden_states = hidden_states[0] From 2530051e3013a43b2a5cde003acdedeaa5bc45be Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 19:09:33 +0200 Subject: [PATCH 11/81] add more modifs --- .../models/bloom/modeling_flax_bloom.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index bd055d6cb284..f42296baab8c 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -60,10 +60,10 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask): query_length, key_length, n_heads = attention_scores.shape[2], attention_scores.shape[3], attention_scores.shape[1] padded_causal_mask = jnp.logical_or( - attention_mask_bool[:, None, key_length - query_length : key_length, None], + attention_mask_bool[:, :, key_length - query_length : key_length, :], ~(causal_mask[:, :, key_length - query_length : key_length, :key_length] == 1), ) - padded_causal_mask = jnp.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length]) + padded_causal_mask = jnp.logical_or(padded_causal_mask, attention_mask_bool[:, :, :, :key_length]) # Make use of floats return ( masked_fill( @@ -202,6 +202,8 @@ def __call__(self, input, mask, causal_mask, scale): if mask is None: mask = jnp.ones((input.shape[0], input.shape[1]), dtype=bool) + else: + mask = mask.astype(bool) mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) mask_output.astype(softmax_dtype) @@ -306,7 +308,7 @@ def __call__( output_attentions: bool = False, layer_number: int = None, ): - # TODO: this module __call__ needs checking for correctness of implementation. + alibi = jnp.repeat(alibi, hidden_states.shape[0], axis=0) fused_qkv = self.query_key_value(hidden_states) new_tensor_shape = fused_qkv.shape[:-1] + (self.num_heads, 3 * self.head_dim) @@ -316,7 +318,6 @@ def __call__( query_length, key_length = query.shape[1], key.shape[1] - # TODO: check size of hidden_states to confirm this is the right dim for causal mask to use causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool") if self.has_variable("cache", "cached_key"): @@ -469,8 +470,8 @@ def __call__( self, hidden_states, alibi, - layer_past=None, attention_mask=None, + layer_past=None, head_mask=None, deterministic: bool = True, init_cache: bool = False, @@ -598,8 +599,8 @@ def init_cache(self, batch_size, max_length): def __call__( self, input_ids, - past_key_values: dict = None, attention_mask=None, + past_key_values: dict = None, head_mask=None, inputs_embeds=None, params: dict = None, @@ -628,13 +629,13 @@ def __call__( inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - # TODO: build alibi here? - # TODO: check the inputs and their order to this + # TODO: check with patrick + # if isinstance(past_key_values, jnp.ndarray): + # inputs["cache"] = past_key_values + # mutable = ["cache"] + # else: + # mutable = False + mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), @@ -675,7 +676,6 @@ def __call__( init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, - return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -713,8 +713,8 @@ def __call__( layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( hidden_states, - alibi, - attention_mask, + alibi=alibi, + attention_mask=attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, @@ -774,8 +774,7 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) @@ -793,9 +792,8 @@ def __call__( attention_mask=attention_mask, deterministic=deterministic, init_cache=init_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + output_attentions=output_attentions, ) hidden_states = outputs[0] @@ -905,20 +903,15 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, - "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs From 448c1e9e046ea6b2f730b7f3fe5fdfa23be248ea Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 5 Jul 2022 19:09:51 +0200 Subject: [PATCH 12/81] add test --- .../models/bloom/test_modeling_flax_bloom.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index e533fc7cc73f..2fcb4af60080 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -17,7 +17,7 @@ import timeout_decorator # noqa from transformers import BloomConfig, is_flax_available -from transformers.testing_utils import require_flax, require_sentencepiece, slow +from transformers.testing_utils import require_flax, slow from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -33,7 +33,7 @@ import jax import jax.numpy as jnp - from transformers import FlaxBloomForCausalLM, FlaxBloomModel, BloomTokenizerFast + from transformers import FlaxBloomForCausalLM, FlaxBloomModel def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): @@ -74,7 +74,7 @@ def __init__( self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size - self.n_layer = n_layer + self.num_hidden_layers = n_layer self.n_head = n_head self.hidden_act = hidden_act self.hidden_dropout = hidden_dropout @@ -93,7 +93,7 @@ def prepare_config_and_inputs(self): config = BloomConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, - n_layer=self.n_layer, + n_layer=self.num_hidden_layers, n_head=self.n_head, hidden_dropout=self.hidden_dropout, attention_dropout=self.attention_probs_dropout_prob, @@ -120,23 +120,16 @@ def check_use_cache_forward(self, model_class_name, config, inputs_dict): past_key_values = model.init_cache(input_ids.shape[0], max_length) attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4") - position_ids = jnp.broadcast_to( - jnp.arange(input_ids.shape[-1] - 1)[None, :], - (input_ids.shape[0], input_ids.shape[-1] - 1), - ) outputs_cache = model( input_ids[:, :-1], attention_mask=attention_mask, past_key_values=past_key_values, - position_ids=position_ids, ) - position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") outputs_cache_next = model( input_ids[:, -1:], attention_mask=attention_mask, past_key_values=outputs_cache.past_key_values, - position_ids=position_ids, ) outputs = model(input_ids) @@ -162,23 +155,16 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input ) past_key_values = model.init_cache(input_ids.shape[0], max_length) - position_ids = jnp.broadcast_to( - jnp.arange(input_ids.shape[-1] - 1)[None, :], - (input_ids.shape[0], input_ids.shape[-1] - 1), - ) outputs_cache = model( input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values, - position_ids=position_ids, ) - position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") outputs_cache_next = model( input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache, - position_ids=position_ids, ) outputs = model(input_ids, attention_mask=attention_mask) From 3aed702125553e620448ee2dbdf047765783785e Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 5 Jul 2022 19:08:02 +0100 Subject: [PATCH 13/81] update scan for new kwarg order --- .../models/bloom/modeling_flax_bloom.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index f42296baab8c..a179d0893b5a 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -471,13 +471,13 @@ def __call__( hidden_states, alibi, attention_mask=None, + layer_number: int = None, layer_past=None, head_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, use_cache: bool = False, - layer_number: int = None, ): if self.use_scan: hidden_states = hidden_states[0] @@ -666,6 +666,7 @@ class FlaxBloomBlockCollection(nn.Module): dtype: jnp.dtype = jnp.float32 use_scan: bool = False + # TODO (SG): re-write as a `setup` to conform to Transformers JAX/Flax conventions -> awaiting CG response on G Chat @nn.compact def __call__( self, @@ -690,19 +691,13 @@ def __call__( FlaxBloomBlock, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, 0), + in_axes=(nn.broadcast, nn.broadcast, 0), length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, alibi, - None, # kwargs not supported by scan - None, - None, - deterministic, - init_cache, - output_attentions, - False, - layer_number=jnp.arange(self.config.num_hidden_layers), + attention_mask, # kwargs not supported by scan + jnp.arange(self.config.num_hidden_layers), ) hidden_states = hidden_states[0] From 9ad912c3ab64c3d651cece8e0d76e8735e8686d3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Jul 2022 20:30:03 +0000 Subject: [PATCH 14/81] fix position_ids problem --- .../modeling_flax_pytorch_utils.py | 55 +++++++------------ .../models/bloom/modeling_flax_bloom.py | 8 +-- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 21d312647f9a..44605823d1ca 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -134,8 +134,20 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + # convert pytorch tensor to numpy - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + # numpy currently does not support bfloat16, need to go over float32 in this case to not loose precision + is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) + pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix @@ -278,11 +290,11 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): continue # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) + flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) else: # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) return unflatten_dict(flax_state_dict) @@ -323,7 +335,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): raise # check if we have bf16 weights - is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. @@ -331,7 +343,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) - flax_state = jax.tree_util.tree_map( + flax_state = jax.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) @@ -339,10 +351,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): pt_model_dict = pt_model.state_dict() load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( - pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()} + pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) ) load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( - pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()} + pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) ) # keep track of unexpected & missing keys @@ -371,34 +383,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): elif flax_key_tuple[-1] in ["scale", "embedding"]: flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - # adding batch stats from flax batch norm to pt - elif "mean" in flax_key_tuple[-1]: - flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) - elif "var" in flax_key_tuple[-1]: - flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) - - if "batch_stats" in flax_state: - flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header - else: - flax_key = ".".join(flax_key_tuple) - - # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. - special_pt_names = {} - # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 - for key in pt_model_dict: - key_components = key.split(".") - name = None - if key_components[-3::2] == ["parametrizations", "original0"]: - name = key_components[-2] + "_g" - elif key_components[-3::2] == ["parametrizations", "original1"]: - name = key_components[-2] + "_v" - if name is not None: - key_components = key_components[:-3] + [name] - key_to_check = ".".join(key_components) - special_pt_names[key_to_check] = key - - if flax_key in special_pt_names: - flax_key = special_pt_names[flax_key] + flax_key = ".".join(flax_key_tuple) if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index a179d0893b5a..515b8905ceaa 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -433,7 +433,8 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): hidden_states = self.act(hidden_states) # TODO: this code block is from the pytorch implementation. needs changing to work. - if self.pretraining_tp > 1 and self.slow_but_exact: +# if self.pretraining_tp > 1 and self.slow_but_exact: + if False: intermediate_output = jnp.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): @@ -636,6 +637,7 @@ def __call__( # else: # mutable = False mutable = False +# import ipdb; ipdb.set_trace() outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), @@ -759,7 +761,6 @@ def __call__( input_ids=None, attention_mask=None, past_key_values=None, - position_ids=None, head_mask=None, inputs_embeds=None, use_cache=None, @@ -845,17 +846,16 @@ def __call__( self, input_ids, attention_mask, - position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): +# import ipdb; ipdb.set_trace() outputs = self.transformer( input_ids, attention_mask=attention_mask, - position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, From fc830793efac09a581ee3c80e4d263ad3a76609d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Jul 2022 22:57:17 +0000 Subject: [PATCH 15/81] fix bug in attention layer --- .../models/bloom/modeling_flax_bloom.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 515b8905ceaa..94a94059be21 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -21,6 +21,7 @@ # TODO: check that code is jit-able import math +import numpy as np from functools import partial from typing import Callable, Optional, Tuple @@ -308,7 +309,10 @@ def __call__( output_attentions: bool = False, layer_number: int = None, ): - alibi = jnp.repeat(alibi, hidden_states.shape[0], axis=0) + batch_size, sequence, hidden_size = hidden_states.shape + + alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) + fused_qkv = self.query_key_value(hidden_states) new_tensor_shape = fused_qkv.shape[:-1] + (self.num_heads, 3 * self.head_dim) @@ -379,25 +383,18 @@ def __call__( attn_weights = attn_weights.reshape(output_size) # TODO: apply softmax to attention weights - att_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, layer_number) + attention_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, layer_number) + if head_mask is not None: attention_probs = attention_probs * head_mask output_size = (value.shape[0], value.shape[2], query.shape[1], value.shape[3]) - value = jnp.transpose(value, (1, 0, 2, 3)).reshape(value.shape[1], output_size[0] * output_size[1], -1) - attention_probs_reshaped = jnp.reshape(att_probs, (output_size[0] * output_size[1], output_size[2], -1)) - - context = jnp.matmul(attention_probs_reshaped, jnp.transpose(value, (1, 0, 2))) - context = context.reshape(output_size) - context = jnp.transpose(context, (2, 0, 1, 3)) - - # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] - new_context_layer_shape = context.shape[:-2] + (self.hidden_size,) - context = context.reshape(new_context_layer_shape) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attention_probs, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.dense(attn_output ) - attn_output = self.dense(context) - attn_output = jnp.transpose(attn_output, (1, 0, 2)) + residual + attn_output = attn_output + residual outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -637,7 +634,6 @@ def __call__( # else: # mutable = False mutable = False -# import ipdb; ipdb.set_trace() outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), @@ -708,6 +704,7 @@ def __call__( if output_hidden_states: all_hidden_states += (hidden_states,) + layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( hidden_states, alibi=alibi, @@ -852,7 +849,6 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): -# import ipdb; ipdb.set_trace() outputs = self.transformer( input_ids, attention_mask=attention_mask, From 541429af9624c0d0124403ead8f5ce2538c23836 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Jul 2022 09:57:14 +0200 Subject: [PATCH 16/81] small fix - do the alibi broadcasting only once --- src/transformers/models/bloom/modeling_flax_bloom.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 94a94059be21..451b9c65e56d 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -311,8 +311,6 @@ def __call__( ): batch_size, sequence, hidden_size = hidden_states.shape - alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) - fused_qkv = self.query_key_value(hidden_states) new_tensor_shape = fused_qkv.shape[:-1] + (self.num_heads, 3 * self.head_dim) @@ -771,8 +769,11 @@ def __call__( # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) - curr_seq_len = hidden_states.shape[1] + batch_size, curr_seq_len, _ = hidden_states.shape alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) + # TODO put repeat here + alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) + past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? From 40ff6dfc7b6cbef8c0f042e71b34ceb95de0dfcd Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 6 Jul 2022 10:22:15 +0100 Subject: [PATCH 17/81] prelim refactor --- .../models/bloom/modeling_flax_bloom.py | 37 +++++++------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 451b9c65e56d..f85905aefe1c 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -221,18 +221,10 @@ class FlaxBloomAttention(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - # TODO: make sure these affect behavior correctly - self.pretraining_tp = self.config.pretraining_tp - self.slow_but_exact = self.config.slow_but_exact - self.hidden_size = self.config.hidden_size self.num_heads = self.config.n_head self.head_dim = self.hidden_size // self.num_heads - self.split_size = self.hidden_size - # TODO: deal with softmax self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32 - # TODO: deal with hidden dropout - self.hidden_dropout = self.config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( @@ -240,8 +232,7 @@ def setup(self): f"`num_heads`: {self.num_heads})." ) - self.attn_dropout = nn.Dropout(self.config.attention_dropout) - + # Scaled softmax self.scale_mask_softmax = FlaxBloomScaledSoftmax( self.config, attention_mask_func, @@ -259,12 +250,13 @@ def setup(self): self.attention_dropout = nn.Dropout(self.config.attention_dropout) def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, 3 * self.head_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) @nn.compact + # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached @@ -309,12 +301,9 @@ def __call__( output_attentions: bool = False, layer_number: int = None, ): - batch_size, sequence, hidden_size = hidden_states.shape - fused_qkv = self.query_key_value(hidden_states) - new_tensor_shape = fused_qkv.shape[:-1] + (self.num_heads, 3 * self.head_dim) - fused_qkv = fused_qkv.reshape(new_tensor_shape) + fused_qkv = self._split_heads(fused_qkv) query, key, value = jnp.split(fused_qkv, 3, axis=-1) @@ -337,6 +326,8 @@ def __call__( if attention_mask is not None: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) + else: + attention_mask = causal_mask dropout_rng = None if not deterministic and self.config.attention_dropout > 0.0: @@ -347,14 +338,12 @@ def __call__( if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - attention_bias = None # transform boolean mask into float mask - if attention_mask is not None: - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), - ) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + ) # Reshape input tensors output_size = (query.shape[0], query.shape[2], query.shape[1], key.shape[1]) @@ -371,7 +360,7 @@ def __call__( # Layer-wise attention scaling # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. - layer_number = jnp.where(layer_number < 1, 1, layer_number) + layer_number = jax.lax.max(1, layer_number) norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * layer_number alpha = 1.0 / norm_factor beta = 1.0 / layer_number @@ -702,7 +691,7 @@ def __call__( if output_hidden_states: all_hidden_states += (hidden_states,) - + layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( hidden_states, alibi=alibi, From 8d1f13741b7c9d7bd820a1b39d9e054f14a44ad8 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 6 Jul 2022 11:01:46 +0100 Subject: [PATCH 18/81] finish refactor --- .../models/bloom/modeling_flax_bloom.py | 128 +++--------------- 1 file changed, 16 insertions(+), 112 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index f85905aefe1c..c8c34f8c93aa 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -21,15 +21,14 @@ # TODO: check that code is jit-able import math -import numpy as np from functools import partial -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask +from flax.linen import combine_masks, make_causal_mask, dot_product_attention_weights from flax.linen.activation import tanh from flax.linen.partitioning import scan_with_axes from flax.traverse_util import flatten_dict, unflatten_dict @@ -52,33 +51,6 @@ _TOKENIZER_FOR_DOC = "BloomTokenizer" -def masked_fill(mask, a, fill): - return jax.lax.select(mask, a, jax.lax.broadcast(fill, a.shape)) - - -def attention_mask_func(attention_scores, attention_mask, causal_mask): - attention_mask_bool = ~(attention_mask) - - query_length, key_length, n_heads = attention_scores.shape[2], attention_scores.shape[3], attention_scores.shape[1] - padded_causal_mask = jnp.logical_or( - attention_mask_bool[:, :, key_length - query_length : key_length, :], - ~(causal_mask[:, :, key_length - query_length : key_length, :key_length] == 1), - ) - padded_causal_mask = jnp.logical_or(padded_causal_mask, attention_mask_bool[:, :, :, :key_length]) - # Make use of floats - return ( - masked_fill( - jnp.broadcast_to( - ~padded_causal_mask, - (padded_causal_mask.shape[0], n_heads, padded_causal_mask.shape[2], padded_causal_mask.shape[3]), - ), - attention_scores, - -1e4, - ), - ~padded_causal_mask, - ) - - BLOOM_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -175,47 +147,6 @@ def get_slopes_power_of_2(n): return alibi -class FlaxBloomScaledSoftmax(nn.Module): - config: BloomConfig - mask_func: Callable - softmax_in_fp32: bool - """ - Scaled Softmax module. Also performs masking. - Args: - mask_func (`function`, *required*): - mask function to be applied. - softmax_in_fp32 (`bool`, *required*): - if true, softmax in performed at fp32 precision. - scale (`float`, *optional*): - scaling factor used in input tensor scaling. - """ - - def setup(self): - pass - - def __call__(self, input, mask, causal_mask, scale): - input_dtype = input.dtype - input_in_16bit = input_dtype in [jnp.float16, jnp.bfloat16] - softmax_dtype = jnp.float32 if self.softmax_in_fp32 else input_dtype - - if scale is not None: - input = input * scale - - if mask is None: - mask = jnp.ones((input.shape[0], input.shape[1]), dtype=bool) - else: - mask = mask.astype(bool) - - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - mask_output.astype(softmax_dtype) - probs = nn.softmax(mask_output, axis=-1) * (padded_causal_mask) - - if input_dtype != softmax_dtype: - probs = probs.astype(input_dtype) - - return probs - - class FlaxBloomAttention(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 @@ -232,13 +163,6 @@ def setup(self): f"`num_heads`: {self.num_heads})." ) - # Scaled softmax - self.scale_mask_softmax = FlaxBloomScaledSoftmax( - self.config, - attention_mask_func, - self.attention_softmax_in_fp32, - ) - dense = partial( nn.Dense, dtype=self.dtype, @@ -345,41 +269,24 @@ def __call__( jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) - # Reshape input tensors - output_size = (query.shape[0], query.shape[2], query.shape[1], key.shape[1]) - - # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] - query = jnp.transpose(query, (1, 0, 2, 3)).reshape(output_size[2], output_size[0] * output_size[1], -1) - - # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] - key = jnp.transpose(key, (1, 0, 2, 3)).reshape(output_size[3], output_size[0] * output_size[1], -1) - - # Reshape according to batch size - query = jnp.transpose(query, (1, 0, 2)) - key = jnp.transpose(key, (1, 2, 0)) - - # Layer-wise attention scaling - # layer_number matters for attn scaling and should not be 0. see `FlaxBloomAttention` for its use. - layer_number = jax.lax.max(1, layer_number) - norm_factor = jnp.sqrt(self.head_dim).astype(self.dtype) * layer_number - alpha = 1.0 / norm_factor - beta = 1.0 / layer_number + attention_bias = attention_bias + alibi + # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 # usual dot product attention - attn_weights = beta * alibi + alpha * jnp.matmul(query, key) - attn_weights = attn_weights.reshape(output_size) - - # TODO: apply softmax to attention weights - attention_probs = self.scale_mask_softmax(attn_weights, attention_mask, causal_mask, layer_number) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - output_size = (value.shape[0], value.shape[2], query.shape[1], value.shape[3]) + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) - attn_output = jnp.einsum("...hqk,...khd->...qhd", attention_probs, value) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) - attn_output = self.dense(attn_output ) + attn_output = self.dense(attn_output) attn_output = attn_output + residual @@ -760,9 +667,6 @@ def __call__( batch_size, curr_seq_len, _ = hidden_states.shape alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) - # TODO put repeat here - alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) - past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? From 22796ae99cdc1ff00446cd57913484dc692b4051 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Jul 2022 10:46:48 +0200 Subject: [PATCH 19/81] alibi shifting --- src/transformers/models/bloom/modeling_flax_bloom.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index c8c34f8c93aa..93b40a92de6a 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -666,7 +666,11 @@ def __call__( hidden_states = self.word_embeddings_layernorm(inputs_embeds) batch_size, curr_seq_len, _ = hidden_states.shape + alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) + # TODO put repeat here + alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) + past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? From 2917cac8cc8cbd18f29b36c4accb5911a2dac455 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 6 Jul 2022 11:31:21 +0100 Subject: [PATCH 20/81] incorporate dropout_add to attention module --- src/transformers/models/bloom/modeling_flax_bloom.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 93b40a92de6a..b9b0d54770e3 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -171,7 +171,7 @@ def setup(self): self.query_key_value = dense(self.hidden_size * 3) self.dense = dense(self.hidden_size) - self.attention_dropout = nn.Dropout(self.config.attention_dropout) + self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, 3 * self.head_dim)) @@ -284,9 +284,13 @@ def __call__( precision=None, ) + if head_mask is not None: + attn_weights = attn_weights * head_mask + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.dense(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) attn_output = attn_output + residual From a6669b407357f6c76ba3973d5813c4cacaa9b0db Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 11:12:20 +0000 Subject: [PATCH 21/81] make style --- src/transformers/models/bloom/modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index b9b0d54770e3..7ae520958f3f 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -675,7 +675,6 @@ def __call__( # TODO put repeat here alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) - past_key_values = () if use_cache else None # TODO: come back to this line # TODO: how to handle alibi? build alibi tensor here? @@ -785,6 +784,7 @@ def __call__( # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->Bloom class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): module_class = FlaxBloomForCausalLMModule + # TODO: check if this class is correct / take out position ids def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): # initializing the cache From 312a2f1ae75093e1edf133ffca454d62ea2f7ff8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 11:19:02 +0000 Subject: [PATCH 22/81] make padding work again --- src/transformers/modeling_flax_pytorch_utils.py | 8 ++++---- src/transformers/models/bloom/modeling_flax_bloom.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 44605823d1ca..956c262e3e93 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -209,11 +209,11 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): continue # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) + flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) else: # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) return unflatten_dict(flax_state_dict) @@ -290,11 +290,11 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): continue # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) else: # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 7ae520958f3f..0010d3a8e2ad 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Flax BLOOM model. """ +""" Flax BLOOM model.""" # TODO: see todos throughout this file # TODO: check correctness against pytorch implementation # TODO: add unit tests @@ -28,7 +28,7 @@ import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask, dot_product_attention_weights +from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask from flax.linen.activation import tanh from flax.linen.partitioning import scan_with_axes from flax.traverse_util import flatten_dict, unflatten_dict @@ -328,7 +328,7 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): hidden_states = self.act(hidden_states) # TODO: this code block is from the pytorch implementation. needs changing to work. -# if self.pretraining_tp > 1 and self.slow_but_exact: + # if self.pretraining_tp > 1 and self.slow_but_exact: if False: intermediate_output = jnp.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp From 143a13591fd5f26852020add03d11784a697536e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 12:35:32 +0000 Subject: [PATCH 23/81] update --- @ | 841 ++++++++++++++++++ .../models/bloom/modeling_flax_bloom.py | 30 +- 2 files changed, 863 insertions(+), 8 deletions(-) create mode 100644 @ diff --git a/@ b/@ new file mode 100644 index 000000000000..ffebb5743bfc --- /dev/null +++ b/@ @@ -0,0 +1,841 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and Bigscience Workshop. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax BLOOM model.""" +# TODO: see todos throughout this file +# TODO: check correctness against pytorch implementation +# TODO: add unit tests +# TODO: add documentation / check that documentation is correct +# TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) +# TODO: check that code is jit-able + +import math +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask +from flax.linen.activation import tanh +from flax.linen.partitioning import scan_with_axes +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutput, +) +from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_bloom import BloomConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bigscience/bloom" +_CONFIG_FOR_DOC = "BloomConfig" +_TOKENIZER_FOR_DOC = "BloomTokenizer" + + +def shift_alibi_for_padding(alibi, attention_mask, batch_size): + if jnp.isin(0, attention_mask): + unpadded_indices = nn.relu(lax.cumsum(attention_mask, axis=1) - 1) + reshaped_alibi = jnp.reshape(alibi, (batch_size, int(alibi.shape[0] / batch_size), alibi.shape[-1])) + + final_alibi = [] + for i in range(reshaped_alibi.shape[0]): + final_alibi.append(jnp.take_along_axis(reshaped_alibi[i], jnp.expand_dims(unpadded_indices[i], 0), -1)) + alibi = jnp.array(final_alibi) + return jnp.reshape(alibi, (alibi.shape[0] * alibi.shape[1], 1, -1)) + else: + return alibi + + +BLOOM_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BloomConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BLOOM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def flax_unsqueeze(x, axis): + return jnp.expand_dims(x, axis) + + +def build_alibi_tensor_flax(max_seq_len, n_head, dtype): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = jnp.array(get_slopes(n_head))[:, None, None] + arange_tensor = jnp.arange(max_seq_len, dtype=dtype)[None, None, :] + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => the batch_size and query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 + # batch_size = 1, n_head = n_head, query_length + batch_size = key_length = 1 + num_heads = n_head + query_length = arange_tensor.shape[-1] + + alibi = slopes * jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) + return alibi + + +class FlaxBloomAttention(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32 + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.query_key_value = dense(self.hidden_size * 3) + self.dense = dense(self.hidden_size) + self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, 3 * self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) + + @nn.compact + # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + residual, + alibi, + layer_past=None, + attention_mask=None, + head_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + layer_number: int = None, + ): + batch_size = hidden_states.shape[0] + + # proj q, k, v + fused_qkv = self.query_key_value(hidden_states) + fused_qkv = self._split_heads(fused_qkv) + query, key, value = jnp.split(fused_qkv, 3, axis=-1) + + query_length, key_length = query.shape[1], key.shape[1] + + causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool") + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # create attention mask + if attention_mask is not None: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + dropout_rng = None + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + ) + + if alibi.shape[1:] != (1, 1): + import ipdb; ipdb.set_trace() + + attention_bias = attention_bias + alibi + + # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 + # usual dot product attention + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.dense(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + + attn_output = attn_output + residual + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class BloomGELU(nn.Module): + def setup(self): + self.dtype = jnp.float32 + + def __call__(self, x): + return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +class FlaxBloomMLP(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + hidden_size = self.config.hidden_size + + self.pretraining_tp = self.config.pretraining_tp + self.slow_but_exact = self.config.slow_but_exact + + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) + + self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) + self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) + self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) + self.act = BloomGELU() + + def __call__(self, hidden_states, residual, deterministic: bool = True): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + + # TODO: this code block is from the pytorch implementation. needs changing to work. + # if self.pretraining_tp > 1 and self.slow_but_exact: + if False: + intermediate_output = jnp.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + nn.functional.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + + intermediate_output = intermediate_output + residual + hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) + + return hidden_states + + +class FlaxBloomBlock(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self): + self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + + self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype) + self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + + self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) + + self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm + self.hidden_dropout = self.config.hidden_dropout + + def __call__( + self, + hidden_states, + alibi, + attention_mask=None, + layer_number: int = None, + layer_past=None, + head_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + use_cache: bool = False, + ): + if self.use_scan: + hidden_states = hidden_states[0] + + layernorm_output = self.input_layernorm(hidden_states) + # layer norm before saving residual if config calls for it + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # self-attention + attn_outputs = self.self_attention( + layernorm_output, + residual=residual, + alibi=alibi, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + layer_number=layer_number, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + post_layernorm = self.post_attention_layernorm(attention_output) + + # set residual based on config + if self.apply_residual_connection_post_layernorm: + residual = post_layernorm + else: + residual = attention_output + + output = self.mlp(post_layernorm, residual, deterministic=deterministic) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + if self.use_scan: + outputs = (outputs, None) + + return outputs + + +# TODO: does this still require position_ids? +# TODO: gradient checkpointing +# TODO: _no_split_modules? +# TODO: check initialization +class FlaxBloomPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BloomConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: BloomConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + use_scan: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, use_scan=use_scan, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + # TODO: check whether this is correct (position ids might not be required) + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) + def __call__( + self, + input_ids, + attention_mask=None, + past_key_values: dict = None, + head_mask=None, + inputs_embeds=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, sequence_length = input_ids.shape + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module + # TODO: check with patrick + # if isinstance(past_key_values, jnp.ndarray): + # inputs["cache"] = past_key_values + # mutable = ["cache"] + # else: + # mutable = False + mutable = False + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxBloomBlockCollection(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + # TODO (SG): re-write as a `setup` to conform to Transformers JAX/Flax conventions -> awaiting CG response on G Chat + @nn.compact + def __call__( + self, + hidden_states, + alibi, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.use_scan: + # since all decoder layers are the same, we use nn.scan directly + # assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" + # assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" + hidden_states = (hidden_states,) + + hidden_states, _ = scan_with_axes( + FlaxBloomBlock, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, 0), + length=self.config.num_hidden_layers, + )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( + hidden_states, + alibi, + attention_mask, # kwargs not supported by scan + jnp.arange(self.config.num_hidden_layers), + ) + hidden_states = hidden_states[0] + + else: + for layer_number in range(self.config.num_hidden_layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( + hidden_states, + alibi=alibi, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + layer_number=layer_number, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxBloomModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxBloomModule(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self): + # TODO: check initialization correctness + self.embed_dim = self.config.hidden_size + + embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) + + # word embeddings (no positional embedding layer) TODO: confirm this statement correct + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.embed_dim, + embedding_init=embedding_init, + ) + # post-embedding layernorm + self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) + + # transformer layers + self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) + + # final layernorm + self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + # TODO: change how gradient checkpointing is done + self.gradient_checkpointing = False + + def __call__( + self, + input_ids=None, + attention_mask=None, + past_key_values=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + # do post-embedding layernorm + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + batch_size, curr_seq_len, _ = hidden_states.shape + + alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) + + if False and attention_mask is not None: + alibi = shift_alibi_for_padding(alibi, attention_mask, batch_size) + + past_key_values = () if use_cache else None # TODO: come back to this line + # TODO: how to handle alibi? build alibi tensor here? + + # TODO: fix inputs to this (and args to submodules in general) + # TODO: gradient checkpointing + outputs = self.h( + hidden_states, + alibi=alibi, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + # TODO: don't think this return value / ordering is correct + return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings( + "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", + BLOOM_START_DOCSTRING, +) +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom +class FlaxBloomModel(FlaxBloomPreTrainedModel): + module_class = FlaxBloomModule + + +append_call_sample_docstring( + FlaxBloomModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC +) + + +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLMModule with GPTNeo->Bloom +class FlaxBloomForCausalLMModule(nn.Module): + config: BloomConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self): + self.transformer = FlaxBloomModule(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """ + The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + BLOOM_START_DOCSTRING, +) +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->Bloom +class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): + module_class = FlaxBloomForCausalLMModule + + # TODO: check if this class is correct / take out position ids + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since Bloom uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +append_call_sample_docstring( + FlaxBloomForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC +) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 0010d3a8e2ad..784f235495ef 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -139,11 +139,21 @@ def get_slopes_power_of_2(n): + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) - slopes = jnp.array(get_slopes(n_head)) - slopes = flax_unsqueeze(flax_unsqueeze(slopes, 1), 1) - arange_tensor = flax_unsqueeze(flax_unsqueeze(jnp.arange(max_seq_len, dtype=dtype), 0), 0) - - alibi = slopes * jnp.broadcast_to(arange_tensor, (n_head, 1, arange_tensor.shape[-1])) + slopes = jnp.array(get_slopes(n_head))[None, :, None, None] + arange_tensor = jnp.arange(max_seq_len, dtype=dtype)[None, None, None, :] + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => the batch_size and query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 + # batch_size = 1, n_head = n_head, query_length + batch_size = query_length = 1 + num_heads = n_head + key_length = arange_tensor.shape[-1] + + alibi = slopes * jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) return alibi @@ -225,10 +235,11 @@ def __call__( output_attentions: bool = False, layer_number: int = None, ): - fused_qkv = self.query_key_value(hidden_states) + batch_size = hidden_states.shape[0] + # proj q, k, v + fused_qkv = self.query_key_value(hidden_states) fused_qkv = self._split_heads(fused_qkv) - query, key, value = jnp.split(fused_qkv, 3, axis=-1) query_length, key_length = query.shape[1], key.shape[1] @@ -244,9 +255,9 @@ def __call__( else: causal_mask = causal_mask[:, :, :query_length, :key_length] - batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + # create attention mask if attention_mask is not None: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) @@ -269,6 +280,9 @@ def __call__( jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) +# if alibi.shape[1:] != (1, 1): +# import ipdb; ipdb.set_trace() + attention_bias = attention_bias + alibi # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 From e2b67aa78f1327173df55650bbdc4837a67c07a0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 12:35:43 +0000 Subject: [PATCH 24/81] remove bogus file --- @ | 841 -------------------------------------------------------------- 1 file changed, 841 deletions(-) delete mode 100644 @ diff --git a/@ b/@ deleted file mode 100644 index ffebb5743bfc..000000000000 --- a/@ +++ /dev/null @@ -1,841 +0,0 @@ -# coding=utf-8 -# Copyright 2022 HuggingFace Inc. team and Bigscience Workshop. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Flax BLOOM model.""" -# TODO: see todos throughout this file -# TODO: check correctness against pytorch implementation -# TODO: add unit tests -# TODO: add documentation / check that documentation is correct -# TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) -# TODO: check that code is jit-able - -import math -from functools import partial -from typing import Optional, Tuple - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask -from flax.linen.activation import tanh -from flax.linen.partitioning import scan_with_axes -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutput, -) -from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_bloom import BloomConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "bigscience/bloom" -_CONFIG_FOR_DOC = "BloomConfig" -_TOKENIZER_FOR_DOC = "BloomTokenizer" - - -def shift_alibi_for_padding(alibi, attention_mask, batch_size): - if jnp.isin(0, attention_mask): - unpadded_indices = nn.relu(lax.cumsum(attention_mask, axis=1) - 1) - reshaped_alibi = jnp.reshape(alibi, (batch_size, int(alibi.shape[0] / batch_size), alibi.shape[-1])) - - final_alibi = [] - for i in range(reshaped_alibi.shape[0]): - final_alibi.append(jnp.take_along_axis(reshaped_alibi[i], jnp.expand_dims(unpadded_indices[i], 0), -1)) - alibi = jnp.array(final_alibi) - return jnp.reshape(alibi, (alibi.shape[0] * alibi.shape[1], 1, -1)) - else: - return alibi - - -BLOOM_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BloomConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BLOOM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def flax_unsqueeze(x, axis): - return jnp.expand_dims(x, axis) - - -def build_alibi_tensor_flax(max_seq_len, n_head, dtype): - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - slopes = jnp.array(get_slopes(n_head))[:, None, None] - arange_tensor = jnp.arange(max_seq_len, dtype=dtype)[None, None, :] - - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) - # => the batch_size and query_length dimension will then be broadcasted correctly - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 - # batch_size = 1, n_head = n_head, query_length - batch_size = key_length = 1 - num_heads = n_head - query_length = arange_tensor.shape[-1] - - alibi = slopes * jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) - return alibi - - -class FlaxBloomAttention(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.hidden_size = self.config.hidden_size - self.num_heads = self.config.n_head - self.head_dim = self.hidden_size // self.num_heads - self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32 - - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError( - f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.query_key_value = dense(self.hidden_size * 3) - self.dense = dense(self.hidden_size) - self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, 3 * self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) - - @nn.compact - # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slighly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - residual, - alibi, - layer_past=None, - attention_mask=None, - head_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - layer_number: int = None, - ): - batch_size = hidden_states.shape[0] - - # proj q, k, v - fused_qkv = self.query_key_value(hidden_states) - fused_qkv = self._split_heads(fused_qkv) - query, key, value = jnp.split(fused_qkv, 3, axis=-1) - - query_length, key_length = query.shape[1], key.shape[1] - - causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool") - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = causal_mask[:, :, :query_length, :key_length] - - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # create attention mask - if attention_mask is not None: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - else: - attention_mask = causal_mask - - dropout_rng = None - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), - ) - - if alibi.shape[1:] != (1, 1): - import ipdb; ipdb.set_trace() - - attention_bias = attention_bias + alibi - - # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 - # usual dot product attention - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.dense(attn_output) - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - - attn_output = attn_output + residual - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class BloomGELU(nn.Module): - def setup(self): - self.dtype = jnp.float32 - - def __call__(self, x): - return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - -class FlaxBloomMLP(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - hidden_size = self.config.hidden_size - - self.pretraining_tp = self.config.pretraining_tp - self.slow_but_exact = self.config.slow_but_exact - - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - - self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) - self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) - self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) - self.act = BloomGELU() - - def __call__(self, hidden_states, residual, deterministic: bool = True): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - - # TODO: this code block is from the pytorch implementation. needs changing to work. - # if self.pretraining_tp > 1 and self.slow_but_exact: - if False: - intermediate_output = jnp.zeros_like(residual) - slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp - for i in range(self.pretraining_tp): - intermediate_output = intermediate_output + nn.functional.linear( - hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - intermediate_output = self.dense_4h_to_h(hidden_states) - - intermediate_output = intermediate_output + residual - hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) - - return hidden_states - - -class FlaxBloomBlock(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - - def setup(self): - self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype) - self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) - - self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm - self.hidden_dropout = self.config.hidden_dropout - - def __call__( - self, - hidden_states, - alibi, - attention_mask=None, - layer_number: int = None, - layer_past=None, - head_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - use_cache: bool = False, - ): - if self.use_scan: - hidden_states = hidden_states[0] - - layernorm_output = self.input_layernorm(hidden_states) - # layer norm before saving residual if config calls for it - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - # self-attention - attn_outputs = self.self_attention( - layernorm_output, - residual=residual, - alibi=alibi, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - layer_number=layer_number, - ) - - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - post_layernorm = self.post_attention_layernorm(attention_output) - - # set residual based on config - if self.apply_residual_connection_post_layernorm: - residual = post_layernorm - else: - residual = attention_output - - output = self.mlp(post_layernorm, residual, deterministic=deterministic) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - if self.use_scan: - outputs = (outputs, None) - - return outputs - - -# TODO: does this still require position_ids? -# TODO: gradient checkpointing -# TODO: _no_split_modules? -# TODO: check initialization -class FlaxBloomPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BloomConfig - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: BloomConfig, - input_shape: Tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - use_scan: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, use_scan=use_scan, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - # TODO: check whether this is correct (position ids might not be required) - @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - past_key_values: dict = None, - head_mask=None, - inputs_embeds=None, - params: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, sequence_length = input_ids.shape - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module - # TODO: check with patrick - # if isinstance(past_key_values, jnp.ndarray): - # inputs["cache"] = past_key_values - # mutable = ["cache"] - # else: - # mutable = False - mutable = False - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxBloomBlockCollection(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - - # TODO (SG): re-write as a `setup` to conform to Transformers JAX/Flax conventions -> awaiting CG response on G Chat - @nn.compact - def __call__( - self, - hidden_states, - alibi, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.use_scan: - # since all decoder layers are the same, we use nn.scan directly - # assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" - # assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" - hidden_states = (hidden_states,) - - hidden_states, _ = scan_with_axes( - FlaxBloomBlock, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0), - length=self.config.num_hidden_layers, - )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( - hidden_states, - alibi, - attention_mask, # kwargs not supported by scan - jnp.arange(self.config.num_hidden_layers), - ) - hidden_states = hidden_states[0] - - else: - for layer_number in range(self.config.num_hidden_layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( - hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - layer_number=layer_number, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxBloomModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -class FlaxBloomModule(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - - def setup(self): - # TODO: check initialization correctness - self.embed_dim = self.config.hidden_size - - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - - # word embeddings (no positional embedding layer) TODO: confirm this statement correct - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.embed_dim, - embedding_init=embedding_init, - ) - # post-embedding layernorm - self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) - - # transformer layers - self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) - - # final layernorm - self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - # TODO: change how gradient checkpointing is done - self.gradient_checkpointing = False - - def __call__( - self, - input_ids=None, - attention_mask=None, - past_key_values=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - # do post-embedding layernorm - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - - batch_size, curr_seq_len, _ = hidden_states.shape - - alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) - - if False and attention_mask is not None: - alibi = shift_alibi_for_padding(alibi, attention_mask, batch_size) - - past_key_values = () if use_cache else None # TODO: come back to this line - # TODO: how to handle alibi? build alibi tensor here? - - # TODO: fix inputs to this (and args to submodules in general) - # TODO: gradient checkpointing - outputs = self.h( - hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - # TODO: don't think this return value / ordering is correct - return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", - BLOOM_START_DOCSTRING, -) -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom -class FlaxBloomModel(FlaxBloomPreTrainedModel): - module_class = FlaxBloomModule - - -append_call_sample_docstring( - FlaxBloomModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC -) - - -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLMModule with GPTNeo->Bloom -class FlaxBloomForCausalLMModule(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - - def setup(self): - self.transformer = FlaxBloomModule(self.config, dtype=self.dtype, use_scan=self.use_scan) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - BLOOM_START_DOCSTRING, -) -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->Bloom -class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): - module_class = FlaxBloomForCausalLMModule - - # TODO: check if this class is correct / take out position ids - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Bloom uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - return model_kwargs - - -append_call_sample_docstring( - FlaxBloomForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC -) From c499433477729d4d2a47ebcf0749bdaf8c59e80b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 13:25:44 +0000 Subject: [PATCH 25/81] up --- .../models/bloom/modeling_flax_bloom.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 784f235495ef..c39336387069 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -119,11 +119,7 @@ """ -def flax_unsqueeze(x, axis): - return jnp.expand_dims(x, axis) - - -def build_alibi_tensor_flax(max_seq_len, n_head, dtype): +def build_alibi_tensor_flax(attention_mask, n_head, dtype): def get_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -139,21 +135,24 @@ def get_slopes_power_of_2(n): + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) - slopes = jnp.array(get_slopes(n_head))[None, :, None, None] - arange_tensor = jnp.arange(max_seq_len, dtype=dtype)[None, None, None, :] - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) - # => the batch_size and query_length dimension will then be broadcasted correctly + # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 # batch_size = 1, n_head = n_head, query_length - batch_size = query_length = 1 + batch_size, key_length = attention_mask.shape num_heads = n_head - key_length = arange_tensor.shape[-1] + query_length = 1 - alibi = slopes * jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) + slopes = jnp.array(get_slopes(n_head))[None, :, None, None] + arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 + + slopes_broadcasted = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) + arange_broadcasted = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) + + alibi = slopes_broadcasted * arange_broadcasted return alibi @@ -280,9 +279,6 @@ def __call__( jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) -# if alibi.shape[1:] != (1, 1): -# import ipdb; ipdb.set_trace() - attention_bias = attention_bias + alibi # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 From 01e01a398c38ce06014a2e05f16c2b0dcae03aeb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 14:10:31 +0000 Subject: [PATCH 26/81] get generation to work --- .../models/bloom/modeling_flax_bloom.py | 60 +++++++++++-------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index c39336387069..5f1c4694c239 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -234,49 +234,59 @@ def __call__( output_attentions: bool = False, layer_number: int = None, ): - batch_size = hidden_states.shape[0] + batch_size, seq_length = hidden_states.shape[:2] # proj q, k, v fused_qkv = self.query_key_value(hidden_states) fused_qkv = self._split_heads(fused_qkv) query, key, value = jnp.split(fused_qkv, 3, axis=-1) - query_length, key_length = query.shape[1], key.shape[1] + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") - causal_mask = make_causal_mask(jnp.ones((1, hidden_states.shape[1]), dtype="bool"), dtype="bool") + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0 + ) + # fast decoding for generate requires special attention_mask if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), ) - else: - causal_mask = causal_mask[:, :, :query_length, :key_length] - - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - # create attention mask - if attention_mask is not None: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - else: - attention_mask = causal_mask + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape) + attention_mask = combine_masks(attention_mask, causal_attention_mask) dropout_rng = None if not deterministic and self.config.attention_dropout > 0.0: dropout_rng = self.make_rng("dropout") + # print("layer") + # + # if init_cache: + # print("init_cache") + # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + # if hidden_states.shape[:2] != (1, 1) and hidden_states.shape[1] < 10: + # import ipdb; ipdb.set_trace() + # transform boolean mask into float mask + mask_value = jnp.finfo(self.dtype).min attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) attention_bias = attention_bias + alibi @@ -493,10 +503,9 @@ def init_cache(self, batch_size, max_length): # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True ) return unfreeze(init_variables["cache"]) @@ -535,13 +544,12 @@ def __call__( inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module - # TODO: check with patrick - # if isinstance(past_key_values, jnp.ndarray): - # inputs["cache"] = past_key_values - # mutable = ["cache"] - # else: - # mutable = False - mutable = False + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), From d1329c97a762e8b0bf0239e5e6390856f81f0ac4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Jul 2022 14:30:58 +0000 Subject: [PATCH 27/81] clean code a bit --- .../models/bloom/modeling_flax_bloom.py | 37 +++---------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 5f1c4694c239..0799bfd37645 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -228,7 +228,6 @@ def __call__( alibi, layer_past=None, attention_mask=None, - head_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, @@ -268,19 +267,11 @@ def __call__( if not deterministic and self.config.attention_dropout > 0.0: dropout_rng = self.make_rng("dropout") - # print("layer") - # - # if init_cache: - # print("init_cache") - # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - # if hidden_states.shape[:2] != (1, 1) and hidden_states.shape[1] < 10: - # import ipdb; ipdb.set_trace() - # transform boolean mask into float mask mask_value = jnp.finfo(self.dtype).min attention_bias = lax.select( @@ -304,9 +295,6 @@ def __call__( precision=None, ) - if head_mask is not None: - attn_weights = attn_weights * head_mask - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.dense(attn_output) @@ -389,11 +377,9 @@ def __call__( attention_mask=None, layer_number: int = None, layer_past=None, - head_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, - use_cache: bool = False, ): if self.use_scan: hidden_states = hidden_states[0] @@ -412,7 +398,6 @@ def __call__( alibi=alibi, layer_past=layer_past, attention_mask=attention_mask, - head_mask=head_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, @@ -433,10 +418,11 @@ def __call__( output = self.mlp(post_layernorm, residual, deterministic=deterministic) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] + # if use_cache: + # outputs = (output,) + outputs + # else: + + outputs = (output,) + outputs[1:] if self.use_scan: outputs = (outputs, None) @@ -516,8 +502,6 @@ def __call__( input_ids, attention_mask=None, past_key_values: dict = None, - head_mask=None, - inputs_embeds=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, @@ -673,10 +657,6 @@ def __call__( self, input_ids=None, attention_mask=None, - past_key_values=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, deterministic=True, init_cache: bool = False, output_attentions: bool = False, @@ -693,10 +673,6 @@ def __call__( # TODO put repeat here alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) - past_key_values = () if use_cache else None # TODO: come back to this line - # TODO: how to handle alibi? build alibi tensor here? - - # TODO: fix inputs to this (and args to submodules in general) # TODO: gradient checkpointing outputs = self.h( hidden_states, @@ -719,11 +695,10 @@ def __call__( if not return_dict: # TODO: don't think this return value / ordering is correct - return tuple(v for v in [outputs[0], past_key_values, outputs[-1]] if v is not None) + return tuple(v for v in [outputs[0], outputs[-1]] if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=past_key_values, hidden_states=outputs[1], attentions=outputs[-1], ) From 27b4bdb365d6fc728f430286fa6f42729173f01d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Jul 2022 17:04:41 +0200 Subject: [PATCH 28/81] added small tests --- .../models/bloom/test_modeling_flax_bloom.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 2fcb4af60080..2442fffbfe41 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -16,7 +16,8 @@ import numpy as np import timeout_decorator # noqa -from transformers import BloomConfig, is_flax_available +from transformers import BloomConfig, is_flax_available, BloomTokenizerFast +from transformers.utils.import_utils import is_torch_cuda_available from transformers.testing_utils import require_flax, slow from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin @@ -34,6 +35,7 @@ import jax import jax.numpy as jnp from transformers import FlaxBloomForCausalLM, FlaxBloomModel + from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): @@ -172,7 +174,6 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") - @require_flax class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin): all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else () @@ -198,3 +199,45 @@ def test_model_from_pretrained(self): input_ids = np.ones((1, 1)) * model.config.eos_token_id outputs = model(input_ids) self.assertIsNotNone(outputs) + + +#@slow +@require_flax +class FlaxBloomGenerationTest(unittest.TestCase): + all_model_classes = (FlaxBloomForCausalLM) if is_flax_available() else () + all_generative_model_classes = () if is_flax_available() else () + + def setUp(self): + self.model_id = "bigscience/bloom-350m" + self.tokenizer = BloomTokenizerFast.from_pretrained(self.model_id, padding_side="left") + self.model_tester = FlaxBloomModelTester(self) + self.model = FlaxBloomForCausalLM.from_pretrained(self.model_id, from_pt=True) + + def test_model_batched_gen(self): + # tests if the model outputs the same generation for the same batched input + input_sentences = ["Hello there is this string is definitely longer I believe that", "Hello there is this string is definitely longer I believe that"] + inputs = self.tokenizer(input_sentences, return_tensors="np", padding=True, truncation=True) + sequences_fx = self.model.generate(**inputs, max_length=20).sequences + self.assertEqual(sequences_fx[0], sequences_fx[1]) + + def test_model_batched_padding_left(self): + # tests if the model outputs the same generation for an input that is part of a batch + # and a single input + input_sentences_batch = ["Hello there is this string is definitely longer I believe that", "Hi I want to order"] + inputs = self.tokenizer(input_sentences_batch, return_tensors="np", padding=True, truncation=True) + sequences_fx_batch = self.model.generate(**inputs, max_length=20).sequences + + input_sentence_simple = "Hi I want to order" + inputs_simple = self.tokenizer(input_sentence_simple, return_tensors="np") + sequences_fx_simple = self.model.generate(**inputs_simple, max_length=20).sequences + + self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist()) + +@require_flax +class FlaxBloomConversionTest(unittest.TestCase): + def setup(self): + self.model_tester = FlaxBloomConversionTest(self) + + def test_alibi_padding(self): + # example_attn_mask = + pass From b61bbad36769f8aaf5a8ad5cecc0f86ff3fea78b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Jul 2022 19:25:39 +0200 Subject: [PATCH 29/81] adding albii test --- .../models/bloom/test_modeling_flax_bloom.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 2442fffbfe41..217f661c584c 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -18,7 +18,7 @@ from transformers import BloomConfig, is_flax_available, BloomTokenizerFast from transformers.utils.import_utils import is_torch_cuda_available -from transformers.testing_utils import require_flax, slow +from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test, require_torch from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -35,7 +35,12 @@ import jax import jax.numpy as jnp from transformers import FlaxBloomForCausalLM, FlaxBloomModel + + +if is_pt_flax_cross_test: + from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax + from transformers.models.bloom.modeling_bloom import build_alibi_tensor def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): @@ -233,11 +238,32 @@ def test_model_batched_padding_left(self): self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist()) -@require_flax +@require_torch +@is_pt_flax_cross_test class FlaxBloomConversionTest(unittest.TestCase): def setup(self): + self.n_head = 16 self.model_tester = FlaxBloomConversionTest(self) + def test_flax_torch_alibi(self): + import torch + dtype = jnp.float16 + single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) + seq_len = single_attention_mask.shape[-1] + + alibi = build_alibi_tensor(seq_len, self.n_head, torch.float16) + alibi_flax = build_alibi_tensor_flax(single_attention_mask, self.n_head, dtype)[0] + + self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) + + def test_alibi_padding(self): - # example_attn_mask = - pass + dtype = jnp.bfloat16 + + batch_attention_mask = jnp.array([[1, 1, 1, 1, 1], [0, 0, 0, 1, 1]]) + single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) + + alibi_padd = build_alibi_tensor_flax(batch_attention_mask, self.n_head, dtype) + alibi_simple = build_alibi_tensor_flax(single_attention_mask, self.n_head, dtype) + + self.assertTrue(jnp.equal(alibi_simple[:, :, :, :2], alibi_padd[1][:, :, 3:]).all()) From 734884b5f64356fc49344d587e537614e9c49eb2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 14:50:03 +0200 Subject: [PATCH 30/81] make CI tests pass: - change init weight - add correct tuple for output attention - add scan test - make CI tests work --- docs/source/de/index.md | 2 +- .../models/auto/modeling_flax_auto.py | 2 + .../models/bloom/modeling_flax_bloom.py | 7 +- src/transformers/utils/dummy_flax_objects.py | 21 ++++++ .../models/bloom/test_modeling_flax_bloom.py | 64 +++++++++++-------- 5 files changed, 65 insertions(+), 31 deletions(-) diff --git a/docs/source/de/index.md b/docs/source/de/index.md index b2661345adff..4742a99f643c 100644 --- a/docs/source/de/index.md +++ b/docs/source/de/index.md @@ -218,7 +218,7 @@ Flax), PyTorch, und/oder TensorFlow haben. | BigBird-Pegasus | ❌ | ❌ | ✅ | ❌ | ❌ | | Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ | -| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ | +| BLOOM | ❌ | ✅ | ✅ | ❌ | ✅ | | CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | CANINE | ✅ | ❌ | ✅ | ❌ | ❌ | | CLIP | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 44ef84448119..ebc768963429 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -35,6 +35,7 @@ ("big_bird", "FlaxBigBirdModel"), ("blenderbot", "FlaxBlenderbotModel"), ("blenderbot-small", "FlaxBlenderbotSmallModel"), + ("bloom", "FlaxBloomModel"), ("clip", "FlaxCLIPModel"), ("distilbert", "FlaxDistilBertModel"), ("electra", "FlaxElectraModel"), @@ -139,6 +140,7 @@ ("bart", "FlaxBartForCausalLM"), ("bert", "FlaxBertForCausalLM"), ("big_bird", "FlaxBigBirdForCausalLM"), + ("bloom", "FlaxBloomForCausalLM"), ("electra", "FlaxElectraForCausalLM"), ("gpt-sw3", "FlaxGPT2LMHeadModel"), ("gpt2", "FlaxGPT2LMHeadModel"), diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 0799bfd37645..63b9861d0c47 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -422,7 +422,7 @@ def __call__( # outputs = (output,) + outputs # else: - outputs = (output,) + outputs[1:] + outputs = (output,) + outputs if self.use_scan: outputs = (outputs, None) @@ -461,11 +461,10 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -718,7 +717,6 @@ class FlaxBloomModel(FlaxBloomPreTrainedModel): ) -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLMModule with GPTNeo->Bloom class FlaxBloomForCausalLMModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 @@ -774,7 +772,6 @@ def __call__( """, BLOOM_START_DOCSTRING, ) -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->Bloom class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): module_class = FlaxBloomForCausalLMModule diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index cc3f48cb25e2..78be4ef747e9 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -520,6 +520,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxBloomForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxBloomModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxBloomPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxCLIPModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 217f661c584c..78a5d6793940 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -13,12 +13,10 @@ # limitations under the License. import unittest -import numpy as np -import timeout_decorator # noqa +import numpy as np # noqa -from transformers import BloomConfig, is_flax_available, BloomTokenizerFast -from transformers.utils.import_utils import is_torch_cuda_available -from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test, require_torch +from transformers import BloomConfig, BloomTokenizerFast, is_flax_available +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_torch, slow from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -32,15 +30,14 @@ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - import jax import jax.numpy as jnp from transformers import FlaxBloomForCausalLM, FlaxBloomModel - + if is_pt_flax_cross_test: - from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax from transformers.models.bloom.modeling_bloom import build_alibi_tensor + from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): @@ -82,7 +79,7 @@ def __init__( self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = n_layer - self.n_head = n_head + self.num_attention_heads = n_head self.hidden_act = hidden_act self.hidden_dropout = hidden_dropout self.attention_probs_dropout_prob = attention_probs_dropout_prob @@ -101,7 +98,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, hidden_size=self.hidden_size, n_layer=self.num_hidden_layers, - n_head=self.n_head, + n_head=self.num_attention_heads, hidden_dropout=self.hidden_dropout, attention_dropout=self.attention_probs_dropout_prob, eos_token_id=self.eos_token_id, @@ -179,6 +176,7 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + @require_flax class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin): all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else () @@ -204,9 +202,9 @@ def test_model_from_pretrained(self): input_ids = np.ones((1, 1)) * model.config.eos_token_id outputs = model(input_ids) self.assertIsNotNone(outputs) - -#@slow + +@slow @require_flax class FlaxBloomGenerationTest(unittest.TestCase): all_model_classes = (FlaxBloomForCausalLM) if is_flax_available() else () @@ -220,15 +218,21 @@ def setUp(self): def test_model_batched_gen(self): # tests if the model outputs the same generation for the same batched input - input_sentences = ["Hello there is this string is definitely longer I believe that", "Hello there is this string is definitely longer I believe that"] + input_sentences = [ + "Hello there is this string is definitely longer I believe that", + "Hello there is this string is definitely longer I believe that", + ] inputs = self.tokenizer(input_sentences, return_tensors="np", padding=True, truncation=True) sequences_fx = self.model.generate(**inputs, max_length=20).sequences - self.assertEqual(sequences_fx[0], sequences_fx[1]) - + self.assertEqual(sequences_fx[0].tolist(), sequences_fx[1].tolist()) + def test_model_batched_padding_left(self): # tests if the model outputs the same generation for an input that is part of a batch # and a single input - input_sentences_batch = ["Hello there is this string is definitely longer I believe that", "Hi I want to order"] + input_sentences_batch = [ + "Hello there is this string is definitely longer I believe that", + "Hi I want to order", + ] inputs = self.tokenizer(input_sentences_batch, return_tensors="np", padding=True, truncation=True) sequences_fx_batch = self.model.generate(**inputs, max_length=20).sequences @@ -238,32 +242,42 @@ def test_model_batched_padding_left(self): self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist()) + def test_scan_model(self): + scan_model = FlaxBloomForCausalLM.from_pretrained("sanchit-gandhi/bloom-350m-scan", use_scan=True) + input_ids = np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32) + + unrolled_logits = self.model(input_ids).logits + scan_logits = scan_model(input_ids).logits + + self.assertTrue(np.max(np.abs(unrolled_logits - scan_logits)) <= 1e-3) + + @require_torch @is_pt_flax_cross_test class FlaxBloomConversionTest(unittest.TestCase): def setup(self): - self.n_head = 16 + self.num_attention_heads = 16 self.model_tester = FlaxBloomConversionTest(self) - + def test_flax_torch_alibi(self): import torch + dtype = jnp.float16 single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) seq_len = single_attention_mask.shape[-1] - alibi = build_alibi_tensor(seq_len, self.n_head, torch.float16) - alibi_flax = build_alibi_tensor_flax(single_attention_mask, self.n_head, dtype)[0] + alibi = build_alibi_tensor(seq_len, self.num_attention_heads, torch.float16) + alibi_flax = build_alibi_tensor_flax(single_attention_mask, self.num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) - def test_alibi_padding(self): dtype = jnp.bfloat16 batch_attention_mask = jnp.array([[1, 1, 1, 1, 1], [0, 0, 0, 1, 1]]) single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) - - alibi_padd = build_alibi_tensor_flax(batch_attention_mask, self.n_head, dtype) - alibi_simple = build_alibi_tensor_flax(single_attention_mask, self.n_head, dtype) - + + alibi_padd = build_alibi_tensor_flax(batch_attention_mask, self.num_attention_heads, dtype) + alibi_simple = build_alibi_tensor_flax(single_attention_mask, self.num_attention_heads, dtype) + self.assertTrue(jnp.equal(alibi_simple[:, :, :, :2], alibi_padd[1][:, :, 3:]).all()) From b717b39a9f9f4f2d93c8d6c581664e531f6d7ac1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:05:43 +0200 Subject: [PATCH 31/81] fix few nits --- .../models/bloom/modeling_flax_bloom.py | 2 +- .../models/bloom/test_modeling_flax_bloom.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 63b9861d0c47..ad331dc15fa7 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -146,7 +146,7 @@ def get_slopes_power_of_2(n): num_heads = n_head query_length = 1 - slopes = jnp.array(get_slopes(n_head))[None, :, None, None] + slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 slopes_broadcasted = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 78a5d6793940..2618cd9261f1 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -17,6 +17,7 @@ from transformers import BloomConfig, BloomTokenizerFast, is_flax_available from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_torch, slow +from transformers.utils.import_utils import is_torch_available from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -33,9 +34,7 @@ import jax.numpy as jnp from transformers import FlaxBloomForCausalLM, FlaxBloomModel - -if is_pt_flax_cross_test: - +if is_flax_available() and is_torch_available(): from transformers.models.bloom.modeling_bloom import build_alibi_tensor from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax @@ -255,19 +254,16 @@ def test_scan_model(self): @require_torch @is_pt_flax_cross_test class FlaxBloomConversionTest(unittest.TestCase): - def setup(self): - self.num_attention_heads = 16 - self.model_tester = FlaxBloomConversionTest(self) - def test_flax_torch_alibi(self): import torch dtype = jnp.float16 single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) seq_len = single_attention_mask.shape[-1] + num_attention_heads = 16 - alibi = build_alibi_tensor(seq_len, self.num_attention_heads, torch.float16) - alibi_flax = build_alibi_tensor_flax(single_attention_mask, self.num_attention_heads, dtype)[0] + alibi = build_alibi_tensor(seq_len, num_attention_heads, torch.float16) + alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) @@ -276,8 +272,10 @@ def test_alibi_padding(self): batch_attention_mask = jnp.array([[1, 1, 1, 1, 1], [0, 0, 0, 1, 1]]) single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) + num_attention_heads = 16 + - alibi_padd = build_alibi_tensor_flax(batch_attention_mask, self.num_attention_heads, dtype) - alibi_simple = build_alibi_tensor_flax(single_attention_mask, self.num_attention_heads, dtype) + alibi_padd = build_alibi_tensor_flax(batch_attention_mask, num_attention_heads, dtype) + alibi_simple = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype) self.assertTrue(jnp.equal(alibi_simple[:, :, :, :2], alibi_padd[1][:, :, 3:]).all()) From fdf439250be633f50a32e796ca24b30edfa04f12 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:08:33 +0200 Subject: [PATCH 32/81] fix nit onnx --- src/transformers/models/bloom/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index 73c4410371b3..27e914aea306 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -27,6 +27,7 @@ "configuration_bloom": [ "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", + "BloomOnnxConfig", ], } try: @@ -67,7 +68,7 @@ if TYPE_CHECKING: - from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig + from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig try: if not is_tokenizers_available(): From 7dd7e64b7e15938eeab8cfed4d9311d9e83a58b5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:08:49 +0200 Subject: [PATCH 33/81] fix onnx nit --- tests/models/bloom/test_modeling_flax_bloom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 2618cd9261f1..08a2660d1fea 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -274,7 +274,6 @@ def test_alibi_padding(self): single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) num_attention_heads = 16 - alibi_padd = build_alibi_tensor_flax(batch_attention_mask, num_attention_heads, dtype) alibi_simple = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype) From 46a685da3b7ffc85f590bdf20f2964cce34c5f40 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 7 Jul 2022 18:00:27 +0100 Subject: [PATCH 34/81] add missing dtype args to nn.Modules --- src/transformers/models/bloom/modeling_flax_bloom.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index ad331dc15fa7..87852cfaa977 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -492,6 +492,10 @@ def init_cache(self, batch_size, max_length): init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True ) + import ipdb + + ipdb.set_trace() + return unfreeze(init_variables["cache"]) # TODO: check whether this is correct (position ids might not be required) @@ -640,9 +644,11 @@ def setup(self): self.config.vocab_size, self.embed_dim, embedding_init=embedding_init, + dtype=self.dtype, ) + # post-embedding layernorm - self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon) + self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) # transformer layers self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) From f34c854b884b8d552230fe61a1c33e6ba0584501 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 7 Jul 2022 20:03:38 +0100 Subject: [PATCH 35/81] remove debugging statements --- src/transformers/models/bloom/modeling_flax_bloom.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 87852cfaa977..fbd15b67c874 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -492,10 +492,6 @@ def init_cache(self, batch_size, max_length): init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True ) - import ipdb - - ipdb.set_trace() - return unfreeze(init_variables["cache"]) # TODO: check whether this is correct (position ids might not be required) From b83009eadaa9852a0de412963a5a85e24bfb4f28 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 8 Jul 2022 09:40:50 +0100 Subject: [PATCH 36/81] fix scan generate --- src/transformers/models/bloom/modeling_flax_bloom.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index fbd15b67c874..9214f1ec93f7 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -588,13 +588,16 @@ def __call__( FlaxBloomBlock, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0), + in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( hidden_states, alibi, attention_mask, # kwargs not supported by scan jnp.arange(self.config.num_hidden_layers), + None, + deterministic, + init_cache, ) hidden_states = hidden_states[0] From 362ff0350cf116fa03eaa5ce09439dee8cdd04e5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Jul 2022 12:19:20 +0200 Subject: [PATCH 37/81] Update modeling_flax_bloom.py --- src/transformers/models/bloom/modeling_flax_bloom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 9214f1ec93f7..db767decccd6 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -164,7 +164,6 @@ def setup(self): self.hidden_size = self.config.hidden_size self.num_heads = self.config.n_head self.head_dim = self.hidden_size // self.num_heads - self.attention_softmax_in_fp32 = self.config.attention_softmax_in_fp32 if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( From 5652b057a79a2d448fa391f1fd0c3b27ade80569 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Jul 2022 12:28:57 +0200 Subject: [PATCH 38/81] Update test_modeling_flax_bloom.py --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 08a2660d1fea..0615536bb0e9 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -262,7 +262,7 @@ def test_flax_torch_alibi(self): seq_len = single_attention_mask.shape[-1] num_attention_heads = 16 - alibi = build_alibi_tensor(seq_len, num_attention_heads, torch.float16) + alibi = build_alibi_tensor(seq_len, num_attention_heads, torch.float16, device=torch.device("cpu")) alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) From 9e95e315aadb4cd4501b1949dbaee7b826af7879 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Jul 2022 14:30:28 +0200 Subject: [PATCH 39/81] Update test_modeling_flax_bloom.py --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 0615536bb0e9..9d833e8997d0 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -262,7 +262,7 @@ def test_flax_torch_alibi(self): seq_len = single_attention_mask.shape[-1] num_attention_heads = 16 - alibi = build_alibi_tensor(seq_len, num_attention_heads, torch.float16, device=torch.device("cpu")) + alibi = build_alibi_tensor(single_attention_mask, num_attention_heads, torch.float16, device=torch.device("cpu")) alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) From 627e8e4b67c372ddba29607493eded586577293a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Jul 2022 16:15:32 +0200 Subject: [PATCH 40/81] Update test_modeling_flax_bloom.py --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 9d833e8997d0..c7b6135b8b49 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -262,7 +262,7 @@ def test_flax_torch_alibi(self): seq_len = single_attention_mask.shape[-1] num_attention_heads = 16 - alibi = build_alibi_tensor(single_attention_mask, num_attention_heads, torch.float16, device=torch.device("cpu")) + alibi = build_alibi_tensor(torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16, device=torch.device("cpu")) alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) From 77dfc0a3e861d3a03de014bb8cdee7e9a019119f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Jul 2022 20:12:39 +0200 Subject: [PATCH 41/81] fix small test issue + make style --- tests/models/bloom/test_modeling_flax_bloom.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index c7b6135b8b49..efd99e495353 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -258,11 +258,12 @@ def test_flax_torch_alibi(self): import torch dtype = jnp.float16 - single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) - seq_len = single_attention_mask.shape[-1] + single_attention_mask = np.array([[1, 1, 1, 1, 1]]) num_attention_heads = 16 - alibi = build_alibi_tensor(torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16, device=torch.device("cpu")) + alibi = build_alibi_tensor( + torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16, device=torch.device("cpu") + ) alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) From 70b57efc9d1a78da561490009a8364ccd4d5c68f Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 27 Jul 2022 17:40:06 +0100 Subject: [PATCH 42/81] clean up --- .../models/bloom/modeling_flax_bloom.py | 64 +++++-------------- 1 file changed, 15 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index db767decccd6..e58c211ea5c1 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -12,13 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Flax BLOOM model.""" -# TODO: see todos throughout this file -# TODO: check correctness against pytorch implementation -# TODO: add unit tests -# TODO: add documentation / check that documentation is correct -# TODO: BLOOM_INPUTS_DOCSTRING might be wrong still (position_ids) -# TODO: check that code is jit-able +"""Flax BLOOM model.""" import math from functools import partial @@ -102,9 +96,6 @@ - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. @@ -135,10 +126,10 @@ def get_slopes_power_of_2(n): + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # Note: alibi will be added to the attention bias that is applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly + # => the query_length dimension will then be broadcast correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 # batch_size = 1, n_head = n_head, query_length @@ -149,10 +140,10 @@ def get_slopes_power_of_2(n): slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 - slopes_broadcasted = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) - arange_broadcasted = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) + slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) + arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) - alibi = slopes_broadcasted * arange_broadcasted + alibi = slopes_broadcast * arange_broadcast return alibi @@ -182,7 +173,7 @@ def setup(self): self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout) def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, 3 * self.head_dim)) + return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) @@ -281,7 +272,7 @@ def __call__( attention_bias = attention_bias + alibi - # TODO: override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 + # TODO(sanchit-gandhi): override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 # usual dot product attention attn_weights = dot_product_attention_weights( query, @@ -334,18 +325,7 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.act(hidden_states) - # TODO: this code block is from the pytorch implementation. needs changing to work. - # if self.pretraining_tp > 1 and self.slow_but_exact: - if False: - intermediate_output = jnp.zeros_like(residual) - slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp - for i in range(self.pretraining_tp): - intermediate_output = intermediate_output + nn.functional.linear( - hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - intermediate_output = self.dense_4h_to_h(hidden_states) + intermediate_output = self.dense_4h_to_h(hidden_states) intermediate_output = intermediate_output + residual hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) @@ -417,10 +397,6 @@ def __call__( output = self.mlp(post_layernorm, residual, deterministic=deterministic) - # if use_cache: - # outputs = (output,) + outputs - # else: - outputs = (output,) + outputs if self.use_scan: @@ -429,10 +405,6 @@ def __call__( return outputs -# TODO: does this still require position_ids? -# TODO: gradient checkpointing -# TODO: _no_split_modules? -# TODO: check initialization class FlaxBloomPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -493,7 +465,6 @@ def init_cache(self, batch_size, max_length): ) return unfreeze(init_variables["cache"]) - # TODO: check whether this is correct (position ids might not be required) @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) def __call__( self, @@ -525,7 +496,8 @@ def __call__( inputs = {"params": params or self.params} - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module + # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. + # It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] @@ -562,7 +534,7 @@ class FlaxBloomBlockCollection(nn.Module): dtype: jnp.dtype = jnp.float32 use_scan: bool = False - # TODO (SG): re-write as a `setup` to conform to Transformers JAX/Flax conventions -> awaiting CG response on G Chat + # TODO(sanchit-gandhi): re-write as a `setup` to conform to Transformers JAX/Flax conventions @nn.compact def __call__( self, @@ -632,12 +604,11 @@ class FlaxBloomModule(nn.Module): use_scan: bool = False def setup(self): - # TODO: check initialization correctness self.embed_dim = self.config.hidden_size embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - # word embeddings (no positional embedding layer) TODO: confirm this statement correct + # word embeddings (no positional embedding layer) self.word_embeddings = nn.Embed( self.config.vocab_size, self.embed_dim, @@ -653,8 +624,6 @@ def setup(self): # final layernorm self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - # TODO: change how gradient checkpointing is done - self.gradient_checkpointing = False def __call__( self, @@ -676,7 +645,6 @@ def __call__( # TODO put repeat here alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) - # TODO: gradient checkpointing outputs = self.h( hidden_states, alibi=alibi, @@ -697,7 +665,6 @@ def __call__( outputs = (hidden_states,) + outputs[1:] if not return_dict: - # TODO: don't think this return value / ordering is correct return tuple(v for v in [outputs[0], outputs[-1]] if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( @@ -779,15 +746,14 @@ def __call__( class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): module_class = FlaxBloomForCausalLMModule - # TODO: check if this class is correct / take out position ids def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Bloom uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation + # But since Bloom uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) From e1012632b5bcb69f375700158a5498e9c40a35a7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 27 Jul 2022 18:53:43 +0200 Subject: [PATCH 43/81] Update tests/models/bloom/test_modeling_flax_bloom.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index efd99e495353..55c810e0b467 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -39,7 +39,7 @@ from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax -def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None): +def prepare_bloom_inputs_dict(config, input_ids, attention_mask=None): if attention_mask is None: attention_mask = np.where(input_ids != config.pad_token_id, 1, 0) return { From 223c0622e3aa3fa282258157b57d17908629e8fa Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 27 Jul 2022 18:54:26 +0200 Subject: [PATCH 44/81] fix function name --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 55c810e0b467..b7b11222d712 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -106,7 +106,7 @@ def prepare_config_and_inputs(self): is_encoder_decoder=False, use_cache=False, ) - inputs_dict = prepare_opt_inputs_dict(config, input_ids) + inputs_dict = prepare_bloom_inputs_dict(config, input_ids) return config, inputs_dict def prepare_config_and_inputs_for_common(self): From a359fc619ccc405b7e2eb0ae0f0723c2c4e07726 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 27 Jul 2022 19:01:53 +0200 Subject: [PATCH 45/81] small fix test --- tests/models/bloom/test_modeling_flax_bloom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index b7b11222d712..891012127b15 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -121,7 +121,6 @@ def check_use_cache_forward(self, model_class_name, config, inputs_dict): attention_mask = inputs_dict["attention_mask"] past_key_values = model.init_cache(input_ids.shape[0], max_length) - attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4") outputs_cache = model( input_ids[:, :-1], From e838fcacc02f38923b94a0235bb130bc5d92d90c Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Thu, 28 Jul 2022 22:44:02 +0200 Subject: [PATCH 46/81] forward contrib credits from PR17761 From 49bf9e0aca16bb0a2f4a6b34a91c2eb42e5ac272 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 10:24:07 +0200 Subject: [PATCH 47/81] Fix failing test --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 891012127b15..b94798bf2919 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -118,7 +118,7 @@ def check_use_cache_forward(self, model_class_name, config, inputs_dict): model = model_class_name(config) input_ids = inputs_dict["input_ids"] - attention_mask = inputs_dict["attention_mask"] + attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4") past_key_values = model.init_cache(input_ids.shape[0], max_length) From 5b54e4c13431bdce346d6882f99586a77bae455e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 10:27:56 +0200 Subject: [PATCH 48/81] fix small typo documentation --- src/transformers/models/bloom/modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index e58c211ea5c1..2785d13784a2 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -42,7 +42,7 @@ _CHECKPOINT_FOR_DOC = "bigscience/bloom" _CONFIG_FOR_DOC = "BloomConfig" -_TOKENIZER_FOR_DOC = "BloomTokenizer" +_TOKENIZER_FOR_DOC = "BloomTokenizerFast" BLOOM_START_DOCSTRING = r""" From bb2e0e4a732c7ba1c2afc0b83d697df166b927f6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 14:55:28 +0200 Subject: [PATCH 49/81] fix non passing test - remove device from build alibi --- tests/models/bloom/test_modeling_flax_bloom.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index b94798bf2919..1388a710eebf 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -260,9 +260,7 @@ def test_flax_torch_alibi(self): single_attention_mask = np.array([[1, 1, 1, 1, 1]]) num_attention_heads = 16 - alibi = build_alibi_tensor( - torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16, device=torch.device("cpu") - ) + alibi = build_alibi_tensor(torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16) alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) From 78c3121b8f9825a36e2b44805f113de3efc8ec40 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 13:47:33 +0000 Subject: [PATCH 50/81] refactor call - refactor `FlaxBloomBlockCollection` module --- .../models/bloom/modeling_flax_bloom.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 2785d13784a2..1262f0db7951 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -534,8 +534,18 @@ class FlaxBloomBlockCollection(nn.Module): dtype: jnp.dtype = jnp.float32 use_scan: bool = False + def setup(self): + self.layers = [FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) for layer_number in range(self.config.num_hidden_layers)] + if self.use_scan: + self.scan_fn = scan_with_axes( + FlaxBloomBlock, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), + length=self.config.num_hidden_layers, + )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") + # TODO(sanchit-gandhi): re-write as a `setup` to conform to Transformers JAX/Flax conventions - @nn.compact def __call__( self, hidden_states, @@ -555,13 +565,7 @@ def __call__( # assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" hidden_states = (hidden_states,) - hidden_states, _ = scan_with_axes( - FlaxBloomBlock, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), - length=self.config.num_hidden_layers, - )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers")( + hidden_states, _ = self.scan_fn( hidden_states, alibi, attention_mask, # kwargs not supported by scan @@ -577,8 +581,7 @@ def __call__( if output_hidden_states: all_hidden_states += (hidden_states,) - - layer_outputs = FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False)( + layer_outputs = self.layers[layer_number]( hidden_states, alibi=alibi, attention_mask=attention_mask, From c23221e505270c468dcbb5a6a9b09b854bb6877e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 13:48:44 +0000 Subject: [PATCH 51/81] make style --- .../models/bloom/modeling_flax_bloom.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 1262f0db7951..12122e4a4806 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -535,15 +535,18 @@ class FlaxBloomBlockCollection(nn.Module): use_scan: bool = False def setup(self): - self.layers = [FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) for layer_number in range(self.config.num_hidden_layers)] + self.layers = [ + FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) + for layer_number in range(self.config.num_hidden_layers) + ] if self.use_scan: self.scan_fn = scan_with_axes( - FlaxBloomBlock, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), - length=self.config.num_hidden_layers, - )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") + FlaxBloomBlock, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), + length=self.config.num_hidden_layers, + )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") # TODO(sanchit-gandhi): re-write as a `setup` to conform to Transformers JAX/Flax conventions def __call__( From a2c0e98422081fae076b47d90946c09da595143f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 14:27:52 +0000 Subject: [PATCH 52/81] upcast to fp32 --- .../models/bloom/modeling_flax_bloom.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 12122e4a4806..e76d0525c5be 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -155,6 +155,7 @@ def setup(self): self.hidden_size = self.config.hidden_size self.num_heads = self.config.n_head self.head_dim = self.hidden_size // self.num_heads + self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( @@ -272,8 +273,12 @@ def __call__( attention_bias = attention_bias + alibi - # TODO(sanchit-gandhi): override softmax precision to fp32 if self.attention_softmax_in_fp32=True and self.dtype != fp32 - # usual dot product attention + # Cast in fp32 if the original dtype is different from fp32 + if self.attention_softmax_in_fp32: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + attention_bias = attention_bias.astype(jnp.float32) + attn_weights = dot_product_attention_weights( query, key, @@ -285,6 +290,12 @@ def __call__( precision=None, ) + # Cast back in the original dtype if the native dtype is not fp32 + if self.attention_softmax_in_fp32: + query = query.astype(self.dtype) + key = key.astype(self.dtype) + attention_bias = attention_bias.astype(self.dtype) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.dense(attn_output) From bdda3fa25abde7d8ab1880169a41731364049538 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 14:32:31 +0000 Subject: [PATCH 53/81] cleaner way to upcast --- src/transformers/models/bloom/modeling_flax_bloom.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index e76d0525c5be..f8b5ab7c2d25 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -274,10 +274,7 @@ def __call__( attention_bias = attention_bias + alibi # Cast in fp32 if the original dtype is different from fp32 - if self.attention_softmax_in_fp32: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - attention_bias = attention_bias.astype(jnp.float32) + attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype attn_weights = dot_product_attention_weights( query, @@ -286,15 +283,13 @@ def __call__( dropout_rng=dropout_rng, dropout_rate=self.config.attention_dropout, deterministic=deterministic, - dtype=self.dtype, + dtype=attention_dtype, precision=None, ) # Cast back in the original dtype if the native dtype is not fp32 if self.attention_softmax_in_fp32: - query = query.astype(self.dtype) - key = key.astype(self.dtype) - attention_bias = attention_bias.astype(self.dtype) + attn_weights = attn_weights.astype(self.dtype) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) From d8ddc11cef6cbd050d46e9bee06ecccbc9cbd046 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 14:34:06 +0000 Subject: [PATCH 54/81] remove unused args --- src/transformers/models/bloom/modeling_flax_bloom.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index f8b5ab7c2d25..f122484f1586 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -317,9 +317,6 @@ class FlaxBloomMLP(nn.Module): def setup(self): hidden_size = self.config.hidden_size - self.pretraining_tp = self.config.pretraining_tp - self.slow_but_exact = self.config.slow_but_exact - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) From c4866ecb2b8315f7714886822124c7ff3e7658b5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 14:41:08 +0000 Subject: [PATCH 55/81] remove layer number --- src/transformers/models/bloom/modeling_flax_bloom.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index f122484f1586..80024f7fbc07 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -222,7 +222,6 @@ def __call__( deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, - layer_number: int = None, ): batch_size, seq_length = hidden_states.shape[:2] @@ -357,7 +356,6 @@ def __call__( hidden_states, alibi, attention_mask=None, - layer_number: int = None, layer_past=None, deterministic: bool = True, init_cache: bool = False, @@ -383,7 +381,6 @@ def __call__( deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, - layer_number=layer_number, ) attention_output = attn_outputs[0] @@ -594,7 +591,6 @@ def __call__( deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, - layer_number=layer_number, ) hidden_states = layer_outputs[0] From 12ba14ba9332e5b3e44b9ca60a4c808503cde625 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 16:20:39 +0000 Subject: [PATCH 56/81] fix scan test --- src/transformers/models/bloom/modeling_flax_bloom.py | 3 ++- tests/models/bloom/test_modeling_flax_bloom.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 80024f7fbc07..9669d7d25ee9 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -539,6 +539,7 @@ def setup(self): FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) for layer_number in range(self.config.num_hidden_layers) ] + if self.use_scan: self.scan_fn = scan_with_axes( FlaxBloomBlock, @@ -572,10 +573,10 @@ def __call__( hidden_states, alibi, attention_mask, # kwargs not supported by scan - jnp.arange(self.config.num_hidden_layers), None, deterministic, init_cache, + output_attentions, ) hidden_states = hidden_states[0] diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 1388a710eebf..b9e03a837e1d 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -209,10 +209,10 @@ class FlaxBloomGenerationTest(unittest.TestCase): all_generative_model_classes = () if is_flax_available() else () def setUp(self): - self.model_id = "bigscience/bloom-350m" + self.model_id = "bigscience/bloom-560m" self.tokenizer = BloomTokenizerFast.from_pretrained(self.model_id, padding_side="left") self.model_tester = FlaxBloomModelTester(self) - self.model = FlaxBloomForCausalLM.from_pretrained(self.model_id, from_pt=True) + self.model = FlaxBloomForCausalLM.from_pretrained(self.model_id, from_pt=True, revision="gs555750") def test_model_batched_gen(self): # tests if the model outputs the same generation for the same batched input From dfe1697cf9209931cadce1df4532fd33fdd2e1b8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 12 Sep 2022 16:24:50 +0000 Subject: [PATCH 57/81] make style --- src/transformers/models/bloom/modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 9669d7d25ee9..58766f664040 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -539,7 +539,7 @@ def setup(self): FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) for layer_number in range(self.config.num_hidden_layers) ] - + if self.use_scan: self.scan_fn = scan_with_axes( FlaxBloomBlock, From f5870458397394ee90d230020bd69fb3f4149af0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 13 Sep 2022 08:23:10 +0000 Subject: [PATCH 58/81] fix i4 casting --- src/transformers/models/bloom/modeling_flax_bloom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 58766f664040..a3a52ddd17f9 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -457,7 +457,7 @@ def init_cache(self, batch_size, max_length): cache. """ # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) + input_ids = jnp.ones((batch_size, max_length), dtype="i4") attention_mask = jnp.ones_like(input_ids) init_variables = self.module.init( @@ -641,7 +641,7 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + inputs_embeds = self.word_embeddings(input_ids) # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) From 0a7fdc4826bca91f4b34c43af585d63f672b85a6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 15:32:43 +0000 Subject: [PATCH 59/81] fix slow test --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index b9e03a837e1d..b8033259b0d2 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -196,7 +196,7 @@ def test_use_cache_forward_with_attn_mask(self): @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("bigscience/bloom-350m") + model = model_class_name.from_pretrained("bigscience/bloom-560m") input_ids = np.ones((1, 1)) * model.config.eos_token_id outputs = model(input_ids) self.assertIsNotNone(outputs) From b49957fe4a0215facb512cf3cce43ae804b9c5d2 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 11 Oct 2022 17:00:02 +0200 Subject: [PATCH 60/81] Update src/transformers/models/bloom/modeling_flax_bloom.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/bloom/modeling_flax_bloom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index a3a52ddd17f9..910b7cb8d41e 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -549,7 +549,6 @@ def setup(self): length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") - # TODO(sanchit-gandhi): re-write as a `setup` to conform to Transformers JAX/Flax conventions def __call__( self, hidden_states, From eb023eb5a3d57119c5ab14c4f9484af9838f6cf7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 15:40:29 +0000 Subject: [PATCH 61/81] remove `layer_past` --- src/transformers/models/bloom/modeling_flax_bloom.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 910b7cb8d41e..b8f9aba4b3c3 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -217,7 +217,6 @@ def __call__( hidden_states, residual, alibi, - layer_past=None, attention_mask=None, deterministic: bool = True, init_cache: bool = False, @@ -356,7 +355,6 @@ def __call__( hidden_states, alibi, attention_mask=None, - layer_past=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, @@ -376,7 +374,6 @@ def __call__( layernorm_output, residual=residual, alibi=alibi, - layer_past=layer_past, attention_mask=attention_mask, deterministic=deterministic, init_cache=init_cache, From b5a1c3fff2018f886bb19a23b92e04383572be95 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 15:42:21 +0000 Subject: [PATCH 62/81] refactor a bit --- .../models/bloom/modeling_flax_bloom.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index b8f9aba4b3c3..104edbd4afdd 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -537,14 +537,13 @@ def setup(self): for layer_number in range(self.config.num_hidden_layers) ] - if self.use_scan: - self.scan_fn = scan_with_axes( - FlaxBloomBlock, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), - length=self.config.num_hidden_layers, - )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") + self.scan_fn = scan_with_axes( + FlaxBloomBlock, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), + length=self.config.num_hidden_layers, + )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") def __call__( self, From 961b0472fff9b31001b3e881877c58e93e74f3ee Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 15:58:08 +0000 Subject: [PATCH 63/81] fix `scan` slow test --- src/transformers/models/bloom/modeling_flax_bloom.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 104edbd4afdd..2bd1da03b2e7 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -541,7 +541,7 @@ def setup(self): FlaxBloomBlock, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, 0, nn.broadcast, nn.broadcast, nn.broadcast), + in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), length=self.config.num_hidden_layers, )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") @@ -568,7 +568,6 @@ def __call__( hidden_states, alibi, attention_mask, # kwargs not supported by scan - None, deterministic, init_cache, output_attentions, From 0951eac7c01cb30e5b5a624a68f262d2608b1026 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 17:05:43 +0000 Subject: [PATCH 64/81] remove useless import --- src/transformers/modeling_flax_pytorch_utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 956c262e3e93..daf68ecaccba 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -134,19 +134,9 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): - try: - import torch # noqa: F401 - except ImportError: - logger.error( - "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" - " instructions." - ) - raise - # convert pytorch tensor to numpy # numpy currently does not support bfloat16, need to go over float32 in this case to not loose precision - is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) + is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) # noqa: F821 pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix From 1db26499cabc3428eb3c5ff7915d9187e6e5d0ea Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 17:17:17 +0000 Subject: [PATCH 65/81] major changes - remove unused code - refactor a bit - revert import `torch` --- src/transformers/modeling_flax_pytorch_utils.py | 9 +++++++++ src/transformers/models/bloom/modeling_flax_bloom.py | 12 +++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index daf68ecaccba..3e4fb00243c1 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -136,6 +136,15 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy # numpy currently does not support bfloat16, need to go over float32 in this case to not loose precision + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) # noqa: F821 pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()} diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 2bd1da03b2e7..7395561db335 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -282,7 +282,6 @@ def __call__( dropout_rate=self.config.attention_dropout, deterministic=deterministic, dtype=attention_dtype, - precision=None, ) # Cast back in the original dtype if the native dtype is not fp32 @@ -606,13 +605,11 @@ class FlaxBloomModule(nn.Module): def setup(self): self.embed_dim = self.config.hidden_size - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - # word embeddings (no positional embedding layer) self.word_embeddings = nn.Embed( self.config.vocab_size, self.embed_dim, - embedding_init=embedding_init, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) @@ -639,11 +636,8 @@ def __call__( # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) - batch_size, curr_seq_len, _ = hidden_states.shape - - alibi = build_alibi_tensor_flax(curr_seq_len, self.config.n_head, hidden_states.dtype) - # TODO put repeat here - alibi = jnp.broadcast_to(alibi[None, :], (batch_size,) + alibi.shape).reshape((-1,) + alibi.shape[1:]) + # build alibi depending on `attention_mask` + alibi = build_alibi_tensor_flax(attention_mask, self.config.n_head, hidden_states.dtype) outputs = self.h( hidden_states, From c2230946162680fe918b18d5b32c0afa1515da30 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 11 Oct 2022 18:03:03 +0000 Subject: [PATCH 66/81] major refactoring - change build alibi --- .../models/bloom/modeling_flax_bloom.py | 113 ++++++++++++------ .../models/bloom/test_modeling_flax_bloom.py | 10 +- 2 files changed, 84 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 7395561db335..46c249a88dc9 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -110,41 +110,87 @@ """ -def build_alibi_tensor_flax(attention_mask, n_head, dtype): - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) +# def build_alibi_tensor_flax(attention_mask, n_head, dtype): +# def get_slopes(n): +# def get_slopes_power_of_2(n): +# start = 2 ** (-(2 ** -(math.log2(n) - 3))) +# ratio = start +# return [start * ratio**i for i in range(n)] + +# if math.log2(n).is_integer(): +# return get_slopes_power_of_2(n) +# else: +# closest_power_of_2 = 2 ** math.floor(math.log2(n)) +# return ( +# get_slopes_power_of_2(closest_power_of_2) +# + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] +# ) + +# # Note: alibi will be added to the attention bias that is applied to the query, key product of attention +# # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) +# # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) +# # => the query_length dimension will then be broadcast correctly +# # This is more or less identical to T5's relative position bias: +# # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 +# # batch_size = 1, n_head = n_head, query_length +# batch_size, key_length = attention_mask.shape +# num_heads = n_head +# query_length = 1 + +# slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) +# arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 + +# slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) +# arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) + +# alibi = slopes_broadcast * arange_broadcast +# return alibi + + +def build_alibi_tensor_flax(attention_mask, num_heads, dtype, return_torch_like=False): + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + A Flax implementation - # Note: alibi will be added to the attention bias that is applied to the query, key product of attention + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`jnp.ndarray`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`jnp.dtype`, *required*): + dtype of the output tensor + return_torch_like (`bool`, *optional, defaults to `False`*): + Whether to return in the same format as pytorch `(batch_size * num_heads, 1, seq_length)` + """ + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32) + powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32) + slopes = jax.lax.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32) + slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcast correctly + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 - # batch_size = 1, n_head = n_head, query_length - batch_size, key_length = attention_mask.shape - num_heads = n_head - query_length = 1 - - slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) - arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 - - slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) - arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) - - alibi = slopes_broadcast * arange_broadcast - return alibi + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if return_torch_like: + alibi = jnp.reshape(alibi, (batch_size * num_heads, 1, seq_length)) + else: + alibi = jnp.expand_dims(alibi, axis=2) + return jnp.asarray(alibi, dtype) class FlaxBloomAttention(nn.Module): @@ -558,9 +604,6 @@ def __call__( all_hidden_states = () if output_hidden_states else None if self.use_scan: - # since all decoder layers are the same, we use nn.scan directly - # assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" - # assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" hidden_states = (hidden_states,) hidden_states, _ = self.scan_fn( diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index b8033259b0d2..e2208e58084c 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -261,7 +261,7 @@ def test_flax_torch_alibi(self): num_attention_heads = 16 alibi = build_alibi_tensor(torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16) - alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype)[0] + alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype, return_torch_like=True) self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) @@ -272,7 +272,9 @@ def test_alibi_padding(self): single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) num_attention_heads = 16 - alibi_padd = build_alibi_tensor_flax(batch_attention_mask, num_attention_heads, dtype) - alibi_simple = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype) + alibi_padd = build_alibi_tensor_flax(batch_attention_mask, num_attention_heads, dtype, return_torch_like=True) + alibi_simple = build_alibi_tensor_flax( + single_attention_mask, num_attention_heads, dtype, return_torch_like=True + ) - self.assertTrue(jnp.equal(alibi_simple[:, :, :, :2], alibi_padd[1][:, :, 3:]).all()) + self.assertTrue(jnp.equal(alibi_simple[:, :, :2], alibi_padd[16:, :, 3:]).all()) From 3ce584b0be72489cb565b9da3c507a15365af4e8 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 19:00:11 +0100 Subject: [PATCH 67/81] remove scan --- .../models/bloom/modeling_flax_bloom.py | 68 +++++-------------- .../models/bloom/test_modeling_flax_bloom.py | 9 --- 2 files changed, 18 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 46c249a88dc9..ef35251b8427 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 HuggingFace Inc. team and Bigscience Workshop. All rights reserved. +# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -382,7 +382,6 @@ def __call__(self, hidden_states, residual, deterministic: bool = True): class FlaxBloomBlock(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self): self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) @@ -404,10 +403,8 @@ def __call__( init_cache: bool = False, output_attentions: bool = False, ): - if self.use_scan: - hidden_states = hidden_states[0] - layernorm_output = self.input_layernorm(hidden_states) + # layer norm before saving residual if config calls for it if self.apply_residual_connection_post_layernorm: residual = layernorm_output @@ -441,9 +438,6 @@ def __call__( outputs = (output,) + outputs - if self.use_scan: - outputs = (outputs, None) - return outputs @@ -464,10 +458,9 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, - use_scan: bool = False, **kwargs, ): - module = self.module_class(config=config, dtype=dtype, use_scan=use_scan, **kwargs) + module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: @@ -574,22 +567,13 @@ def __call__( class FlaxBloomBlockCollection(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self): self.layers = [ - FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype, use_scan=False) + FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype) for layer_number in range(self.config.num_hidden_layers) ] - self.scan_fn = scan_with_axes( - FlaxBloomBlock, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), - length=self.config.num_hidden_layers, - )(self.config, dtype=self.dtype, use_scan=True, name="FlaxBloomBlockLayers") - def __call__( self, hidden_states, @@ -603,36 +587,22 @@ def __call__( all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.use_scan: - hidden_states = (hidden_states,) + for layer_number in range(self.config.num_hidden_layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) - hidden_states, _ = self.scan_fn( + layer_outputs = self.layers[layer_number]( hidden_states, - alibi, - attention_mask, # kwargs not supported by scan - deterministic, - init_cache, - output_attentions, + alibi=alibi, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, ) - hidden_states = hidden_states[0] + hidden_states = layer_outputs[0] - else: - for layer_number in range(self.config.num_hidden_layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = self.layers[layer_number]( - hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) + if output_attentions: + all_attentions += (layer_outputs[1],) # this contains possible `None` values - `FlaxBloomModule` will filter them out outputs = (hidden_states, all_hidden_states, all_attentions) @@ -643,7 +613,6 @@ def __call__( class FlaxBloomModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self): self.embed_dim = self.config.hidden_size @@ -660,7 +629,7 @@ def setup(self): self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) # transformer layers - self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype) # final layernorm self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) @@ -728,10 +697,9 @@ class FlaxBloomModel(FlaxBloomPreTrainedModel): class FlaxBloomForCausalLMModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self): - self.transformer = FlaxBloomModule(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.transformer = FlaxBloomModule(self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index e2208e58084c..c459ef955b11 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -240,15 +240,6 @@ def test_model_batched_padding_left(self): self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist()) - def test_scan_model(self): - scan_model = FlaxBloomForCausalLM.from_pretrained("sanchit-gandhi/bloom-350m-scan", use_scan=True) - input_ids = np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32) - - unrolled_logits = self.model(input_ids).logits - scan_logits = scan_model(input_ids).logits - - self.assertTrue(np.max(np.abs(unrolled_logits - scan_logits)) <= 1e-3) - @require_torch @is_pt_flax_cross_test From acda19b138b277079d0e7198902746bd47e3aec3 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 19:24:12 +0100 Subject: [PATCH 68/81] fix tests --- src/transformers/models/bloom/modeling_flax_bloom.py | 10 ++-------- tests/models/bloom/test_modeling_flax_bloom.py | 3 ++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index ef35251b8427..04947483e1a5 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -24,7 +24,6 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask from flax.linen.activation import tanh -from flax.linen.partitioning import scan_with_axes from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -42,7 +41,6 @@ _CHECKPOINT_FOR_DOC = "bigscience/bloom" _CONFIG_FOR_DOC = "BloomConfig" -_TOKENIZER_FOR_DOC = "BloomTokenizerFast" BLOOM_START_DOCSTRING = r""" @@ -689,9 +687,7 @@ class FlaxBloomModel(FlaxBloomPreTrainedModel): module_class = FlaxBloomModule -append_call_sample_docstring( - FlaxBloomModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC -) +append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) class FlaxBloomForCausalLMModule(nn.Module): @@ -773,6 +769,4 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): return model_kwargs -append_call_sample_docstring( - FlaxBloomForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC -) +append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index c459ef955b11..2c74450aea43 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -19,7 +19,7 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_torch, slow from transformers.utils.import_utils import is_torch_available -from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin +from ...generation.test_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -32,6 +32,7 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" import jax.numpy as jnp + from transformers import FlaxBloomForCausalLM, FlaxBloomModel if is_flax_available() and is_torch_available(): From 933f93f3b998417cfd7476748bb9ecd993fe0db7 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 19:24:16 +0100 Subject: [PATCH 69/81] make style --- src/transformers/modeling_flax_pytorch_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 3e4fb00243c1..b4d52a922087 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -145,6 +145,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): " instructions." ) raise + is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) # noqa: F821 pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()} @@ -208,11 +209,15 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): continue # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + flax_state_dict[("params",) + flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) else: # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + flax_state_dict[flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) return unflatten_dict(flax_state_dict) @@ -350,10 +355,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): pt_model_dict = pt_model.state_dict() load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( - pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) + pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()} ) load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( - pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) + pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()} ) # keep track of unexpected & missing keys From cae64fa1a8c99407d92ce601ae3130bb60c5e91d Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 19:43:29 +0100 Subject: [PATCH 70/81] clean-up alibi --- .../models/bloom/modeling_flax_bloom.py | 73 ++++--------------- 1 file changed, 16 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 04947483e1a5..0e0096484886 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -108,61 +108,23 @@ """ -# def build_alibi_tensor_flax(attention_mask, n_head, dtype): -# def get_slopes(n): -# def get_slopes_power_of_2(n): -# start = 2 ** (-(2 ** -(math.log2(n) - 3))) -# ratio = start -# return [start * ratio**i for i in range(n)] - -# if math.log2(n).is_integer(): -# return get_slopes_power_of_2(n) -# else: -# closest_power_of_2 = 2 ** math.floor(math.log2(n)) -# return ( -# get_slopes_power_of_2(closest_power_of_2) -# + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] -# ) - -# # Note: alibi will be added to the attention bias that is applied to the query, key product of attention -# # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) -# # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) -# # => the query_length dimension will then be broadcast correctly -# # This is more or less identical to T5's relative position bias: -# # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 -# # batch_size = 1, n_head = n_head, query_length -# batch_size, key_length = attention_mask.shape -# num_heads = n_head -# query_length = 1 - -# slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) -# arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 - -# slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) -# arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) - -# alibi = slopes_broadcast * arange_broadcast -# return alibi - - -def build_alibi_tensor_flax(attention_mask, num_heads, dtype, return_torch_like=False): +def build_alibi_tensor(attention_mask: jax.Array, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32): """ - Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - A Flax implementation + Link to paper: https://arxiv.org/abs/2108.12409 Args: - Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) attention_mask (`jnp.ndarray`): - Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - num_heads (`int`, *required*): - number of heads - dtype (`jnp.dtype`, *required*): - dtype of the output tensor - return_torch_like (`bool`, *optional, defaults to `False`*): - Whether to return in the same format as pytorch `(batch_size * num_heads, 1, seq_length)` + Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`. + num_heads (`int`): + Number of attention heads. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The data type (dtype) of the output tensor. + + Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`. """ batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) @@ -176,18 +138,15 @@ def build_alibi_tensor_flax(attention_mask, num_heads, dtype, return_torch_like= extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32) slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0) - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention + # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly + # so that the query_length dimension will then be broadcast correctly. # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :] alibi = slopes[..., None] * arange_tensor - if return_torch_like: - alibi = jnp.reshape(alibi, (batch_size * num_heads, 1, seq_length)) - else: - alibi = jnp.expand_dims(alibi, axis=2) + alibi = jnp.expand_dims(alibi, axis=2) return jnp.asarray(alibi, dtype) @@ -647,7 +606,7 @@ def __call__( hidden_states = self.word_embeddings_layernorm(inputs_embeds) # build alibi depending on `attention_mask` - alibi = build_alibi_tensor_flax(attention_mask, self.config.n_head, hidden_states.dtype) + alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) outputs = self.h( hidden_states, @@ -747,7 +706,7 @@ def __call__( class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): module_class = FlaxBloomForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape From 7171a77af48ed7ddf0204359d7fd792a23124f6e Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 20:03:55 +0100 Subject: [PATCH 71/81] add integration tests --- tests/models/bloom/test_modeling_bloom.py | 25 +++++++++++++++++++ .../models/bloom/test_modeling_flax_bloom.py | 17 +++++++++++++ 2 files changed, 42 insertions(+) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index f7ef199febdb..a1f9d708197d 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -487,6 +487,31 @@ def test_batch_generation_padd(self): tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True), ) + @slow + @require_torch_gpu + def test_batch_generated_text(self): + path_560m = "bigscience/bloom-560m" + + model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").cuda() + model = model.eval() + tokenizer = BloomTokenizerFast.from_pretrained(path_560m, padding_side="left") + + input_sentences = [ + "Hello what is", + "Running a quick test with the", + ] + inputs = tokenizer(input_sentences, return_tensors="pt", padding=True, truncation=True) + generated_ids = model.generate(inputs["input_ids"].cuda(), attention_mask=inputs["attention_mask"], max_length=20) + generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # these generations match those of the PyTorch model + EXPECTED_GENERATIONS = [ + "Hello what is the best way to get the data from the server? I have tried", + "Running a quick test with the following command:\nsudo apt-get install python3\nsudo apt-get install python2", + ] + + self.assertListEqual(generated_text, EXPECTED_GENERATIONS) + @require_torch class BloomEmbeddingTest(unittest.TestCase): diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 2c74450aea43..09814589c955 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -241,6 +241,23 @@ def test_model_batched_padding_left(self): self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist()) + def test_batch_generated_text(self): + input_sentences = [ + "Hello what is", + "Running a quick test with the", + ] + inputs = self.tokenizer(input_sentences, return_tensors="np", padding=True, truncation=True) + generated_ids = self.model.generate(**inputs, max_length=20).sequences + generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # these generations match those of the PyTorch model, ensuring correctness + EXPECTED_GENERATIONS = [ + "Hello what is the best way to get the data from the server? I have tried", + "Running a quick test with the following command:\nsudo apt-get install python3\nsudo apt-get install python2", + ] + + self.assertListEqual(generated_text, EXPECTED_GENERATIONS) + @require_torch @is_pt_flax_cross_test From dbc363f8683ee2f2fe869ca756aaff55833b886b Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 20:04:01 +0100 Subject: [PATCH 72/81] up --- docs/source/en/index.md | 2 +- src/transformers/models/bloom/modeling_flax_bloom.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 5063df249487..62f0469aa097 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -300,7 +300,7 @@ Flax), PyTorch, and/or TensorFlow. | BlenderbotSmall | ✅ | ✅ | ✅ | | BLIP | ✅ | ✅ | ❌ | | BLIP-2 | ✅ | ❌ | ❌ | -| BLOOM | ✅ | ❌ | ❌ | +| BLOOM | ✅ | ❌ | ✅ | | BridgeTower | ✅ | ❌ | ❌ | | CamemBERT | ✅ | ✅ | ❌ | | CANINE | ✅ | ❌ | ❌ | diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 0e0096484886..8271e2a7e44f 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -108,7 +108,7 @@ """ -def build_alibi_tensor(attention_mask: jax.Array, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32): +def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32): """ Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value From 2fe409edf3444a7cbf2b71004649085e26a26620 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 20:09:51 +0100 Subject: [PATCH 73/81] fix batch norm conversion --- .../modeling_flax_pytorch_utils.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index b4d52a922087..31af8983dab3 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -339,7 +339,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): raise # check if we have bf16 weights - is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. @@ -347,7 +347,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) - flax_state = jax.tree_map( + flax_state = jax.tree_util.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) @@ -387,7 +387,34 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): elif flax_key_tuple[-1] in ["scale", "embedding"]: flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_key = ".".join(flax_key_tuple) + # adding batch stats from flax batch norm to pt + elif "mean" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) + elif "var" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) + + if "batch_stats" in flax_state: + flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header + else: + flax_key = ".".join(flax_key_tuple) + + # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. + special_pt_names = {} + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + for key in pt_model_dict: + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + key_to_check = ".".join(key_components) + special_pt_names[key_to_check] = key + + if flax_key in special_pt_names: + flax_key = special_pt_names[flax_key] if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: From f64960e7de010f16a0e8f247504db91e0dc2a95a Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 20:09:55 +0100 Subject: [PATCH 74/81] style --- tests/models/bloom/test_modeling_bloom.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index a1f9d708197d..4e9b837c8adc 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -501,7 +501,9 @@ def test_batch_generated_text(self): "Running a quick test with the", ] inputs = tokenizer(input_sentences, return_tensors="pt", padding=True, truncation=True) - generated_ids = model.generate(inputs["input_ids"].cuda(), attention_mask=inputs["attention_mask"], max_length=20) + generated_ids = model.generate( + inputs["input_ids"].cuda(), attention_mask=inputs["attention_mask"], max_length=20 + ) generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # these generations match those of the PyTorch model From 06fcc0620c7d5fe491ce4da2e9e8133afd3316b4 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 25 Jul 2023 20:20:44 +0100 Subject: [PATCH 75/81] style --- docs/source/en/model_doc/bloom.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/en/model_doc/bloom.md b/docs/source/en/model_doc/bloom.md index 3fba4af3de22..3c155fa58782 100644 --- a/docs/source/en/model_doc/bloom.md +++ b/docs/source/en/model_doc/bloom.md @@ -85,3 +85,13 @@ See also: [[autodoc]] BloomForQuestionAnswering - forward + +## FlaxBloomModel + +[[autodoc]] FlaxBloomModel + - __call__ + +## FlaxBloomForCausalLM + +[[autodoc]] FlaxBloomForCausalLM + - __call__ From 85862924a50c347957f3a0c676c1ac28a808f2bd Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 26 Jul 2023 09:48:54 +0100 Subject: [PATCH 76/81] update pt-fx cross tests --- .../models/bloom/test_modeling_flax_bloom.py | 39 +------------------ 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 09814589c955..81484ca01bda 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -13,11 +13,10 @@ # limitations under the License. import unittest -import numpy as np # noqa +import numpy as np from transformers import BloomConfig, BloomTokenizerFast, is_flax_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_torch, slow -from transformers.utils.import_utils import is_torch_available +from transformers.testing_utils import require_flax, slow from ...generation.test_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -35,10 +34,6 @@ from transformers import FlaxBloomForCausalLM, FlaxBloomModel -if is_flax_available() and is_torch_available(): - from transformers.models.bloom.modeling_bloom import build_alibi_tensor - from transformers.models.bloom.modeling_flax_bloom import build_alibi_tensor_flax - def prepare_bloom_inputs_dict(config, input_ids, attention_mask=None): if attention_mask is None: @@ -257,33 +252,3 @@ def test_batch_generated_text(self): ] self.assertListEqual(generated_text, EXPECTED_GENERATIONS) - - -@require_torch -@is_pt_flax_cross_test -class FlaxBloomConversionTest(unittest.TestCase): - def test_flax_torch_alibi(self): - import torch - - dtype = jnp.float16 - single_attention_mask = np.array([[1, 1, 1, 1, 1]]) - num_attention_heads = 16 - - alibi = build_alibi_tensor(torch.from_numpy(single_attention_mask), num_attention_heads, torch.float16) - alibi_flax = build_alibi_tensor_flax(single_attention_mask, num_attention_heads, dtype, return_torch_like=True) - - self.assertTrue(jnp.equal(alibi_flax, alibi.numpy()).all()) - - def test_alibi_padding(self): - dtype = jnp.bfloat16 - - batch_attention_mask = jnp.array([[1, 1, 1, 1, 1], [0, 0, 0, 1, 1]]) - single_attention_mask = jnp.array([[1, 1, 1, 1, 1]]) - num_attention_heads = 16 - - alibi_padd = build_alibi_tensor_flax(batch_attention_mask, num_attention_heads, dtype, return_torch_like=True) - alibi_simple = build_alibi_tensor_flax( - single_attention_mask, num_attention_heads, dtype, return_torch_like=True - ) - - self.assertTrue(jnp.equal(alibi_simple[:, :, :2], alibi_padd[16:, :, 3:]).all()) From b0dd5ebfa3aa01d118bb8e75fe1707119c07a8a3 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 26 Jul 2023 09:49:15 +0100 Subject: [PATCH 77/81] update copyright --- tests/models/bloom/test_modeling_flax_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index 81484ca01bda..fecf6015591e 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e9b85e5ca3eea89b57a7e4bc7b34d7a572cd6b0d Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 27 Jul 2023 15:57:05 +0100 Subject: [PATCH 78/81] Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_flax_pytorch_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 31af8983dab3..ad267de3ec03 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -146,8 +146,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ) raise - is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) # noqa: F821 - pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()} + pt_state_dict = {k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix From acfc88383086159b19312109dcbaf0de1f38bf14 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 27 Jul 2023 16:01:35 +0100 Subject: [PATCH 79/81] per-weight check --- src/transformers/modeling_flax_pytorch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index ad267de3ec03..ac390e20d711 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -135,7 +135,7 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy - # numpy currently does not support bfloat16, need to go over float32 in this case to not loose precision + # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision try: import torch # noqa: F401 except ImportError: @@ -146,6 +146,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ) raise + weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix @@ -174,6 +175,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) + is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 # remove base model prefix if necessary has_base_model_prefix = pt_tuple_key[0] == model_prefix From 27d212c68ee4da1695943e9f99e86031845186c3 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 27 Jul 2023 16:04:30 +0100 Subject: [PATCH 80/81] style --- src/transformers/modeling_flax_pytorch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index ac390e20d711..79d91da49729 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -147,7 +147,9 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): raise weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} - pt_state_dict = {k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()} + pt_state_dict = { + k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() + } model_prefix = flax_model.base_model_prefix From b6d8dc6b76ba1b979fea8045d4ab17e3d0f959c3 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 27 Jul 2023 16:09:17 +0100 Subject: [PATCH 81/81] line formats --- src/transformers/models/bloom/__init__.py | 6 +----- .../models/bloom/modeling_flax_bloom.py | 15 +++++++++------ .../models/gptj/modeling_flax_gptj.py | 3 ++- tests/models/bloom/test_modeling_flax_bloom.py | 5 +---- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index 27e914aea306..32e8617e8270 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -24,11 +24,7 @@ _import_structure = { - "configuration_bloom": [ - "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", - "BloomConfig", - "BloomOnnxConfig", - ], + "configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"], } try: if not is_tokenizers_available(): diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py index 8271e2a7e44f..187230f35ab9 100644 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ b/src/transformers/models/bloom/modeling_flax_bloom.py @@ -207,7 +207,8 @@ def _concatenate_to_cache(self, key, value, query, attention_mask): cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + # causal mask for cached decoder self-attention: our single query position should only attend to those key + # positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), @@ -488,8 +489,9 @@ def __call__( inputs = {"params": params or self.params} - # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. - # It has to be made sure that cache is marked as mutable so that it can be changed by FlaxBloomAttention module + # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBloomAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] @@ -711,9 +713,10 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Bloom uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation + # Note that usually one would have to put 0's in the attention_mask for + # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask, + # those positions are masked anyway. Thus, we can create a single static attention_mask here, + # which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 6270355129ff..8ec53aec46d5 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -194,7 +194,8 @@ def _concatenate_to_cache(self, key, value, query, attention_mask): cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + # causal mask for cached decoder self-attention: our single query position should only attend to those key + # positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), diff --git a/tests/models/bloom/test_modeling_flax_bloom.py b/tests/models/bloom/test_modeling_flax_bloom.py index fecf6015591e..0e49039afe59 100644 --- a/tests/models/bloom/test_modeling_flax_bloom.py +++ b/tests/models/bloom/test_modeling_flax_bloom.py @@ -38,10 +38,7 @@ def prepare_bloom_inputs_dict(config, input_ids, attention_mask=None): if attention_mask is None: attention_mask = np.where(input_ids != config.pad_token_id, 1, 0) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - } + return {"input_ids": input_ids, "attention_mask": attention_mask} @require_flax