diff --git a/README.md b/README.md
index b11f5bc48dc8..9b0909f6cf84 100644
--- a/README.md
+++ b/README.md
@@ -399,6 +399,7 @@ Current number of checkpoints: ** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_de.md b/README_de.md
index a78fd6b11b2e..a3f50c382a3c 100644
--- a/README_de.md
+++ b/README_de.md
@@ -395,6 +395,7 @@ Aktuelle Anzahl der Checkpoints: ** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_es.md b/README_es.md
index bcfbaa4e81e6..19523cc7b371 100644
--- a/README_es.md
+++ b/README_es.md
@@ -372,6 +372,7 @@ Número actual de puntos de control: ** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_fr.md b/README_fr.md
index 788d5344c212..2aaaf243570c 100644
--- a/README_fr.md
+++ b/README_fr.md
@@ -393,6 +393,7 @@ Nombre actuel de points de contrôle : ** (d'OpenAI) a été publié dans l'article [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) par Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (de l'Université de Beihang, UC Berkeley, Rutgers University, SEDD Company) a été publié dans l'article [Informer : Au-delà du Transformer efficace pour la prévision de séries temporel
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (de Salesforce) a été publié dans l'article [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) de Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (d'OpenAI) a été publié dans l'article [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) de Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (de Microsoft Research Asia) a été publié dans l'article [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) de Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (de Microsoft Research Asia) a été publié dans l'article [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) de Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_hd.md b/README_hd.md
index d4500d685002..b859c2931e78 100644
--- a/README_hd.md
+++ b/README_hd.md
@@ -346,6 +346,7 @@ conda install conda-forge::transformers
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (Salesforce से) Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi. द्वाराअनुसंधान पत्र [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) के साथ जारी किया गया
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_ja.md b/README_ja.md
index dd0ca695a890..af10255ac418 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -406,6 +406,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (OpenAI から) Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever から公開された研究論文: [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/)
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (Salesforce から) Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi. から公開された研究論文 [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500)
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (OpenAI から) Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever から公開された研究論文: [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf)
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (Microsoft Research Asia から) Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou から公開された研究論文: [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318)
diff --git a/README_ko.md b/README_ko.md
index 8e90082e7603..f471611a6fce 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -321,6 +321,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (OpenAI 에서) Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever 의 [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) 논문과 함께 발표했습니다.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (Salesforce 에서 제공)은 Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.의 [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500)논문과 함께 발표했습니다.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (OpenAI 에서) Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever 의 [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) 논문과 함께 발표했습니다.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (Microsoft Research Asia 에서) Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou 의 [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) 논문과 함께 발표했습니다.
diff --git a/README_pt-br.md b/README_pt-br.md
index 76f5e8ca08d2..379176fa42ee 100644
--- a/README_pt-br.md
+++ b/README_pt-br.md
@@ -404,6 +404,7 @@ Número atual de pontos de verificação: ** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_ru.md b/README_ru.md
index 0f75eac533d2..ca8d27b4e880 100644
--- a/README_ru.md
+++ b/README_ru.md
@@ -394,6 +394,7 @@ conda install conda-forge::transformers
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_te.md b/README_te.md
index a629d02f4718..4651bb7f6389 100644
--- a/README_te.md
+++ b/README_te.md
@@ -396,6 +396,7 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్స్టా
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_vi.md b/README_vi.md
index 0f3129d8c029..67981890770a 100644
--- a/README_vi.md
+++ b/README_vi.md
@@ -395,6 +395,7 @@ Số lượng điểm kiểm tra hiện tại: ** (từ OpenAI) được phát hành với bài báo [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (từ Beihang University, UC Berkeley, Rutgers University, SEDD Company) được phát hành với bài báo [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (từ Salesforce) được phát hành với bài báo [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (từ OpenAI) được phát hành với bài báo [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (từ Microsoft Research Asia) được phát hành với bài báo [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (từ Microsoft Research Asia) được phát hành với bài báo [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index e5a5eb7c0b16..04d9c22d51b9 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -345,6 +345,7 @@ conda install conda-forge::transformers
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (来自 OpenAI) 伴随论文 [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) 由 Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever 发布。
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (来自 Salesforce) 伴随论文 [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) 由 Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi 发布。
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) 由 Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index ea934e1cb874..36f51ba5c12e 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -357,6 +357,7 @@ conda install conda-forge::transformers
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
+1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from AI21 Labs Ltd.) released with the paper [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) by Opher Lieber, Barak Lenz, Hofit Bata, Gal Cohen, Jhonathan Osin, Itay Dalmedigos, Erez Safahi, Shaked Meirom, Yonatan Belinkov, Shai Shalev-Shwartz, Omri Abend, Raz Alon, Tomer Asida, Amir Bergman, Roman Glozman, Michael Gokhman, Avshalom Manevich, Nir Ratner, Noam Rozen, Erez Shwartz, Mor Zusman, Yoav Shoham.
1. **[Jukebox](https://huggingface.co/docs/transformers/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever.
1. **[KOSMOS-2](https://huggingface.co/docs/transformers/model_doc/kosmos-2)** (from Microsoft Research Asia) released with the paper [Kosmos-2: Grounding Multimodal Large Language Models to the World](https://arxiv.org/abs/2306.14824) by Zhiliang Peng, Wenhui Wang, Li Dong, Yaru Hao, Shaohan Huang, Shuming Ma, Furu Wei.
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 2591ebea1c91..52e7587fae7f 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -382,6 +382,8 @@
title: HerBERT
- local: model_doc/ibert
title: I-BERT
+ - local: model_doc/jamba
+ title: Jamba
- local: model_doc/jukebox
title: Jukebox
- local: model_doc/led
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 90cea85e077a..ea4eb92a38d7 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -164,6 +164,7 @@ Flax), PyTorch, and/or TensorFlow.
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |
+| [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ |
| [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ |
| [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ |
| [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ |
diff --git a/docs/source/en/model_doc/jamba.md b/docs/source/en/model_doc/jamba.md
new file mode 100644
index 000000000000..d8de36771da2
--- /dev/null
+++ b/docs/source/en/model_doc/jamba.md
@@ -0,0 +1,122 @@
+
+
+# Jamba
+
+## Overview
+
+Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.
+
+For full details of this model please read the [release blog post](https://www.ai21.com/blog/announcing-jamba).
+
+### Model Details
+
+Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.
+
+As depicted in the diagram below, Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.
+
+
+
+## Usage
+
+### Presequities
+
+Jamba requires you use `transformers` version 4.39.0 or higher:
+```bash
+pip install transformers>=4.39.0
+```
+
+In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`:
+```bash
+pip install mamba-ssm causal-conv1d>=1.2.0
+```
+You also have to have the model on a CUDA device.
+
+You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.
+
+### Run the model
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
+tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
+
+input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
+
+outputs = model.generate(input_ids, max_new_tokens=216)
+
+print(tokenizer.batch_decode(outputs))
+# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
+```
+
+
+Loading the model in half precision
+
+The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`:
+
+```python
+from transformers import AutoModelForCausalLM
+import torch
+model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)
+# you can also use torch_dtype=torch.float16
+```
+
+When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
+```python
+from transformers import AutoModelForCausalLM
+import torch
+model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ device_map="auto")
+```
+
+
+Load the model in 8-bit
+
+**Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
+
+```python
+from transformers import AutoModelForCausalLM, BitsAndBytesConfig
+quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
+model = AutoModelForCausalLM.from_pretrained(
+ "ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config
+)
+```
+
+
+## JambaConfig
+
+[[autodoc]] JambaConfig
+
+
+## JambaModel
+
+[[autodoc]] JambaModel
+ - forward
+
+
+## JambaForCausalLM
+
+[[autodoc]] JambaForCausalLM
+ - forward
+
+
+## JambaForSequenceClassification
+
+[[autodoc]] transformers.JambaForSequenceClassification
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 2a6c2e2b136c..08a3f0fbe126 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
+* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
@@ -184,6 +185,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
+* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md
index e8bbb7925365..c51b45528ac9 100644
--- a/docs/source/en/tasks/language_modeling.md
+++ b/docs/source/en/tasks/language_modeling.md
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
-[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OLMo](../model_doc/olmo), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
+[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [Jamba](../model_doc/jamba), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OLMo](../model_doc/olmo), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
diff --git a/docs/source/en/tasks/sequence_classification.md b/docs/source/en/tasks/sequence_classification.md
index 572bc6a8b824..55f05e0956b5 100644
--- a/docs/source/en/tasks/sequence_classification.md
+++ b/docs/source/en/tasks/sequence_classification.md
@@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
-[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
+[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [Jamba](../model_doc/jamba), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 5fe4c1642d81..333d8feebbdb 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -20,6 +20,7 @@
__version__ = "4.40.0.dev0"
+
from typing import TYPE_CHECKING
# Check the dependencies satisfy the minimal versions required.
@@ -517,6 +518,7 @@
"InstructBlipQFormerConfig",
"InstructBlipVisionConfig",
],
+ "models.jamba": ["JambaConfig"],
"models.jukebox": [
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
"JukeboxConfig",
@@ -1473,6 +1475,7 @@
"AlignVisionModel",
]
)
+
_import_structure["models.altclip"].extend(
[
"ALTCLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2480,6 +2483,14 @@
"InstructBlipVisionModel",
]
)
+ _import_structure["models.jamba"].extend(
+ [
+ "JambaForCausalLM",
+ "JambaForSequenceClassification",
+ "JambaModel",
+ "JambaPreTrainedModel",
+ ]
+ )
_import_structure["models.jukebox"].extend(
[
"JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -5439,6 +5450,7 @@
InstructBlipQFormerConfig,
InstructBlipVisionConfig,
)
+ from .models.jamba import JambaConfig
from .models.jukebox import (
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
JukeboxConfig,
@@ -7213,6 +7225,12 @@
InstructBlipQFormerModel,
InstructBlipVisionModel,
)
+ from .models.jamba import (
+ JambaForCausalLM,
+ JambaForSequenceClassification,
+ JambaModel,
+ JambaPreTrainedModel,
+ )
from .models.jukebox import (
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
JukeboxModel,
@@ -7852,8 +7870,6 @@
SamModel,
SamPreTrainedModel,
)
-
- # PyTorch model imports
from .models.seamless_m4t import (
SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST,
SeamlessM4TCodeHifiGan,
diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py
index 735431fe6fec..6bd55c5f6b51 100644
--- a/src/transformers/generation/candidate_generator.py
+++ b/src/transformers/generation/candidate_generator.py
@@ -18,6 +18,8 @@
import torch
+from ..cache_utils import DynamicCache
+
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
@@ -371,7 +373,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
- else:
+ elif isinstance(past_key_values, DynamicCache):
+ for idx in range(len(past_key_values.key_cache)):
+ if past_key_values.value_cache[idx].shape[-1] != 0:
+ past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
+ past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
+
+ elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
(
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 36e62794a435..002cea9d73ca 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -598,7 +598,11 @@ def _expand_inputs_for_generation(
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
- if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ ):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
@@ -2094,7 +2098,8 @@ def _contrastive_search(
# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
- for layer in model_kwargs["past_key_values"]:
+ past = model_kwargs["past_key_values"]
+ for layer in past:
items = []
# item is either the key or the value matrix
for item in layer:
@@ -2103,7 +2108,13 @@ def _contrastive_search(
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))
- model_kwargs["past_key_values"] = tuple(new_key_values)
+ if not isinstance(past, DynamicCache):
+ past = tuple(new_key_values)
+ else:
+ for layer_idx in range(len(new_key_values)):
+ past.key_cache[layer_idx] = new_key_values[layer_idx][0]
+ past.value_cache[layer_idx] = new_key_values[layer_idx][1]
+ model_kwargs["past_key_values"] = past
if sequential:
all_outputs = []
@@ -2178,16 +2189,22 @@ def _contrastive_search(
else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
- new_key_values = ()
+ new_key_values = []
for layer in next_past_key_values:
- items = ()
+ items = []
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
- items += (item,)
- new_key_values += (items,)
- next_past_key_values = new_key_values
+ items += [item]
+ new_key_values += [items]
+
+ if not isinstance(next_past_key_values, DynamicCache):
+ next_past_key_values = tuple(new_key_values)
+ else:
+ for layer_idx in range(len(new_key_values)):
+ next_past_key_values.key_cache[layer_idx] = new_key_values[layer_idx][0]
+ next_past_key_values.value_cache[layer_idx] = new_key_values[layer_idx][1]
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
@@ -3127,6 +3144,7 @@ def _beam_search(
"transo_xl",
"xlnet",
"cpm",
+ "jamba",
]
):
raise RuntimeError(
@@ -4645,21 +4663,22 @@ def _assisted_decoding(
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Prepare the model inputs
- candidate_kwargs = copy.copy(model_kwargs)
- candidate_kwargs = _prepare_attention_mask(
- candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
+ model_kwargs = _prepare_attention_mask(
+ model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
- candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
- if "cache_position" in candidate_kwargs:
- candidate_kwargs["cache_position"] = torch.cat(
+ model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1])
+ if "cache_position" in model_kwargs:
+ model_kwargs["cache_position"] = torch.cat(
(
- candidate_kwargs["cache_position"],
+ model_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)
- model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
+ model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs)
+ if "num_logits_to_keep" in model_inputs:
+ model_inputs["num_logits_to_keep"] = candidate_length + 1
# 2.2. Run a forward pass on the candidate sequence
outputs = self(
@@ -4985,7 +5004,7 @@ def _split_model_inputs(
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
- keys_to_ignore = ["cache_position", "encoder_outputs"]
+ keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
# we split the tensors and tuples of tensors
@@ -5001,6 +5020,11 @@ def _split_model_inputs(
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]
+ # num_logits_to_keep should be replicated for each split, similar to bool values
+ if "num_logits_to_keep" in model_input:
+ data_split_list = [
+ {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list
+ ]
# Convert each dictionary in the list to an object of the inferred class
split_model_inputs: List[Union[ModelOutput, Dict]] = [
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 5b92ab6a7350..50c96c837080 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -115,6 +115,7 @@
imagegpt,
informer,
instructblip,
+ jamba,
jukebox,
kosmos2,
layoutlm,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 1cdab274d0df..d2ea6b7682d4 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -129,6 +129,7 @@
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
+ ("jamba", "JambaConfig"),
("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"),
("layoutlm", "LayoutLMConfig"),
@@ -397,6 +398,7 @@
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),
+ ("jamba", "Jamba"),
("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"),
("layoutlm", "LayoutLM"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 66e7a0b0fa8b..85b818fff553 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -123,6 +123,7 @@
("idefics2", "Idefics2Model"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
+ ("jamba", "JambaModel"),
("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"),
("layoutlm", "LayoutLMModel"),
@@ -451,6 +452,7 @@
("gpt_neox", "GPTNeoXForCausalLM"),
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
+ ("jamba", "JambaForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"),
@@ -851,6 +853,7 @@
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
+ ("jamba", "JambaForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 8822ef5e1cb8..64dc057a5f7f 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -203,6 +203,13 @@
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "jamba",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
("jukebox", ("JukeboxTokenizer", None)),
(
"kosmos-2",
diff --git a/src/transformers/models/jamba/__init__.py b/src/transformers/models/jamba/__init__.py
new file mode 100644
index 000000000000..f6b7c2137b20
--- /dev/null
+++ b/src/transformers/models/jamba/__init__.py
@@ -0,0 +1,58 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_jamba": ["JambaConfig"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_jamba"] = [
+ "JambaForCausalLM",
+ "JambaForSequenceClassification",
+ "JambaModel",
+ "JambaPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_jamba import JambaConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_jamba import (
+ JambaForCausalLM,
+ JambaForSequenceClassification,
+ JambaModel,
+ JambaPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/jamba/configuration_jamba.py b/src/transformers/models/jamba/configuration_jamba.py
new file mode 100644
index 000000000000..de9cd378bdc1
--- /dev/null
+++ b/src/transformers/models/jamba/configuration_jamba.py
@@ -0,0 +1,223 @@
+# coding=utf-8
+# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Jamba model configuration"""
+import math
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class JambaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
+
+ [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 65536):
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`JambaModel`]
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
+ model has a output word embedding layer.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
+ significantly.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If not specified, will default to `None`.
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
+ used with. It can be used with longer sequences, but performance may degrade.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
+ parameter
+ num_experts (`int`, *optional*, defaults to 16):
+ Number of experts per Sparse MLP layer.
+ expert_layer_period (`int`, *optional*, defaults to 2):
+ Once in this many layers, we will have an expert layer
+ expert_layer_offset (`int`, *optional*, defaults to 1):
+ The first layer index that contains an expert mlp layer
+ attn_layer_period (`int`, *optional*, defaults to 8):
+ Once in this many layers, we will have a vanilla attention layer
+ attn_layer_offset (`int`, *optional*, defaults to 4):
+ The first layer index that contains a vanilla attention mlp layer
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
+ `True` and kernels are not available
+ mamba_d_state (`int`, *optional*, defaults to 16):
+ The dimension the mamba state space latents
+ mamba_d_conv (`int`, *optional*, defaults to 4):
+ The size of the mamba convolution kernel
+ mamba_expand (`int`, *optional*, defaults to 2):
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
+
+ """
+
+ model_type = "jamba"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=65536,
+ tie_word_embeddings=False,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ num_logits_to_keep=1,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ sliding_window=None,
+ max_position_embeddings=262144,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_experts=16,
+ expert_layer_period=2,
+ expert_layer_offset=1,
+ attn_layer_period=8,
+ attn_layer_offset=4,
+ use_mamba_kernels=True,
+ mamba_d_state=16,
+ mamba_d_conv=4,
+ mamba_expand=2,
+ mamba_dt_rank="auto",
+ mamba_conv_bias=True,
+ mamba_proj_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.tie_word_embeddings = tie_word_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+ self.max_position_embeddings = max_position_embeddings
+ self.attention_dropout = attention_dropout
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+
+ self.use_cache = use_cache
+ self.num_logits_to_keep = num_logits_to_keep
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.expert_layer_period = expert_layer_period
+ self.expert_layer_offset = expert_layer_offset
+ self.attn_layer_period = attn_layer_period
+ self.attn_layer_offset = attn_layer_offset
+
+ self.use_mamba_kernels = use_mamba_kernels
+ self.mamba_d_state = mamba_d_state
+ self.mamba_d_conv = mamba_d_conv
+ self.mamba_expand = mamba_expand
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
+ self.mamba_conv_bias = mamba_conv_bias
+ self.mamba_proj_bias = mamba_proj_bias
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ @property
+ def layers_block_type(self):
+ return [
+ "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
+ for i in range(self.num_hidden_layers)
+ ]
+
+ @property
+ def layers_num_experts(self):
+ return [
+ self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
+ for i in range(self.num_hidden_layers)
+ ]
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
new file mode 100755
index 000000000000..9780d95d4ee3
--- /dev/null
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -0,0 +1,1882 @@
+# coding=utf-8
+# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Jamba model."""
+import inspect
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+)
+from ...modeling_outputs import (
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.import_utils import (
+ is_causal_conv1d_available,
+ is_flash_attn_2_available,
+ is_mamba_ssm_available,
+)
+from .configuration_jamba import JambaConfig
+
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+
+
+if is_mamba_ssm_available():
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+else:
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
+
+if is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+else:
+ causal_conv1d_update, causal_conv1d_fn = None, None
+
+is_fast_path_available = all(
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+)
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "JambaConfig"
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
+def load_balancing_loss_func(
+ router_logits: torch.Tensor,
+ num_experts: torch.Tensor = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> float:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
+ Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ attention_mask (`torch.Tensor`, None):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+ num_experts (`int`, *optional*):
+ Number of experts
+
+ Returns:
+ The auxiliary loss.
+ """
+ if router_logits is None or not isinstance(router_logits, tuple):
+ return 0
+
+ if isinstance(router_logits, tuple):
+ compute_device = router_logits[0].device
+ concatenated_router_logits = torch.cat(
+ [layer_router.to(compute_device) for layer_router in router_logits], dim=0
+ )
+
+ routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
+class JambaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ JambaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class HybridMambaAttentionDynamicCache(DynamicCache):
+ """
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
+ (which has a constant shape regardless of seq_len).
+
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
+ """
+
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
+ self.dtype = dtype
+ self.layers_block_type = config.layers_block_type
+ self.has_previous_state = False # only used by mamba
+ intermediate_size = config.mamba_expand * config.hidden_size
+ ssm_state_size = config.mamba_d_state
+ conv_kernel_size = config.mamba_d_conv
+ self.conv_states = []
+ self.ssm_states = []
+ for i in range(config.num_hidden_layers):
+ if self.layers_block_type[i] == "mamba":
+ self.conv_states += [
+ torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
+ ]
+ self.ssm_states += [
+ torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
+ ]
+ else:
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
+
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the cache
+ if self.key_cache[layer_idx].shape[-1] == 0:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ device = self.conv_states[layer_idx].device
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.ssm_states[layer_idx].device
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
+
+ @classmethod
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
+class JambaAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
+class JambaFlashAttention2(JambaAttention):
+ """
+ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = cache_position[-1]
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
+ " make sure to upgrade flash-attn library."
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = cache_position[0] > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ if not use_sliding_windows:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ if not use_sliding_windows:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
+
+ # On the first iteration we need to properly re-create the padding mask
+ # by slicing it on the proper place
+ if kv_seq_len != attention_mask.shape[-1]:
+ attention_mask_num_tokens = attention_mask.shape[-1]
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
+class JambaSdpaAttention(JambaAttention):
+ """
+ Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from JambaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+JAMBA_ATTENTION_CLASSES = {
+ "eager": JambaAttention,
+ "flash_attention_2": JambaFlashAttention2,
+ "sdpa": JambaSdpaAttention,
+}
+
+
+# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
+class JambaMambaMixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+ and is why Mamba is called **selective** state spaces)
+ """
+
+ def __init__(self, config: JambaConfig, layer_idx):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
+ self.intermediate_size = config.mamba_expand * config.hidden_size
+ self.time_step_rank = config.mamba_dt_rank
+ self.use_conv_bias = config.mamba_conv_bias
+ self.use_bias = config.mamba_proj_bias
+ self.conv1d = nn.Conv1d(
+ in_channels=self.intermediate_size,
+ out_channels=self.intermediate_size,
+ bias=self.use_conv_bias,
+ kernel_size=self.conv_kernel_size,
+ groups=self.intermediate_size,
+ padding=self.conv_kernel_size - 1,
+ )
+
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+
+ self.use_fast_kernels = config.use_mamba_kernels
+
+ # projection of the input hidden states
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
+ # selective projection used to make dt, B and C input dependant
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
+ # time step projection (discretization)
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(self.intermediate_size, -1).contiguous()
+
+ self.A_log = nn.Parameter(torch.log(A))
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
+
+ self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
+ self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
+ self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
+
+ if not is_fast_path_available:
+ logger.warning_once(
+ "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+ " is None. To install follow https://github.com/state-spaces/mamba/#installation and"
+ " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
+ )
+
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
+ batch_size, seq_len, _ = hidden_states.shape
+ use_precomputed_states = (
+ cache_params is not None
+ and cache_params.has_previous_state
+ and seq_len == 1
+ and cache_params.conv_states[self.layer_idx].shape[0]
+ == cache_params.ssm_states[self.layer_idx].shape[0]
+ == batch_size
+ )
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
+
+ # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
+ # inner layernorms which isn't supported by this fused kernel
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ # 2. Convolution sequence transformation
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
+ if use_precomputed_states:
+ hidden_states = causal_conv1d_update(
+ hidden_states.squeeze(-1),
+ cache_params.conv_states[self.layer_idx],
+ conv_weights,
+ self.conv1d.bias,
+ self.activation,
+ )
+ hidden_states = hidden_states.unsqueeze(-1)
+ else:
+ if cache_params is not None:
+ conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
+ hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. input varying initialization of time_step, B and C
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ time_step = self.dt_layernorm(time_step)
+ B = self.b_layernorm(B)
+ C = self.c_layernorm(C)
+
+ # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
+ # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
+ # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
+ # linear layers, and requires to call the forward pass directly.
+ # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
+ time_proj_bias = self.dt_proj.bias
+ self.dt_proj.bias = None
+ discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
+ self.dt_proj.bias = time_proj_bias
+
+ A = -torch.exp(self.A_log.float())
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
+ if use_precomputed_states:
+ scan_outputs = selective_state_update(
+ cache_params.ssm_states[self.layer_idx],
+ hidden_states[..., 0],
+ discrete_time_step[..., 0],
+ A,
+ B[:, 0],
+ C[:, 0],
+ self.D,
+ gate[..., 0],
+ time_proj_bias,
+ dt_softplus=True,
+ ).unsqueeze(-1)
+ else:
+ scan_outputs, ssm_state = selective_scan_fn(
+ hidden_states,
+ discrete_time_step,
+ A,
+ B.transpose(1, 2),
+ C.transpose(1, 2),
+ self.D.float(),
+ gate,
+ time_proj_bias,
+ delta_softplus=True,
+ return_last_state=True,
+ )
+ if ssm_state is not None and cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
+
+ return contextualized_states
+
+ # fmt: off
+ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
+ # 2. Convolution sequence transformation
+ if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
+ if self.training:
+ # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ else:
+ ssm_state = cache_params.ssm_states[self.layer_idx]
+
+ if cache_params.has_previous_state and seq_len == 1 and \
+ cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
+ cache_params.conv_states[self.layer_idx] = conv_state
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
+ else:
+ conv_state = nn.functional.pad(
+ hidden_states,
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+ cache_params.conv_states[self.layer_idx] = conv_state
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.intermediate_size, self.ssm_state_size),
+ device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+
+ # 3. State Space Model sequence transformation
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ time_step = self.dt_layernorm(time_step)
+ B = self.b_layernorm(B)
+ C = self.c_layernorm(C)
+
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
+
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ scan_outputs = []
+ for i in range(seq_len):
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
+ scan_outputs.append(scan_output[:, :, 0])
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
+ scan_output = (scan_output * self.act(gate))
+
+ if use_cache:
+ cache_params.ssm_states[self.layer_idx] = ssm_state
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+ # fmt: on
+
+ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
+ if self.use_fast_kernels:
+ if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
+ raise ValueError(
+ "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
+ )
+ return self.cuda_kernels_forward(hidden_states, cache_params)
+ return self.slow_forward(hidden_states, cache_params)
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
+class JambaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
+class JambaSparseMoeBlock(nn.Module):
+ """
+ This implementation is
+ strictly equivalent to standard MoE with full capacity (no
+ dropped tokens). It's faster since it formulates MoE operations
+ in terms of block-sparse operations to accomodate imbalanced
+ assignments of tokens to experts, whereas standard MoE either
+ (1) drop tokens at the cost of reduced performance or (2) set
+ capacity factor to number of experts and thus waste computation
+ and memory on padding.
+ """
+
+ def __init__(self, config: JambaConfig):
+ super().__init__()
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+
+ self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+ self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.router(hidden_states)
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ if top_x.shape[0] == 0:
+ continue
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
+
+
+class JambaAttentionDecoderLayer(nn.Module):
+ def __init__(self, config: JambaConfig, layer_idx: int):
+ super().__init__()
+ num_experts = config.layers_num_experts[layer_idx]
+ self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
+ self.feed_forward = ffn_layer_class(config)
+ self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ # residual connection after attention
+ hidden_states = residual + hidden_states
+
+ # feed-forward (experts/MLP)
+ residual = hidden_states
+ hidden_states = self.pre_ff_layernorm(hidden_states)
+ ff_outputs = self.feed_forward(hidden_states)
+ if isinstance(ff_outputs, tuple):
+ hidden_states, router_logits = ff_outputs
+ else:
+ hidden_states, router_logits = ff_outputs, None
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+class JambaMambaDecoderLayer(nn.Module):
+ def __init__(self, config: JambaConfig, layer_idx: int):
+ super().__init__()
+ num_experts = config.layers_num_experts[layer_idx]
+ self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
+
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
+ self.feed_forward = ffn_layer_class(config)
+ self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.mamba(
+ hidden_states=hidden_states,
+ cache_params=past_key_value,
+ )
+ self_attn_weights = None
+
+ # residual connection after mamba
+ hidden_states = residual + hidden_states
+
+ # feed-forward (experts/MLP)
+ residual = hidden_states
+ hidden_states = self.pre_ff_layernorm(hidden_states)
+ ff_outputs = self.feed_forward(hidden_states)
+ if isinstance(ff_outputs, tuple):
+ hidden_states, router_logits = ff_outputs
+ else:
+ hidden_states, router_logits = ff_outputs, None
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+JAMBA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`JambaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
+ JAMBA_START_DOCSTRING,
+)
+class JambaPreTrainedModel(PreTrainedModel):
+ config_class = JambaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+JAMBA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
+ self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
+ Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
+ `(batch_size, d_inner, d_state)` respectively.
+ See the `HybridMambaAttentionDynamicCache` class for more details.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
+
+
+@add_start_docstrings(
+ "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
+ JAMBA_START_DOCSTRING,
+)
+# Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba
+class JambaModel(JambaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`]
+
+ Args:
+ config: JambaConfig
+ """
+
+ def __init__(self, config: JambaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ decoder_layers = []
+ for i in range(config.num_hidden_layers):
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
+ decoder_layers.append(layer_class(config, layer_idx=i))
+ self.layers = nn.ModuleList(decoder_layers)
+
+ self._attn_implementation = config._attn_implementation
+ self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+
+ if use_cache and past_key_values is None:
+ logger.warning_once(
+ "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
+ "provided, so no cache will be returned."
+ )
+
+ if cache_position is None:
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ if layer_outputs[1] is not None:
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ if layer_outputs[-1] is not None:
+ # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if past_key_values and not past_key_values.has_previous_state:
+ past_key_values.has_previous_state = True
+
+ next_cache = None if not use_cache else past_key_values
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ target_length = cache_position[-1] + 1
+
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.dim() == 2:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
+class JambaForCausalLM(JambaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: JambaConfig):
+ super().__init__(config)
+ self.model = JambaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[Union[int, None]] = None,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int` or `None`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
+ `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
+ can save memory, which becomes pretty significant for long sequences.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, JambaForCausalLM
+
+ >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if num_logits_to_keep is None:
+ logits = self.lm_head(hidden_states)
+ else:
+ logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1],
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ output_router_logits=False,
+ cache_position=None,
+ **kwargs,
+ ):
+ empty_past_kv = past_key_values is None
+
+ # Omit tokens covered by past_key_values
+ if not empty_past_kv:
+ past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1]
+ max_cache_length = self.config.sliding_window
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and past_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+ else:
+ past_key_values = HybridMambaAttentionDynamicCache(
+ self.config, input_ids.shape[0], self.dtype, device=self.device
+ )
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if not empty_past_kv:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and empty_past_kv:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "output_router_logits": output_router_logits,
+ "num_logits_to_keep": self.config.num_logits_to_keep,
+ "cache_position": cache_position,
+ }
+ )
+ return model_inputs
+
+
+@add_start_docstrings(
+ """
+ The Jamba Model with a sequence classification head on top (linear layer).
+
+ [`JambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ JAMBA_START_DOCSTRING,
+)
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA
+class JambaForSequenceClassification(JambaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = JambaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 3c5c2831d3a5..e074cfb6252a 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -4526,6 +4526,34 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class JambaForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class JambaForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class JambaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class JambaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index d6b4840c4910..8382273bef4b 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -1050,7 +1050,7 @@ def test_contrastive_generate_low_memory(self):
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format")
- if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
+ if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
self.skipTest("TODO: fix me")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
@@ -1098,6 +1098,7 @@ def test_beam_search_low_memory(self):
"transo_xl",
"xlnet",
"cpm",
+ "jamba",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
@@ -1735,11 +1736,12 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
use_cache=use_cache,
)
- # Past Key Value States -- two notes here:
+ # Past Key Value States -- a few notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
- # 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
- models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
+ # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
+ # complete
+ models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
diff --git a/tests/models/jamba/__init__.py b/tests/models/jamba/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py
new file mode 100644
index 000000000000..f8e9fdb77b20
--- /dev/null
+++ b/tests/models/jamba/test_modeling_jamba.py
@@ -0,0 +1,730 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Jamba model. """
+import math
+import tempfile
+import unittest
+
+import pytest
+from parameterized import parameterized
+
+from transformers import AutoTokenizer, JambaConfig, is_torch_available
+from transformers.testing_utils import (
+ require_bitsandbytes,
+ require_flash_attn,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ JambaForCausalLM,
+ JambaForSequenceClassification,
+ JambaModel,
+ )
+ from transformers.models.jamba.modeling_jamba import (
+ HybridMambaAttentionDynamicCache,
+ )
+
+
+class JambaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ attn_layer_offset=1,
+ attn_layer_period=8,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.attn_layer_offset = attn_layer_offset
+ self.attn_layer_period = attn_layer_period
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ return JambaConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ attn_layer_offset=self.attn_layer_offset,
+ attn_layer_period=self.attn_layer_period,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_key_value_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=True,
+ initializer_range=self.initializer_range,
+ use_mamba_kernels=False,
+ num_experts=2,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+
+ return (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ )
+
+ def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
+ model = JambaModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_causal_lm(
+ self,
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ model = JambaForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=token_labels)
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids, labels=token_labels)
+ result = model(input_ids)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ config.is_decoder = True
+ config.add_cross_attention = True
+ model = JambaForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ # Attention: Jamba needs the cache to be initialized to return a cache!
+ past_key_values = HybridMambaAttentionDynamicCache(
+ config, input_ids.shape[0], model.dtype, device=model.device
+ )
+ outputs = model(
+ input_ids,
+ attention_mask=input_mask,
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(
+ next_input_ids,
+ attention_mask=next_attention_mask,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ cache_position=torch.arange(
+ input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
+ ),
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_for_sequence_classification(
+ self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_labels = self.num_labels
+ model = JambaForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ JambaModel,
+ JambaForCausalLM,
+ JambaForSequenceClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+ all_generative_model_classes = (JambaForCausalLM,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": JambaModel,
+ "text-classification": JambaForSequenceClassification,
+ "text-generation": JambaForCausalLM,
+ "zero-shot": JambaForSequenceClassification,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+
+ def setUp(self):
+ self.model_tester = JambaModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_casual_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_load_balancing_loss(self):
+ r"""
+ Let's make sure we can actually compute the loss and do a backward on it.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.num_labels = 3
+ config.num_experts = 16
+ config.output_router_logits = True
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(config.pad_token_id).to(torch_device)
+ model = JambaForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask)
+ bs, seqlen = input_ids.shape
+ self.assertEqual(result.router_logits[0].shape, (bs * seqlen, config.num_experts))
+ torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
+
+ # First, we make sure that adding padding tokens doesn't change the loss
+ # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
+ pad_length = 1000
+ # Add padding tokens to input_ids
+ padding_block = config.pad_token_id * torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(
+ torch_device
+ )
+ padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
+ padded_attention_mask = padded_input_ids.ne(config.pad_token_id).to(torch_device)
+
+ padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
+ torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)
+
+ # We make sure that the loss of including padding tokens != the loss without padding tokens
+ # if attention_mask=None --> we don't exclude padding tokens
+ include_padding_result = model(padded_input_ids, attention_mask=None)
+
+ # This is to mimic torch.testing.assert_not_close
+ self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
+
+ def test_initialization(self):
+ r"""
+ Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if "A_log" in name:
+ A = torch.arange(1, config.mamba_d_state + 1, dtype=torch.float32)[None, :]
+ self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
+ elif "D" in name:
+ # check if it's a ones like
+ self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def test_mismatched_shapes_have_properly_initialized_weights(self):
+ r"""
+ Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
+ Mamba block are initialized differently and we tested that in test_initialization
+ """
+ self.skipTest("Cumbersome and redundant for Jamba")
+
+ def test_attention_outputs(self):
+ r"""
+ Overriding the test_attention_outputs test as the Jamba model outputs attention only for its attention layers
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+
+ expected_num_attentions = math.ceil(
+ (self.model_tester.num_hidden_layers - self.model_tester.attn_layer_offset)
+ / self.model_tester.attn_layer_period
+ )
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+
+ def test_left_padding_compatibility(self):
+ r"""
+ Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
+ effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
+ """
+ import inspect
+ # NOTE: left-padding results in small numerical differences. This is expected.
+ # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
+
+ # First, filter out models that don't support left padding - generative and decoder-only.
+ # Jamba is a decoder-only architecture
+ decoder_only_classes = self.all_generative_model_classes
+
+ # Then, test left-padding
+ def _prepare_model_kwargs(input_ids, attention_mask, signature):
+ model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ if "cache_position" in signature:
+ cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
+ model_kwargs["cache_position"] = cache_position
+ return model_kwargs
+
+ for model_class in decoder_only_classes:
+ config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
+ signature = inspect.signature(model.forward).parameters.keys()
+
+ # Without padding
+ model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
+ next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # With left-padding (length 32)
+ pad_size = (input_ids.shape[0], 32)
+ padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
+ padded_input_ids = torch.cat((padding, input_ids), dim=1)
+ padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
+ model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
+ next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # They should result in very similar logits
+ self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
+
+ @require_flash_attn
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_fp32_ln(self):
+ r"""
+ Overriding the test_flash_attn_2_fp32_ln test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA2.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_padding_right(self):
+ r"""
+ Overriding the test_flash_attn_2_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_use_cache(self):
+ r"""
+ Overriding the test_flash_attn_2_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA2.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference_equivalence_right_padding(self):
+ r"""
+ Overriding the test_flash_attn_2_inference_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ self.skipTest("Jamba flash attention does not support right padding")
+
+ @unittest.skip("Jamba has its own special cache type")
+ @parameterized.expand([(1, False), (1, True), (4, False)])
+ def test_new_cache_format(self, num_beams, do_sample):
+ pass
+
+
+@require_torch
+class JambaModelIntegrationTest(unittest.TestCase):
+ model = None
+ tokenizer = None
+
+ @classmethod
+ def setUpClass(cls):
+ model_id = "ai21labs/Jamba-tiny-random"
+ cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
+ cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+ @slow
+ def test_simple_generate(self):
+ self.model.to(torch_device)
+
+ input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
+ "input_ids"
+ ].to(torch_device)
+ out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
+ output_sentence = self.tokenizer.decode(out[0, :])
+ self.assertEqual(
+ output_sentence,
+ "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
+ )
+
+ with torch.no_grad():
+ logits = self.model(input_ids=input_ids).logits
+
+ EXPECTED_LOGITS_NO_GRAD = torch.tensor(
+ [
+ 0.0140, -0.2246, 0.0408, -0.1016, 0.0471, 0.2715, -0.1465, 0.1631,
+ -0.2949, -0.0297, 0.0250, -0.5586, -0.2139, -0.1426, -0.1602, 0.1309,
+ 0.0703, 0.2236, 0.1729, -0.2285, -0.1152, -0.1177, -0.1367, 0.0289,
+ 0.1245, 0.2363, 0.0442, 0.1094, -0.1348, -0.2295, 0.1494, -0.3945,
+ 0.1777, -0.4570, -0.0408, 0.2412, 0.1562, -0.1943, 0.2373, -0.0593
+ ]
+ , dtype=torch.float32) # fmt: skip
+
+ torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
+
+ @slow
+ def test_simple_batched_generate_with_padding(self):
+ self.model.to(torch_device)
+
+ inputs = self.tokenizer(
+ ["Hey how are you doing on this lovely evening?", "Tell me a story"], padding=True, return_tensors="pt"
+ ).to(torch_device)
+ out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
+ output_sentences = self.tokenizer.batch_decode(out)
+ self.assertEqual(
+ output_sentences[0],
+ "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
+ )
+ self.assertEqual(
+ output_sentences[1],
+ "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
+ )
+
+ with torch.no_grad():
+ logits = self.model(input_ids=inputs["input_ids"]).logits
+
+ EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
+ [
+ 0.0140, -0.2246, 0.0408, -0.1016, 0.0471, 0.2715, -0.1465, 0.1631,
+ -0.2949, -0.0297, 0.0250, -0.5586, -0.2139, -0.1426, -0.1602, 0.1309,
+ 0.0703, 0.2236, 0.1729, -0.2285, -0.1152, -0.1177, -0.1367, 0.0289,
+ 0.1245, 0.2363, 0.0442, 0.1094, -0.1348, -0.2295, 0.1494, -0.3945,
+ 0.1777, -0.4570, -0.0408, 0.2412, 0.1562, -0.1943, 0.2373, -0.0593
+ ]
+ , dtype=torch.float32) # fmt: skip
+
+ EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
+ [
+ -0.1289, 0.2363, -0.4180, -0.0302, -0.0476, 0.0327, 0.2578, 0.0874,
+ 0.1484, 0.2305, -0.1152, -0.1396, -0.1494, -0.1113, -0.0021, -0.2832,
+ 0.2002, -0.2676, 0.0598, -0.1982, -0.2539, -0.1133, -0.1973, 0.2148,
+ 0.0559, 0.1670, 0.1846, 0.1270, 0.1680, -0.1250, -0.2656, -0.2871,
+ 0.2344, 0.2637, 0.0510, -0.1855, 0.2158, -0.1289, 0.1758, 0.0074
+ ]
+ , dtype=torch.float32) # fmt: skip
+
+ torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
+ torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 736ee681c376..f631c59b75d4 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -32,6 +32,15 @@
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
+ # 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
+ # periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
+ "JambaConfig": [
+ "max_position_embeddings",
+ "attn_layer_offset",
+ "attn_layer_period",
+ "expert_layer_offset",
+ "expert_layer_period",
+ ],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt
index 54924759cafd..1869836909e6 100644
--- a/utils/not_doctested.txt
+++ b/utils/not_doctested.txt
@@ -631,6 +631,8 @@ src/transformers/models/instructblip/configuration_instructblip.py
src/transformers/models/instructblip/convert_instructblip_original_to_pytorch.py
src/transformers/models/instructblip/modeling_instructblip.py
src/transformers/models/instructblip/processing_instructblip.py
+src/transformers/models/jamba/configuration_jamba.py
+src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jukebox/configuration_jukebox.py
src/transformers/models/jukebox/convert_jukebox.py
src/transformers/models/jukebox/modeling_jukebox.py