Skip to content

Conversation

@ekgren
Copy link
Contributor

@ekgren ekgren commented Nov 14, 2022

This adds the gpt-sw3 models and tokenizer to hf. The models are developed by AI Sweden and others. They are gpt models trained from scratch with the nemo-megatron framework and will initially range in sizes from 128m to 20B. The models are multilingual and the languages in the models are English, Swedish, Norwegian, Danish and Icelandic.

Fixes # (issue) #20176

@ArthurZucker

@ekgren ekgren marked this pull request as draft November 14, 2022 14:04
@ArthurZucker
Copy link
Collaborator

Hey! Feel free to ping me if you need any pointers! :)
5seems like the history is a bit broken at this point rebasing with a force push should help.

@JoeyOhman
Copy link
Contributor

Actually, it seems like the modeling code is exactly the same as for GPT2? In this case you can just set in the auto-mappings a correspondance ("gpt-sw3", "GPT2Model") without needing to add a new model module.

Thank you for your feedback, we're happy to follow your lead on how to proceed! So, if we understand you correctly, we should then remove modeling_gpt_sw3.py, configuration_gpt_sw3.py entirely?

Yep sorry for the late reply! Let's do the same as what was done with BertJapanese. I'll review again sorry for not realising sooner # Copied from sweat

Should we await further review or simply get started on this?

@sgugger @ArthurZucker

@sgugger
Copy link
Collaborator

sgugger commented Dec 9, 2022

Yes, that would be easier. Just remove the model and config files and in the auto mapping, use the GPT2 classes.

("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to map all the other model with heads too. Also you get the TF and Flax models for free if you add the same things in modeling_tf_auto and modeling_flax_auto :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, hopefully fixed now!

@JoeyOhman
Copy link
Contributor

Thank you again for your help, I hope we have now resolved all of your issues. Do you see anything else required from our side in this PR? @sgugger @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well great work! It's super clean and LGTM, let's wait for @sgugger and I think we can merge!

output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay good job, exactly what needed to be done 😄

def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 2_000)

# TODO: these tests will differ with our 2 tokenizers, might be able to hard-code it for one
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the TODO)

- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very small nitm, would be cool to have an example here of importing the tokenizer and tokenization of a Swedish sentence! 😉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, added! :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bearing with us, it all looks good to me!

@ArthurZucker
Copy link
Collaborator

A last nit @JoeyOhman , could you add an example of a pretrained model that you released being loaded in the doc? Like what is done with BertJapanese here. Would help people to understand that they can use the GPT2 model with this tokenizer 😉

@sgugger sgugger merged commit 5f94855 into huggingface:main Dec 12, 2022
@ekgren ekgren deleted the add_gpt_sw3 branch December 13, 2022 09:30
@ArthurZucker
Copy link
Collaborator

Hey @ekgren could you add the correct checkpoints? They are probably private.
See our CI fail here

@ekgren
Copy link
Contributor Author

ekgren commented Dec 14, 2022

@sgugger @ArthurZucker Thank you for all the help and guidance! We have made all the tokenizers reffered to in the PR public.

We encountered some internal issues with the model sharing in the last minute, very sorry for that. Currently we are not allowed to share the model files publicly. However we can share the tokenizer and would very much like for it to be included in huggingface, since those with private access to the model easily can use the full hf ecosystem. We hope to be able to share the models fully public in the near future.

Hopefully our PR can still be included in the release now that the tests should pass.

@ArthurZucker
Copy link
Collaborator

No problem, I was thinking about the tokenizer rather than the actual checkpoints! You were mostly adding a tokenizer so I don't really see an issue with this 😉 Thanks for the contribution!

mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* Add templates for gpt-sw3

* Add templates for gpt-sw3

* Added sentencepiece tokenizer

* intermediate commit with many changes

* fixed conflicts

* Init commit for tokenization port

* Tokenization progress

* Remove fast tokenizer

* Clean up and rename spm.model -> spiece.model

* Remove TF -> PT conversion script template, Clean up Megatron -> PT script

* Optimize encode & decode performance

* added new attention

* added new attention

* attention for gpt-sw3 working

* attention good

* Cache is now working

* fixed attention mask so that it works with causal attention

* fixed badbmm bug for cpu and caching

* updated config with correct parameters

* Refactor and leave optimizations as separate functions to avoid breaking expected functionality

* Fix special tokens mapping for both tokenizers

* cleaning up of code and comments

* HF compatible attention outputs

* Tokenizer now passing tests, add documentation

* Update documentation

* reverted back to base implementation after checking that it is identical to pretrained model

* updated gpt-sw3 config

* updated conversion script

* aligned parameters with gpt-sw3 config

* changed default scale_attn_by_inverse_layer_idx to true

* removed flag from conversion script

* added temporary model path

* reverted back to functioning convert script

* small changes to default config

* updated tests for gpt-sw3

* make style, make quality, minor cleanup

* Change local paths to testing online repository

* Change name: GptSw3 -> GPTSw3

* Remove GPTSw3TokenizerFast references

* Use official model repository and add more model sizes

* Added reference to 6.7b model

* Add GPTSw3DoubleHeadsModel to IGNORE_NON_AUTO_CONFIGURED, like GPT2DoubleHeadsModel

* Remove pointers to non-existing TFGPTSw3

* Add GPTSw3 to docs/_toctree.yml

* Remove TF artifacts from GPTSw3 in __init__ files

* Update README:s with 'make fix-copies'

* Add 20b model to archive list

* Add documentation for GPT-Sw3

* Fix typo in documentation for GPT-Sw3

* Do 'make fix-copies' again after having updated docs

* Fix some typos in docs

* Update src/transformers/models/gpt_sw3/configuration_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/configuration_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/__init__.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/__init__.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/modeling_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Update tests/models/gpt_sw3/test_tokenization_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/modeling_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/modeling_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Resolve comments from PR feedback

* Resolve more comments from PR feedback, also set use_cache=True in convert script

* Add '# Copied from' comments for GPTSw3 modeling

* Set 'is_parallelizable = False'

* Remove '# Copied from' where code was modified and add 'with x->y' when appropriate

* Remove parallelize in mdx

* make style, make quality

* Update GPTSw3Config default values and corresponding documentation

* Update src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Update src/transformers/models/gpt_sw3/__init__.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Clean up and protect GPTSw3Tokenizer imports with is_sentencepiece_available

* Make style, make quality

* Add dummy object for GPTSw3Tokenizer via 'make fix-copies'

* make fix-copies

* Remove GPTSw3 modeling classes

* make style, make quality

* Add GPTSw3 auto-mappings for other GPT2 heads

* Update docs/source/en/model_doc/gpt-sw3.mdx

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py

Co-authored-by: Arthur <[email protected]>

* Remove old TODO-comment

* Add example usage to GPTSw3Tokenizer docstring

* make style, make quality

* Add implementation details and example usage to gpt-sw3.mdx

Co-authored-by: JoeyOhman <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants