diff --git a/README.md b/README.md index 4dda4a2ae132..ed7b628473b4 100644 --- a/README.md +++ b/README.md @@ -340,6 +340,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. diff --git a/README_es.md b/README_es.md index 5022db49326f..2f03e94c7d82 100644 --- a/README_es.md +++ b/README_es.md @@ -328,6 +328,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. diff --git a/README_hd.md b/README_hd.md index a1cfe94d849f..959705902d0d 100644 --- a/README_hd.md +++ b/README_hd.md @@ -300,6 +300,7 @@ conda install -c huggingface transformers 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (FLAVA: A फाउंडेशनल लैंग्वेज एंड विजन अलाइनमेंट मॉडल) (https://arxiv) साथ वाला पेपर .org/abs/2112.04482) अमनप्रीत सिंह, रोंगहांग हू, वेदानुज गोस्वामी, गुइल्यूम कुएरॉन, वोज्शिएक गालुबा, मार्कस रोहरबैक, और डौवे कीला द्वारा। 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (गूगल रिसर्च से) साथ वाला पेपर [FNet: मिक्सिंग टोकन विद फूरियर ट्रांसफॉर्म्स](https://arxiv.org /abs/2105.03824) जेम्स ली-थॉर्प, जोशुआ आइंस्ली, इल्या एकस्टीन, सैंटियागो ओंटानन द्वारा। 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (सीएमयू/गूगल ब्रेन से) साथ में कागज [फ़नल-ट्रांसफॉर्मर: कुशल भाषा प्रसंस्करण के लिए अनुक्रमिक अतिरेक को छानना](https://arxiv.org/abs/2006.03236) जिहांग दाई, गुओकुन लाई, यिमिंग यांग, क्वोक वी. ले ​​द्वारा रिहाई। +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (KAIST से) साथ वाला पेपर [वर्टिकल कटडेप्थ के साथ मोनोकुलर डेप्थ एस्टीमेशन के लिए ग्लोबल-लोकल पाथ नेटवर्क्स](https:/ /arxiv.org/abs/2201.07436) डोयोन किम, वूंगह्युन गा, प्युंगवान आह, डोंगग्यू जू, सेहवान चुन, जुनमो किम द्वारा। 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (OpenAI से) साथ में दिया गया पेपर [जेनरेटिव प्री-ट्रेनिंग द्वारा भाषा की समझ में सुधार](https://blog .openai.com/language-unsupervised/) एलेक रैडफोर्ड, कार्तिक नरसिम्हन, टिम सालिमन्स और इल्या सुत्स्केवर द्वारा। diff --git a/README_ja.md b/README_ja.md index f36113a634d7..359bc5225d91 100644 --- a/README_ja.md +++ b/README_ja.md @@ -362,6 +362,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (Facebook AI から) Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela から公開された研究論文: [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (Google Research から) James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon から公開された研究論文: [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (CMU/Google Brain から) Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le から公開された研究論文: [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (Microsoft Research から) Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. から公開された研究論文 [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (KAIST から) Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim から公開された研究論文: [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (OpenAI から) Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever から公開された研究論文: [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) diff --git a/README_ko.md b/README_ko.md index 9d26141de03a..d56205257984 100644 --- a/README_ko.md +++ b/README_ko.md @@ -277,6 +277,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. diff --git a/README_zh-hans.md b/README_zh-hans.md index cf9d6f69d590..8cee81be1b58 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -301,6 +301,7 @@ conda install -c huggingface transformers 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (来自 Facebook AI) 伴随论文 [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) 由 Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela 发布。 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (来自 Google Research) 伴随论文 [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) 由 James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon 发布。 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (来自 CMU/Google Brain) 伴随论文 [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) 由 Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le 发布。 +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (来自 Microsoft Research) 伴随论文 [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) 由 Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang 发布。 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (来自 KAIST) 伴随论文 [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) 由 Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim 发布。 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (来自 OpenAI) 伴随论文 [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) 由 Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index d25cbb6057c5..450597932e7a 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -313,6 +313,7 @@ conda install -c huggingface transformers 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. +1. **[GeoV](https://huggingface.co/docs/transformers/main/model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 26dbe5767e56..7a4d980ef5d2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -299,6 +299,8 @@ title: Funnel Transformer - local: model_doc/openai-gpt title: GPT + - local: model_doc/geov + title: GeoV - local: model_doc/gpt_neo title: GPT Neo - local: model_doc/gpt_neox diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 216425b29f62..38c702ef77fd 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -114,6 +114,7 @@ The documentation is organized into five sections: 1. **[FLAVA](model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. +1. **[GeoV](model_doc/geov)** (from ) released with the paper []() by . 1. **[GIT](model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. 1. **[GLPN](model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. @@ -313,6 +314,7 @@ Flax), PyTorch, and/or TensorFlow. | FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ | | FNet | ✅ | ✅ | ✅ | ❌ | ❌ | | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | +| GeoV | ✅ | ❌ | ✅ | ❌ | ❌ | | GIT | ❌ | ❌ | ✅ | ❌ | ❌ | | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | diff --git a/docs/source/en/model_doc/geov.mdx b/docs/source/en/model_doc/geov.mdx new file mode 100644 index 000000000000..e038dac9a965 --- /dev/null +++ b/docs/source/en/model_doc/geov.mdx @@ -0,0 +1,70 @@ + + +# GeoV + +## Overview + +The GeoV model was designed by Georges Harik and uses [Rotary Positional Embeddings with Relative distances (RoPER)](http://research.labml.ai/RoPER.html) by [Georges Hark](https://twitter.com/ghark) and [Varuna Jayasiri](https://twitter.com/vpj). + +[RoPER](http://research.labml.ai/RoPER.html), in addition to using relative positions in the attention score calculation by RoPE embeddings, adds relative positional information explicitly to value embeddings. Specifically, it incorporates the relative positions of the tokens paid attention to. RoPER has given better performance in some algorithmic tasks, and seems comparable to RoPE in language modeling. + +The GeoV tokenizer uses [SentencePiece](https://github.com/google/sentencepiece) [unigram language model](https://arxiv.org/abs/1804.10959) and tokenizes symbols, digits and new line characters separately, in order to achieve better performance on mathematical content and code. + +This model was contributed by [gharik](https://huggingface.co/gharik) and [vpj](https://huggingface.co/vpj). + +We have shared 9B parameter pre-trained model at [GeoV/GeoV-9b](https://huggingface.co/GeoV/GeoV-9b). +We plan to release checkpoints around every 20b tokens trained from here until around 300b tokens. +We will also train smaller and larger versions. Our aim is to have broadly available smaller and larger models. + +The original model code is at [goev-ai/goev](https://github.com/geov-ai/geov). + +## Generation + +The `generate()` method can be used to generate text using GeoV model. + +```python +>>> from transformers import GeoVForCausalLM, GeoVTokenizer + +>>> model = GeoVForCausalLM.from_pretrained("GeoV/GeoV-9b") +>>> tokenizer = GeoVTokenizer.from_pretrained("GeoV/GeoV-9b") + +>>> prompt = "In mathematics, topology is the study of" + +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> gen_tokens = model.generate( +... input_ids, +... do_sample=True, +... temperature=0.9, +... max_length=100, +... ) +>>> gen_text = tokenizer.batch_decode(gen_tokens)[0] +``` + +## GeoVConfig + +[[autodoc]] GeoVConfig + +## GeoVTokenizer + +[[autodoc]] GeoVTokenizer + +## GeoVModel + +[[autodoc]] GeoVModel + - forward + +## GeoVForCausalLM + +[[autodoc]] GeoVForCausalLM + - forward diff --git a/docs/source/en/tasks/language_modeling.mdx b/docs/source/en/tasks/language_modeling.mdx index a5e4ae6bb38e..3ae08a288071 100644 --- a/docs/source/en/tasks/language_modeling.mdx +++ b/docs/source/en/tasks/language_modeling.mdx @@ -34,7 +34,7 @@ Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GeoV](../model_doc/geov), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 75ff4e8345a2..3658525a1405 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -300,6 +300,7 @@ "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], + "models.geov": ["GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP", "GeoVConfig"], "models.gpt_neox_japanese": ["GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXJapaneseConfig"], "models.gpt_sw3": [], "models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], @@ -733,6 +734,7 @@ _import_structure["models.funnel"].append("FunnelTokenizerFast") _import_structure["models.gpt2"].append("GPT2TokenizerFast") _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast") + _import_structure["models.geov"].append("GeoVTokenizer") _import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer") _import_structure["models.herbert"].append("HerbertTokenizerFast") _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") @@ -1659,6 +1661,15 @@ "GPTNeoXPreTrainedModel", ] ) + _import_structure["models.geov"].extend( + [ + "GEOV_PRETRAINED_MODEL_ARCHIVE_LIST", + "GeoVForCausalLM", + "GeoVLayer", + "GeoVModel", + "GeoVPreTrainedModel", + ] + ) _import_structure["models.gpt_neox_japanese"].extend( [ "GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3968,6 +3979,7 @@ from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig + from .models.geov import GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP, GeoVConfig from .models.gpt_neox_japanese import GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXJapaneseConfig from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .models.gptsan_japanese import ( @@ -4366,6 +4378,7 @@ from .models.funnel import FunnelTokenizerFast from .models.gpt2 import GPT2TokenizerFast from .models.gpt_neox import GPTNeoXTokenizerFast + from .models.geov import GeoVTokenizer from .models.gpt_neox_japanese import GPTNeoXJapaneseTokenizer from .models.herbert import HerbertTokenizerFast from .models.layoutlm import LayoutLMTokenizerFast @@ -5131,6 +5144,13 @@ GPTNeoXModel, GPTNeoXPreTrainedModel, ) + from .models.geov import ( + GEOV_PRETRAINED_MODEL_ARCHIVE_LIST, + GeoVForCausalLM, + GeoVLayer, + GeoVModel, + GeoVPreTrainedModel, + ) from .models.gpt_neox_japanese import ( GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST, GPTNeoXJapaneseForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index a086b879650a..5690ddaf63bd 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -84,6 +84,7 @@ gpt2, gpt_neo, gpt_neox, + geov, gpt_neox_japanese, gpt_sw3, gptj, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 03492c0c5aad..4ba3ebaf6cb8 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -93,6 +93,7 @@ ("gpt2", "GPT2Config"), ("gpt_neo", "GPTNeoConfig"), ("gpt_neox", "GPTNeoXConfig"), + ("geov", "GeoVConfig"), ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), ("gptj", "GPTJConfig"), ("gptsan-japanese", "GPTSanJapaneseConfig"), @@ -274,6 +275,7 @@ ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neox", "GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("geov", "GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neox_japanese", "GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptsan-japanese", "GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -456,6 +458,7 @@ ("gpt2", "OpenAI GPT-2"), ("gpt_neo", "GPT Neo"), ("gpt_neox", "GPT NeoX"), + ("geov", "GeoV"), ("gpt_neox_japanese", "GPT NeoX Japanese"), ("gptj", "GPT-J"), ("gptsan-japanese", "GPTSAN-japanese"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c6eb89327941..5fc8e4e36520 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -91,6 +91,7 @@ ("gpt2", "GPT2Model"), ("gpt_neo", "GPTNeoModel"), ("gpt_neox", "GPTNeoXModel"), + ("geov", "GeoVModel"), ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), ("gptj", "GPTJModel"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), @@ -295,6 +296,7 @@ ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), + ("geov", "GeoVForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), @@ -364,6 +366,7 @@ ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), + ("geov", "GeoVForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("llama", "LlamaForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index b23b0c39d621..5212b9a49eff 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -158,6 +158,7 @@ ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("geov", ("GeoVTokenizer", None)), ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)), diff --git a/src/transformers/models/geov/__init__.py b/src/transformers/models/geov/__init__.py new file mode 100644 index 000000000000..49a2f70dbe54 --- /dev/null +++ b/src/transformers/models/geov/__init__.py @@ -0,0 +1,74 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +_import_structure = {"configuration_geov": ["GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP", "GeoVConfig"]} + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_geov"] = ["GeoVTokenizer"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_geov"] = [ + "GEOV_PRETRAINED_MODEL_ARCHIVE_LIST", + "GeoVForCausalLM", + "GeoVLayer", + "GeoVModel", + "GeoVPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_geov import GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP, GeoVConfig + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_geov import GeoVTokenizer + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_geov import ( + GEOV_PRETRAINED_MODEL_ARCHIVE_LIST, + GeoVForCausalLM, + GeoVLayer, + GeoVModel, + GeoVPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/geov/configuration_geov.py b/src/transformers/models/geov/configuration_geov.py new file mode 100644 index 000000000000..e301ad42372b --- /dev/null +++ b/src/transformers/models/geov/configuration_geov.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2023 Better Planet Investments and labml.ai team. ALl rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GeoV model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +GEOV_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "GeoV/GeoV-9b": "https://huggingface.co/GeoV/GeoV-9b/resolve/main/config.json", +} + + +class GeoVConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GeoVModel`]. It is used to instantiate a + GeoV model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the GeoV + [GeoV/GeoV-9b](https://huggingface.co/GeoV/GeoV-9b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the GeoV model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GeoVModel`]. + hidden_size (`int`, *optional*, defaults to 5120): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 40): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 20480): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + rotary_emb_base (`int`, *optional*, defaults to 10000) + base for computing rotary embeddings frequency + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 1e-5): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-4): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + use_extra_biases_ffn (`bool`, *optional*, defaults to `False`): + Whether or not to have extra bias parameters in the final layer of FFN modules. + use_parallel_residual (`bool`, *optional*, defaults to `True`): + Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training + speedup at large scales (e.g. 20B). + Example: + + ```python + >>> from transformers import GeoVConfig, GeoVModel + + >>> # Initializing a GeoV configuration + >>> configuration = GeoVConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GeoVModel(configuration) # doctest: +SKIP + + >>> # Accessing the model configuration + >>> configuration = model.config # doctest: +SKIP + ```""" + model_type = "geov" + + def __init__( + self, + vocab_size=65_536, + hidden_size=5_120, + num_hidden_layers=32, + num_attention_heads=40, + intermediate_size=20_480, + layer_norm_eps=1e-4, + rotary_emb_base=10000, + max_position_embeddings=2048, + initializer_range=0.02, + use_extra_biases_ffn=False, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=False, + use_parallel_residual=False, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.rotary_emb_base = rotary_emb_base + self.use_cache = use_cache + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_extra_biases_ffn = use_extra_biases_ffn + self.use_parallel_residual = use_parallel_residual diff --git a/src/transformers/models/geov/modeling_geov.py b/src/transformers/models/geov/modeling_geov.py new file mode 100644 index 000000000000..4f77e5b91200 --- /dev/null +++ b/src/transformers/models/geov/modeling_geov.py @@ -0,0 +1,698 @@ +# coding=utf-8 +# Copyright 2023 Better Planet Investments and labml.ai team. ALl rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GeoV model.""" +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from .configuration_geov import GeoVConfig +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import logging + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "GeoV/GeoV-9b" +_REAL_CHECKPOINT_FOR_DOC = "GeoV/GeoV-9b" +_CONFIG_FOR_DOC = "GeoVConfig" + +GEOV_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "GeoV/GeoV-9b", + # See all GeoV models at https://huggingface.co/models?filter=geov +] + + +# Copied (and modified) from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding.forward +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self.max_seq_len_cached = -1 + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self.sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + return self.cos_cached.to(x.device), self.sin_cached.to(x.device) + + +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, cos, sin, position_ids): + """Apply positional embeddings""" + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +def apply_rotary_pos_emb_reverse(q, cos, sin, position_ids): + """Apply positional embeddings in reverse""" + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) - (rotate_half(q) * sin) + return q_embed + + +class GeoVAttention(nn.Module): + """ + Attention module + """ + + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + max_positions = config.max_position_embeddings + self.register_buffer("causal_mask", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))) + self.rotary_emb = RotaryEmbedding(self.head_size, base=config.rotary_emb_base) + self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.qkv(hidden_states) + query, key, value = torch.tensor_split(qkv, 3, dim=-1) + + # 'b l (h q) -> b h l q' + query = self._split_heads(query, self.num_attention_heads) + key = self._split_heads(key, self.num_attention_heads) + value = self._split_heads(value, self.num_attention_heads) + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + + cos, sin = self.rotary_emb(query, seq_len=seq_len) + query = apply_rotary_pos_emb(query, cos, sin, position_ids) + key = apply_rotary_pos_emb(key, cos, sin, position_ids) + value = apply_rotary_pos_emb(value, cos, sin, position_ids) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = apply_rotary_pos_emb_reverse(attn_output, cos, sin, position_ids) + + # Reshape outputs + attn_output = self._merge_heads(attn_output) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads): + """ + Splits hidden dim into num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.shape[:-1] + (num_attention_heads, tensor.shape[-1] // num_attention_heads) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor): + """ + Merges heads + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(*tensor.shape[:2], tensor.shape[2] * tensor.shape[3]) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.shape + key_length = key.shape[-2] + + causal_mask = self.causal_mask[None, None, key_length - query_length : key_length, :key_length] + + attn_scores = torch.einsum("bhid,bhjd->bhij", query, key) / math.sqrt(attn_head_size) + + attn_scores.masked_fill_(causal_mask == 0, torch.finfo(attn_scores.dtype).min) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +class GeoVMLP(nn.Module): + """Position wise Feed-forward network""" + + def __init__(self, config: "GeoVConfig"): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_2h_to_h = nn.Linear( + config.intermediate_size // 2, config.hidden_size, bias=config.use_extra_biases_ffn + ) + self.act = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + # Gated GELU + gate, pass_through = torch.tensor_split(hidden_states, 2, dim=-1) + gate = self.act(gate) + hidden_states = gate * pass_through + + hidden_states = self.dense_2h_to_h(hidden_states) + return hidden_states + + +class GeoVLayer(nn.Module): + """GeoV transformer layer""" + + def __init__(self, config): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = GeoVAttention(config) + self.mlp = GeoVMLP(config) + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + layer_past: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + ): + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs + + +GEOV_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~GeoVConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GEOV_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[torch.FloatTensor]]` of length `n_layers`, with each tuple having 2 tensors of shape `(batch_size, n_heads, seq_len - 1, head_size)`, *optional*): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class GeoVPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GeoVConfig + base_model_prefix = "geov" + supports_gradient_checkpointing = True + _no_split_modules = ["GeoVLayer"] + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GeoVModel): + module.gradient_checkpointing = value + + +@add_start_docstrings( + "The bare GeoV Model transformer outputting raw hidden-states without any specific head on top.", + GEOV_START_DOCSTRING, +) +class GeoVModel(GeoVPreTrainedModel): + def __init__(self, config: "GeoVConfig"): + super().__init__(config) + self.config = config + + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([GeoVLayer(config) for _ in range(config.num_hidden_layers)]) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.embed_in.to(torch.bfloat16) + self.layers.to(torch.bfloat16) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_in + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_in = value + + @add_start_docstrings_to_model_forward(GEOV_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + # Copied (and modified) from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * self.config.num_hidden_layers) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for layer_past + return module(*inputs, use_cache, None, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + # Cast the hidden state to final layer norm data type (this is the modification from GPTNeoX) + hidden_states = self.final_layer_norm(hidden_states.to(self.final_layer_norm.weight.dtype)) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +@add_start_docstrings( + """GeoV Model with a `language modeling` head on top for CLM fine-tuning.""", GEOV_START_DOCSTRING +) +class GeoVForCausalLM(GeoVPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"causal_mask", r"inv_freq"] + + def __init__(self, config: "GeoVConfig"): + super().__init__(config) + + self.geov = GeoVModel(config) + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.embed_out + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.embed_out = new_embeddings + + @add_start_docstrings_to_model_forward(GEOV_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Copied (and modified) from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GeoVForCausalLM, GeoVConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("GeoV/GeoV-9b") + >>> model = GeoVForCausalLM.from_pretrained("GeoV/GeoV-9b") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.geov( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): + input_shape = input_ids.shape + + # cut decoder_input_ids if past is used + if past_key_values and past_key_values[0] is not None: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + } + + # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/src/transformers/models/geov/tokenization_geov.py b/src/transformers/models/geov/tokenization_geov.py new file mode 100644 index 000000000000..6fdcbf44084e --- /dev/null +++ b/src/transformers/models/geov/tokenization_geov.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2023 Better Planet Investments and labml.ai team. ALl rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for GeoV.""" +from pathlib import Path +from typing import List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import SPIECE_UNDERLINE, logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "GeoV/GeoV-9b": "https://huggingface.co/GeoV/GeoV-9b/resolve/main/spiece.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "GeoV-9b": 2048, +} + + +class GeoVTokenizer(PreTrainedTokenizer): + """ + Construct an GeoV tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + + new_line_token_id (`int`, *optional*, defaults to `65_499`): + The token id of new line character. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + new_line_token_id=65_499, + **kwargs, + ) -> None: + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + new_line_token_id=new_line_token_id, + **kwargs, + ) + self.vocab_file = vocab_file + self.new_line_token_id = new_line_token_id + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + + @property + def vocab_size(self): + return len(self.sp_model) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Tokenize a string.""" + ret = [] + split_text = text.splitlines() + for l in split_text: + rl = self.sp_model.encode(l, out_type=str) + ret.extend(rl) + ret.append("\n") + ret = ret[:-1] + return ret + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token == "\n": + return self.new_line_token_id + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index == self.new_line_token_id: + return "\n" + return self.sp_model.IdToPiece(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + if skip_special_tokens: + filtered_tokens = [t for t in filtered_tokens if t not in self.all_special_ids] + + text = self.convert_tokens_to_string(filtered_tokens) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + save_directory = Path(save_directory) + if not save_directory.is_dir(): + raise ValueError(f"Vocabulary path ({save_directory}) should be a directory") + vocab_fn = VOCAB_FILES_NAMES["vocab_file"] + filename_prefix = f"{filename_prefix}-" if filename_prefix else "" + + vocab_file = save_directory / f"{filename_prefix}{vocab_fn}" + + with open(str(vocab_file), "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (str(vocab_file),) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 6b246a8463f7..b3a0a5de5395 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3219,6 +3219,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +GEOV_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class GeoVForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GeoVLayer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GeoVModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GeoVPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index ec65edcc1dcf..235b8be69fb5 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -170,6 +170,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class GeoVTokenizer(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class GPTNeoXJapaneseTokenizer(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/models/geov/__init__.py b/tests/models/geov/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/geov/test_modeling_geov.py b/tests/models/geov/test_modeling_geov.py new file mode 100644 index 000000000000..b70fa5740a16 --- /dev/null +++ b/tests/models/geov/test_modeling_geov.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2023 The Better Planet Investments, labml.ai and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch GeoV model. """ + +import unittest + +from transformers import AutoTokenizer, GeoVConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + +if is_torch_available(): + import torch + + from transformers import GeoVForCausalLM, GeoVModel + + +class GeoVModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=3, + num_attention_heads=4, + intermediate_size=32 * 4, + max_position_embeddings=512, + num_labels=3, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_labels = num_labels + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def get_config(self): + return GeoVConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + max_position_embeddings=self.max_position_embeddings, + ) + + def prepare_config_and_inputs_for_decoder(self): + config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs() + + config.is_decoder = True + + return config, input_ids, input_mask, token_labels + + def create_and_check_model(self, config, input_ids, input_mask): + model = GeoVModel(config=config) + model.to(torch_device) + model.eval() + _ = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder(self, config, input_ids, input_mask): + config.add_cross_attention = True + model = GeoVModel(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels): + model = GeoVForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): + config.is_decoder = True + model = GeoVForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True) + output_from_no_past = output_from_no_past["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask, token_labels = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class GeoVModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (GeoVModel, GeoVForCausalLM) if is_torch_available() else () + all_generative_model_classes = (GeoVForCausalLM,) if is_torch_available() else () + test_pruning = False + test_missing_keys = False + test_model_parallel = False + test_head_masking = False + test_initialization = False + test_left_padding_compatibility = False + + pipeline_model_mapping = {"text-generation": GeoVForCausalLM} if is_torch_available() else {} + + def setUp(self): + self.model_tester = GeoVModelTester(self) + self.config_tester = ConfigTester(self, config_class=GeoVConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(config, input_ids, input_mask) + + def test_model_as_decoder(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + def test_decoder_model_past_large_inputs(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask) + + def test_model_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + +@require_torch +class GeoVLanguageGenerationTest(unittest.TestCase): + @slow + def test_lm_generate_gptneox(self): + tokenizer = AutoTokenizer.from_pretrained("/home/ubuntu/vpj/data/hf_9b") + for checkpointing in [True, False]: + model = GeoVForCausalLM.from_pretrained("/home/ubuntu/vpj/data/hf_9b") + + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() + model.to(torch_device) + + inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) + EXPECTED_OUTPUT = "My favorite food is pizza. I love pizza. I love pizza. I" + + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) + output_str = tokenizer.batch_decode(output_ids)[0] + + self.assertEqual(output_str, EXPECTED_OUTPUT)