Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
09e9221
First commit
younesbelkada Jul 5, 2022
5c21527
step 1 working
younesbelkada Jul 5, 2022
794a508
add alibi
younesbelkada Jul 5, 2022
8416c37
placeholder for `scan`
Jul 5, 2022
2eaa092
add matrix mult alibi
younesbelkada Jul 5, 2022
274767e
beta scaling factor for bmm
Jul 5, 2022
f923d51
working v1 - simple forward pass
younesbelkada Jul 5, 2022
2425e28
move layer_number from attribute to arg in call
Jul 5, 2022
8caac99
partial functioning scan
Jul 5, 2022
91b7f75
hacky working scan
Jul 5, 2022
2530051
add more modifs
younesbelkada Jul 5, 2022
448c1e9
add test
younesbelkada Jul 5, 2022
3aed702
update scan for new kwarg order
Jul 5, 2022
9ad912c
fix position_ids problem
patrickvonplaten Jul 5, 2022
fc83079
fix bug in attention layer
patrickvonplaten Jul 5, 2022
541429a
small fix
younesbelkada Jul 6, 2022
40ff6df
prelim refactor
Jul 6, 2022
8d1f137
finish refactor
Jul 6, 2022
22796ae
alibi shifting
younesbelkada Jul 6, 2022
2917cac
incorporate dropout_add to attention module
Jul 6, 2022
a6669b4
make style
patrickvonplaten Jul 6, 2022
312a2f1
make padding work again
patrickvonplaten Jul 6, 2022
143a135
update
patrickvonplaten Jul 6, 2022
e2b67aa
remove bogus file
patrickvonplaten Jul 6, 2022
c499433
up
patrickvonplaten Jul 6, 2022
01e01a3
get generation to work
patrickvonplaten Jul 6, 2022
d1329c9
clean code a bit
patrickvonplaten Jul 6, 2022
27b4bdb
added small tests
younesbelkada Jul 6, 2022
b61bbad
adding albii test
younesbelkada Jul 6, 2022
734884b
make CI tests pass:
younesbelkada Jul 7, 2022
b717b39
fix few nits
younesbelkada Jul 7, 2022
fdf4392
fix nit onnx
younesbelkada Jul 7, 2022
7dd7e64
fix onnx nit
younesbelkada Jul 7, 2022
46a685d
add missing dtype args to nn.Modules
Jul 7, 2022
f34c854
remove debugging statements
Jul 7, 2022
b83009e
fix scan generate
Jul 8, 2022
362ff03
Update modeling_flax_bloom.py
younesbelkada Jul 18, 2022
5652b05
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
9e95e31
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
627e8e4
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
77dfc0a
fix small test issue + make style
younesbelkada Jul 18, 2022
70b57ef
clean up
Jul 27, 2022
e101263
Update tests/models/bloom/test_modeling_flax_bloom.py
younesbelkada Jul 27, 2022
223c062
fix function name
younesbelkada Jul 27, 2022
a359fc6
small fix test
younesbelkada Jul 27, 2022
e838fca
forward contrib credits from PR17761
haileyschoelkopf Jul 28, 2022
49bf9e0
Fix failing test
younesbelkada Jul 29, 2022
5b54e4c
fix small typo documentation
younesbelkada Jul 29, 2022
bb2e0e4
fix non passing test
younesbelkada Aug 10, 2022
78c3121
refactor call
younesbelkada Sep 12, 2022
c23221e
make style
younesbelkada Sep 12, 2022
a2c0e98
upcast to fp32
younesbelkada Sep 12, 2022
bdda3fa
cleaner way to upcast
younesbelkada Sep 12, 2022
d8ddc11
remove unused args
younesbelkada Sep 12, 2022
c4866ec
remove layer number
younesbelkada Sep 12, 2022
12ba14b
fix scan test
younesbelkada Sep 12, 2022
dfe1697
make style
younesbelkada Sep 12, 2022
f587045
fix i4 casting
younesbelkada Sep 13, 2022
0a7fdc4
fix slow test
younesbelkada Oct 11, 2022
b49957f
Update src/transformers/models/bloom/modeling_flax_bloom.py
younesbelkada Oct 11, 2022
eb023eb
remove `layer_past`
younesbelkada Oct 11, 2022
b5a1c3f
refactor a bit
younesbelkada Oct 11, 2022
961b047
fix `scan` slow test
younesbelkada Oct 11, 2022
0951eac
remove useless import
younesbelkada Oct 11, 2022
1db2649
major changes
younesbelkada Oct 11, 2022
c223094
major refactoring
younesbelkada Oct 11, 2022
3ce584b
remove scan
Jul 25, 2023
acda19b
fix tests
Jul 25, 2023
933f93f
make style
Jul 25, 2023
cae64fa
clean-up alibi
Jul 25, 2023
7171a77
add integration tests
Jul 25, 2023
dbc363f
up
Jul 25, 2023
2fe409e
fix batch norm conversion
Jul 25, 2023
f64960e
style
Jul 25, 2023
06fcc06
style
Jul 25, 2023
8586292
update pt-fx cross tests
Jul 26, 2023
b0dd5eb
update copyright
Jul 26, 2023
e9b85e5
Update src/transformers/modeling_flax_pytorch_utils.py
sanchit-gandhi Jul 27, 2023
acfc883
per-weight check
Jul 27, 2023
27d212c
style
Jul 27, 2023
b6d8dc6
line formats
Jul 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/de/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| BigBird-Pegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | |
| BLOOM | ❌ | ✅ | ✅ | ❌ | |
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| CANINE | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ Flax), PyTorch, and/or TensorFlow.
| BlenderbotSmall | ✅ | ✅ | ✅ |
| BLIP | ✅ | ✅ | ❌ |
| BLIP-2 | ✅ | ❌ | ❌ |
| BLOOM | ✅ | ❌ | |
| BLOOM | ✅ | ❌ | |
| BridgeTower | ✅ | ❌ | ❌ |
| CamemBERT | ✅ | ✅ | ❌ |
| CANINE | ✅ | ❌ | ❌ |
Expand Down
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/bloom.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ See also:

[[autodoc]] BloomForQuestionAnswering
- forward

## FlaxBloomModel

[[autodoc]] FlaxBloomModel
- __call__

## FlaxBloomForCausalLM

[[autodoc]] FlaxBloomForCausalLM
- __call__
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3883,6 +3883,13 @@
"FlaxBlenderbotSmallPreTrainedModel",
]
)
_import_structure["models.bloom"].extend(
[
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]
)
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
Expand Down Expand Up @@ -7263,6 +7270,7 @@
FlaxBlenderbotSmallModel,
FlaxBlenderbotSmallPreTrainedModel,
)
from .models.bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
Expand Down
25 changes: 22 additions & 3 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,21 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:

def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise

weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}

model_prefix = flax_model.base_model_prefix

Expand Down Expand Up @@ -163,6 +177,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
Expand Down Expand Up @@ -197,11 +212,15 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
continue

# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

return unflatten_dict(flax_state_dict)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
Expand Down Expand Up @@ -139,6 +140,7 @@
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gpt-sw3", "FlaxGPT2LMHeadModel"),
("gpt2", "FlaxGPT2LMHeadModel"),
Expand Down
28 changes: 27 additions & 1 deletion src/transformers/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
Expand Down Expand Up @@ -44,6 +50,19 @@
"BloomForQuestionAnswering",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bloom"] = [
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig

Expand Down Expand Up @@ -71,6 +90,13 @@
BloomPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
else:
import sys

Expand Down
Loading