Skip to content

Add TF port of BLIP#22090

Merged
Rocketknight1 merged 57 commits into
mainfrom
add_tf_blip
Apr 4, 2023
Merged

Add TF port of BLIP#22090
Rocketknight1 merged 57 commits into
mainfrom
add_tf_blip

Conversation

@Rocketknight1

Copy link
Copy Markdown
Member

Work in progress right now, will update this when it's closer to being ready!

@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Mar 10, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1 Rocketknight1 marked this pull request as ready for review March 24, 2023 14:15
@Rocketknight1

Copy link
Copy Markdown
Member Author

The TF port is mostly complete now and tests are passing locally - I just need to go around updating docs and auto classes and so on. The main code should be ready for review!

@Rocketknight1 Rocketknight1 force-pushed the add_tf_blip branch 2 times, most recently from fb88fd4 to 120d189 Compare March 24, 2023 15:26

@amyeroberts amyeroberts left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 🔥 Thanks for adding this model!

Mostly nits. The main comment is question whether TODO comments have been resolved.

Comment thread utils/check_repo.py
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py

@gante gante left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, this is a long one -- good job on getting it to the finish line! It should be close to a mergeable state.

A few general comments:

  1. Missing: new modeling files in doctests;
  2. The PR has some minor issues that came from the PT implementation. It would be nice to correct them as well!
  3. The training argument is missing 😱 It needs to be added throughout the code.

("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing a few auto classes -- also missing on the PT side!

Comment thread src/transformers/modeling_tf_utils.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip.py
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip.py
Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
@sgugger

sgugger commented Mar 27, 2023

Copy link
Copy Markdown
Collaborator

Looks like there are many comments to address for now. Please ping me again when it's ready for second review!

@Rocketknight1

Copy link
Copy Markdown
Member Author

Got through a lot of the comments today, but I have a couple of other things to do - will try to finish them tomorrow!

@Rocketknight1

Copy link
Copy Markdown
Member Author

The last remaining big issue is that some of the pt-tf equivalence tests fail when weights don't match up between models. This is caused by the cross-attention weights not being built, presumably because those layers aren't being called in the forward pass. I'm working on figuring out why and resolving that!

@Rocketknight1

Copy link
Copy Markdown
Member Author

The issue seems to be that in all of our other models, cross-attention layers are only added when config.add_cross_attention is True, but in the case of BLIP it only checks config.is_decoder. As a result, the PyTorch models often initialize cross-attention layers that aren't used, which causes weight mismatch issues for us in crossloading tests, because TF only creates weights on first use.

@gante gante left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Two high-level items from the previous review remaining:

  1. Missing: new modeling files in doctests;
  2. The training argument is missing 😱 It needs to be added throughout the code.

@Rocketknight1

Copy link
Copy Markdown
Member Author

It's coming, don't worry! This cross-attention behaviour is just very odd and I'm trying to track it down first

Rocketknight1 and others added 13 commits March 30, 2023 13:48
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@Rocketknight1

Copy link
Copy Markdown
Member Author

Hi all! I've addressed all comments and local tests look good. The remaining issues are:

  • Converting checkpoints so the tests don't need from_pt
  • Maybe adding more auto classes

I'm not sure about the auto classes, though - they're missing in the original PT version of the model as well, so this didn't seem like the right PR to add them.

@Rocketknight1

Copy link
Copy Markdown
Member Author

cc @sgugger - I think this is ready for a final review at last!

@gante gante left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has my blessing 🪄

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR! It sadly cannot be merged until the pt/tf equivalence tests are all passing. There is no model tester that skips them in the code base, so let's not BLIP be the first one.

If the fact BLIP is an encoder/decoder make changes necessary to the base tests in the model tester classes. The test can be overwritten.

Comment thread src/transformers/modeling_tf_utils.py Outdated
"""
return cls(config, **kwargs)

def invert_attention_mask(self, encoder_attention_mask: tf.Tensor) -> tf.Tensor:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not use the state, so better put this as a function in tf_utils. (same for the other two below)

We should probably cleanup the PyTorch side to do the same.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I didn't touch the PyTorch side yet because that's a bigger refactor that touches several models, but I can do it in another PR after this if you want.

Comment thread src/transformers/models/blip/modeling_tf_blip.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread src/transformers/models/blip/modeling_tf_blip_text.py Outdated
Comment thread tests/models/blip/test_modeling_blip.py Outdated
Comment on lines +345 to +347
@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):
pass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be rewritten then. We cannot skip this test.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-enabled!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not.

self.assertIsNotNone(model)

@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-enabled!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same it is not.

self.assertIsNotNone(model)

@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same there

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-enabled!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one isn't either.

Comment thread tests/models/blip/test_modeling_tf_blip.py Outdated
@Rocketknight1

Copy link
Copy Markdown
Member Author

Got it, I'll figure out some way to re-enable those tests, or override them with versions that do work!

@Rocketknight1

Copy link
Copy Markdown
Member Author

@sgugger this should be ready for review with all comments addressed! The failing test is in an unrelated model

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some equivalence tests missing.

self._override_model_class = override_model_class

def get_inputs(self, pt_model, config):
def get_inputs(self, pt_model, tf_dummy_inputs, config):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here seem unrelated to this PR and would be better in their own PR, no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair! I added them because they were needed for the pt-to-tf code to port the BLIP models correctly. If you'd rather I move them to a separate PR though, that's fine!

Comment thread tests/models/blip/test_modeling_blip.py Outdated
Comment on lines +345 to +347
@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):
pass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not.

self.assertIsNotNone(model)

@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same it is not.

self.assertIsNotNone(model)

@unittest.skip(reason="This test class covers encoder-decoder models that the base test does not work with.")
def test_pt_tf_model_equivalence(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one isn't either.

@Rocketknight1

Copy link
Copy Markdown
Member Author

@sgugger Sorry for the confusion - that equivalence test is present in both the test_modeling_tf_blip and test_modeling_blip file. Do we want to keep it in both?

@sgugger

sgugger commented Apr 4, 2023

Copy link
Copy Markdown
Collaborator

Yes we do.

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing tests are unrelated.

@gante gante left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(pt-to-tf changes LGTM 👍 )

@Rocketknight1

Copy link
Copy Markdown
Member Author

Going to leave the pt-to-tf changes in this PR rather than making a separate one, since they're needed for proper BLIP conversion!

@Rocketknight1 Rocketknight1 merged commit 5f3ea66 into main Apr 4, 2023
@Rocketknight1 Rocketknight1 deleted the add_tf_blip branch April 4, 2023 15:05
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants