diff --git a/README.md b/README.md index 32ab0fc894c8..bdda50acc776 100644 --- a/README.md +++ b/README.md @@ -346,6 +346,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/README_es.md b/README_es.md index 7bcaab0cc3d7..c66f4b77b2e4 100644 --- a/README_es.md +++ b/README_es.md @@ -346,6 +346,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/README_hd.md b/README_hd.md index 6443d230ee1e..31f6b448cdde 100644 --- a/README_hd.md +++ b/README_hd.md @@ -319,6 +319,7 @@ conda install -c huggingface transformers 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (फेसबुक से) साथ देने वाला पेपर [बियॉन्ड इंग्लिश-सेंट्रिक मल्टीलिंगुअल मशीन ट्रांसलेशन](https://arxiv.org/ एब्स/2010.11125) एंजेला फैन, श्रुति भोसले, होल्गर श्वेन्क, झी मा, अहमद अल-किश्की, सिद्धार्थ गोयल, मनदीप बैनेस, ओनूर सेलेबी, गुइल्लाम वेन्जेक, विश्रव चौधरी, नमन गोयल, टॉम बर्च, विटाली लिपचिंस्की, सर्गेई एडुनोव, एडौर्ड द्वारा ग्रेव, माइकल औली, आर्मंड जौलिन द्वारा पोस्ट किया गया। 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg द्वारा [OPUS](http://opus.nlpl.eu/) डेटा से प्रशिक्षित मशीनी अनुवाद मॉडल पोस्ट किया गया टाइडेमैन द्वारा। [मैरियन फ्रेमवर्क](https://marian-nmt.github.io/) माइक्रोसॉफ्ट ट्रांसलेटर टीम द्वारा विकसित। 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (माइक्रोसॉफ्ट रिसर्च एशिया से) साथ में पेपर [मार्कअपएलएम: विजुअली-रिच डॉक्यूमेंट अंडरस्टैंडिंग के लिए टेक्स्ट और मार्कअप लैंग्वेज का प्री-ट्रेनिंग] (https://arxiv.org/abs/2110.08518) जुनलॉन्ग ली, यिहेंग जू, लेई कुई, फुरु द्वारा वी द्वारा पोस्ट किया गया। +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (FAIR and UIUC से) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. द्वाराअनुसंधान पत्र [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) के साथ जारी किया गया 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (मेटा और UIUC से) पेपर के साथ जारी किया गया [प्रति-पिक्सेल वर्गीकरण वह सब नहीं है जिसकी आपको सिमेंटिक सेगमेंटेशन की आवश्यकता है] (https://arxiv.org/abs/2107.06278) बोवेन चेंग, अलेक्जेंडर जी. श्विंग, अलेक्जेंडर किरिलोव द्वारा >>>>>> रिबेस ठीक करें 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (फेसबुक से) साथ में पेपर [न्यूरल मशीन ट्रांसलेशन के लिए मल्टीलिंगुअल डीनोइजिंग प्री-ट्रेनिंग](https://arxiv. org/abs/2001.08210) यिनहान लियू, जियाताओ गु, नमन गोयल, जियान ली, सर्गेई एडुनोव, मार्जन ग़ज़विनिनेजाद, माइक लुईस, ल्यूक ज़ेटलमॉयर द्वारा। 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (फेसबुक से) साथ में पेपर [एक्स्टेंसिबल बहुभाषी प्रीट्रेनिंग और फाइनट्यूनिंग के साथ बहुभाषी अनुवाद](https://arxiv युकिंग टैंग, चाउ ट्रान, जियान ली, पेंग-जेन चेन, नमन गोयल, विश्रव चौधरी, जियाताओ गु, एंजेला फैन द्वारा .org/abs/2008.00401)। diff --git a/README_ja.md b/README_ja.md index 6c0f50af716a..9120113dc31e 100644 --- a/README_ja.md +++ b/README_ja.md @@ -381,6 +381,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook から) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin から公開された研究論文: [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg Tiedemann から. [OPUS](http://opus.nlpl.eu/) を使いながら学習された "Machine translation" (マシントランスレーション) モデル. [Marian Framework](https://marian-nmt.github.io/) はMicrosoft Translator Team が現在開発中です. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia から) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei から公開された研究論文: [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (FAIR and UIUC から) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. から公開された研究論文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (Meta and UIUC から) Bowen Cheng, Alexander G. Schwing, Alexander Kirillov から公開された研究論文: [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (Facebook から) Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer から公開された研究論文: [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (Facebook から) Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan から公開された研究論文: [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) diff --git a/README_ko.md b/README_ko.md index 83f240d02ed3..c93eb28415e0 100644 --- a/README_ko.md +++ b/README_ko.md @@ -296,6 +296,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook 에서) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 의 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 논문과 함께 발표했습니다. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia 에서) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 의 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 논문과 함께 발표했습니다. +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (FAIR and UIUC 에서 제공)은 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.의 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)논문과 함께 발표했습니다. 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (Meta and UIUC 에서) Bowen Cheng, Alexander G. Schwing, Alexander Kirillov 의 [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) 논문과 함께 발표했습니다. 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (Facebook 에서) Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer 의 [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 논문과 함께 발표했습니다. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (Facebook 에서) Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 의 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 542c5740b5a9..18cdbc015599 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -320,6 +320,7 @@ conda install -c huggingface transformers 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** 用 [OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (来自 Microsoft Research Asia) 伴随论文 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 由 Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 发布。 +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (来自 FAIR and UIUC) 伴随论文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) 由 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar 发布。 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov >>>>>>> Fix rebase 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 由 Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer 发布。 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index f0b83136f61a..a4de0541fd7c 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -332,6 +332,7 @@ conda install -c huggingface transformers 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. +1. **[Mask2Former](https://huggingface.co/docs/transformers/main/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/docs/source/de/index.mdx b/docs/source/de/index.mdx index c7d6511053ec..031df91237f1 100644 --- a/docs/source/de/index.mdx +++ b/docs/source/de/index.mdx @@ -115,6 +115,7 @@ Die Bibliothek enthält derzeit JAX-, PyTorch- und TensorFlow-Implementierungen, 1. **[M-CTC-T](model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[Mask2Former](model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[mBART](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 33989739082d..8aa735a1e2d8 100755 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -422,6 +422,8 @@ title: ImageGPT - local: model_doc/levit title: LeViT + - local: model_doc/mask2former + title: Mask2Former - local: model_doc/maskformer title: MaskFormer - local: model_doc/mobilenet_v1 diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 6aefe26686ec..335d26ebbbb7 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -133,6 +133,7 @@ The documentation is organized into five sections: 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. +1. **[Mask2Former](model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[mBART](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. @@ -306,6 +307,7 @@ Flax), PyTorch, and/or TensorFlow. | M2M100 | ✅ | ❌ | ✅ | ❌ | ❌ | | Marian | ✅ | ❌ | ✅ | ✅ | ✅ | | MarkupLM | ✅ | ✅ | ✅ | ❌ | ❌ | +| Mask2Former | ❌ | ❌ | ✅ | ❌ | ❌ | | MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | MaskFormerSwin | ❌ | ❌ | ❌ | ❌ | ❌ | | mBART | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/mask2former.mdx b/docs/source/en/model_doc/mask2former.mdx new file mode 100644 index 000000000000..30c8bbdf98a9 --- /dev/null +++ b/docs/source/en/model_doc/mask2former.mdx @@ -0,0 +1,48 @@ + + +# Mask2Former + +## Overview + +The Mask2Former model was proposed in [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. Mask2Former is a unified framework for panoptic, instance and semantic segmentation and features significant performance and efficiency improvements over [MaskFormer](maskformer). + +The abstract from the paper is the following: + +*Image segmentation groups pixels with different semantics, e.g., category or instance membership. Each choice +of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).* + +Tips: +- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`MaskFormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model. +- To get the final segmentation, depending on the task, you can call [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together. + +This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former). + +## MaskFormer specific outputs + +[[autodoc]] models.mask2former.modeling_mask2former.Mask2FormerModelOutput + +[[autodoc]] models.mask2former.modeling_mask2former.Mask2FormerForUniversalSegmentationOutput + +## Mask2FormerConfig + +[[autodoc]] Mask2FormerConfig + +## Mask2FormerModel + +[[autodoc]] Mask2FormerModel + - forward + +## Mask2FormerForUniversalSegmentation + +[[autodoc]] Mask2FormerForUniversalSegmentation + - forward diff --git a/docs/source/en/model_doc/maskformer.mdx b/docs/source/en/model_doc/maskformer.mdx index 4060cbab9a8f..5620c803bf28 100644 --- a/docs/source/en/model_doc/maskformer.mdx +++ b/docs/source/en/model_doc/maskformer.mdx @@ -83,4 +83,4 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The ## MaskFormerForInstanceSegmentation [[autodoc]] MaskFormerForInstanceSegmentation - - forward + - forward \ No newline at end of file diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index 5091a52c8231..30ba547832d1 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -100,6 +100,7 @@ La biblioteca actualmente contiene implementaciones de JAX, PyTorch y TensorFlow 1. **[LXMERT](model_doc/lxmert)** (de UNC Chapel Hill) publicado con el paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) por Hao Tan y Mohit Bansal. 1. **[M2M100](model_doc/m2m_100)** (de Facebook) publicado con el paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) por Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Modelos de traducción automática entrenados usando [OPUS](http://opus.nlpl.eu/) data por Jörg Tiedemann. El [Marian Framework](https://marian-nmt.github.io/) está siendo desarrollado por el equipo de traductores de Microsoft. +1. **[Mask2Former](model_doc/mask2former)** (de FAIR y UIUC) publicado con el paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) por Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](model_doc/maskformer)** (de Meta y UIUC) publicado con el paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) por Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](model_doc/mbart)** (de Facebook) publicado con el paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) por Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](model_doc/mbart)** (de Facebook) publicado con el paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) por Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index e612c3699b59..aa1d7e25d6d9 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -109,7 +109,8 @@ La libreria attualmente contiene implementazioni in JAX, PyTorch e TensorFlow, p 1. **[LXMERT](model_doc/lxmert)** (da UNC Chapel Hill) rilasciato con il paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) da Hao Tan e Mohit Bansal. 1. **[M2M100](model_doc/m2m_100)** (da Facebook) rilasciato con il paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) da Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Modello di machine learning per le traduzioni allenato utilizzando i dati [OPUS](http://opus.nlpl.eu/) di Jörg Tiedemann. Il [Framework Marian](https://marian-nmt.github.io/) è stato sviluppato dal Microsoft Translator Team. -1. **[MaskFormer](model_doc/maskformer)** (da Meta and UIUC) rilasciato con il paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) da Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. +1. **[Mask2Former](model_doc/mask2former)** (da FAIR e UIUC) rilasciato con il paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) da Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. +1. **[MaskFormer](model_doc/maskformer)** (da Meta e UIUC) rilasciato con il paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) da Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](model_doc/mbart)** (da Facebook) rilasciato con il paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) da Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](model_doc/mbart)** (da Facebook) rilasciato con il paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) da Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](model_doc/megatron-bert)** (da NVIDIA) rilasciato con il paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) da Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper e Bryan Catanzaro. diff --git a/docs/source/ko/index.mdx b/docs/source/ko/index.mdx index 0aa73ff2a577..4ae17ab7becd 100644 --- a/docs/source/ko/index.mdx +++ b/docs/source/ko/index.mdx @@ -124,6 +124,7 @@ specific language governing permissions and limitations under the License. 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. +1. **[Mask2Former](model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[mBART](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index 745460f53554..7f1d26b6c447 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -114,6 +114,7 @@ Atualmente a biblioteca contém implementações do PyTorch, TensorFlow e JAX, p 1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[Mask2Former](model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. 1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a6652b853806..8e91bacb26f6 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -322,6 +322,10 @@ "MarkupLMProcessor", "MarkupLMTokenizer", ], + "models.mask2former": [ + "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Mask2FormerConfig", + ], "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"], "models.mbart": ["MBartConfig"], "models.mbart50": [], @@ -1708,6 +1712,14 @@ "MarkupLMPreTrainedModel", ] ) + _import_structure["models.mask2former"].extend( + [ + "MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "Mask2FormerForUniversalSegmentation", + "Mask2FormerModel", + "Mask2FormerPreTrainedModel", + ] + ) _import_structure["models.maskformer"].extend( [ "MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3695,6 +3707,7 @@ MarkupLMProcessor, MarkupLMTokenizer, ) + from .models.mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig from .models.mbart import MBartConfig from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor @@ -4857,6 +4870,12 @@ MarkupLMModel, MarkupLMPreTrainedModel, ) + from .models.mask2former import ( + MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + Mask2FormerForUniversalSegmentation, + Mask2FormerModel, + Mask2FormerPreTrainedModel, + ) from .models.maskformer import ( MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, MaskFormerForInstanceSegmentation, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index a788bd53087a..cf30880faa1c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -102,6 +102,7 @@ m2m_100, marian, markuplm, + mask2former, maskformer, mbart, mbart50, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8a60ea42fbe3..cc3d48fe3be8 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -104,6 +104,7 @@ ("m2m_100", "M2M100Config"), ("marian", "MarianConfig"), ("markuplm", "MarkupLMConfig"), + ("mask2former", "Mask2FormerConfig"), ("maskformer", "MaskFormerConfig"), ("maskformer-swin", "MaskFormerSwinConfig"), ("mbart", "MBartConfig"), @@ -262,6 +263,7 @@ ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mctct", "MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -425,6 +427,7 @@ ("m2m_100", "M2M100"), ("marian", "Marian"), ("markuplm", "MarkupLM"), + ("mask2former", "Mask2Former"), ("maskformer", "MaskFormer"), ("maskformer-swin", "MaskFormerSwin"), ("mbart", "mBART"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 42b658f83a94..3b057fa2a8a7 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -61,6 +61,7 @@ ("layoutlmv2", "LayoutLMv2ImageProcessor"), ("layoutlmv3", "LayoutLMv3ImageProcessor"), ("levit", "LevitImageProcessor"), + ("mask2former", "MaskFormerImageProcessor"), ("maskformer", "MaskFormerImageProcessor"), ("mobilenet_v1", "MobileNetV1ImageProcessor"), ("mobilenet_v2", "MobileNetV2ImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 21882e6f1b5d..4465097dfeed 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -103,6 +103,7 @@ ("m2m_100", "M2M100Model"), ("marian", "MarianModel"), ("markuplm", "MarkupLMModel"), + ("mask2former", "Mask2FormerModel"), ("maskformer", "MaskFormerModel"), ("maskformer-swin", "MaskFormerSwinModel"), ("mbart", "MBartModel"), @@ -454,6 +455,7 @@ [ # Model for Universal Segmentation mapping ("detr", "DetrForSegmentation"), + ("mask2former", "Mask2FormerForUniversalSegmentation"), ("maskformer", "MaskFormerForInstanceSegmentation"), ] ) diff --git a/src/transformers/models/mask2former/__init__.py b/src/transformers/models/mask2former/__init__.py new file mode 100644 index 000000000000..995533c0102d --- /dev/null +++ b/src/transformers/models/mask2former/__init__.py @@ -0,0 +1,64 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_mask2former": [ + "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Mask2FormerConfig", + ], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mask2former"] = [ + "MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "Mask2FormerForUniversalSegmentation", + "Mask2FormerModel", + "Mask2FormerPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mask2former import ( + MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + Mask2FormerForUniversalSegmentation, + Mask2FormerModel, + Mask2FormerPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/mask2former/configuration_mask2former.py b/src/transformers/models/mask2former/configuration_mask2former.py new file mode 100644 index 000000000000..94fc4866bef5 --- /dev/null +++ b/src/transformers/models/mask2former/configuration_mask2former.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mask2Former model configuration""" +import copy +from typing import Dict, List, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/mask2former-swin-small-coco-instance": ( + "https://huggingface.co/facebook/mask2former-swin-small-coco-instance/blob/main/config.json" + ) + # See all Mask2Former models at https://huggingface.co/models?filter=mask2former +} + +logger = logging.get_logger(__name__) + + +class Mask2FormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Mask2FormerModel`]. It is used to instantiate a + Mask2Former model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Mask2Former + [facebook/mask2former-swin-small-coco-instance](https://huggingface.co/facebook/mask2former-swin-small-coco-instance) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Currently, Mask2Former only supports the [Swin Transformer](swin) as backbone. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`): + The configuration of the backbone model. If unset, the configuration corresponding to + `swin-base-patch4-window12-384` will be used. + feature_size (`int`, *optional*, defaults to 256): + The features (channels) of the resulting feature maps. + mask_feature_size (`int`, *optional*, defaults to 256): + The masks' features size, this value will also be used to specify the Feature Pyramid Network features' + size. + hidden_dim (`int`, *optional*, defaults to 256): + Dimensionality of the encoder layers. + encoder_feedforward_dim (`int`, *optional*, defaults to 1024): + Dimension of feedforward network for deformable detr encoder used as part of pixel decoder. + encoder_layers (`int`, *optional*, defaults to 6): + Number of layers in the deformable detr encoder used as part of pixel decoder. + decoder_layers (`int`, *optional*, defaults to 10): + Number of layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder. + dim_feedforward (`int`, *optional*, defaults to 2048): + Feature dimension in feedforward network for transformer decoder. + pre_norm (`bool`, *optional*, defaults to `False`): + Whether to use pre-LayerNorm or not for transformer decoder. + enforce_input_projection (`bool`, *optional*, defaults to `False`): + Whether to add an input projection 1x1 convolution even if the input channels and hidden dim are identical + in the Transformer decoder. + common_stride (`int`, *optional*, defaults to 4): + Parameter used for determining number of FPN levels used as part of pixel decoder. + ignore_value (`int`, *optional*, defaults to 255): + Category id to be ignored during training. + num_queries (`int`, *optional*, defaults to 100): + Number of queries for the decoder. + no_object_weight (`int`, *optional*, defaults to 0.1): + The weight to apply to the null (no object) class. + class_weight (`int`, *optional*, defaults to 2.0): + The weight for the cross entropy loss. + mask_weight (`int`, *optional*, defaults to 5.0): + The weight for the mask loss. + dice_weight (`int`, *optional*, defaults to 5.0): + The weight for the dice loss. + train_num_points (`str` or `function`, *optional*, defaults to 12544): + Number of points used for sampling during loss calculation. + oversample_ratio (`float`, *optional*, defaults to 3.0): + Oversampling parameter used for calculating no. of sampled points + importance_sample_ratio (`float`, *optional*, defaults to 0.75): + Ratio of points that are sampled via importance sampling. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float``, *optional*, defaults to 1.0): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + use_auxiliary_loss (`boolean``, *optional*, defaults to `True`): + If `True` [`Mask2FormerForUniversalSegmentationOutput`] will contain the auxiliary losses computed using + the logits from each decoder's stage. + feature_strides (`List[int]`, *optional*, defaults to `[4, 8, 16, 32]`): + Feature strides corresponding to features generated from backbone network. + output_auxiliary_logits (`bool`, *optional*): + Should the model output its `auxiliary_logits` or not. + + Examples: + + ```python + >>> from transformers import Mask2FormerConfig, Mask2FormerModel + + >>> # Initializing a Mask2Former facebook/mask2former-swin-small-coco-instance configuration + >>> configuration = Mask2FormerConfig() + + >>> # Initializing a model (with random weights) from the facebook/mask2former-swin-small-coco-instance style configuration + >>> model = Mask2FormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + + """ + model_type = "mask2former" + backbones_supported = ["swin"] + attribute_map = {"hidden_size": "hidden_dim"} + + def __init__( + self, + backbone_config: Optional[Dict] = None, + feature_size: int = 256, + mask_feature_size: int = 256, + hidden_dim: int = 256, + encoder_feedforward_dim: int = 1024, + activation_function: str = "relu", + encoder_layers: int = 6, + decoder_layers: int = 10, + num_attention_heads: int = 8, + dropout: float = 0.0, + dim_feedforward: int = 2048, + pre_norm: bool = False, + enforce_input_projection: bool = False, + common_stride: int = 4, + ignore_value: int = 255, + num_queries: int = 100, + no_object_weight: float = 0.1, + class_weight: float = 2.0, + mask_weight: float = 5.0, + dice_weight: float = 5.0, + train_num_points: int = 12544, + oversample_ratio: float = 3.0, + importance_sample_ratio: float = 0.75, + init_std: float = 0.02, + init_xavier_std: float = 1.0, + use_auxiliary_loss: bool = True, + feature_strides: List[int] = [4, 8, 16, 32], + output_auxiliary_logits: bool = None, + **kwargs, + ): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") + backbone_config = CONFIG_MAPPING["swin"]( + image_size=224, + in_channels=3, + patch_size=4, + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + drop_path_rate=0.3, + use_absolute_embeddings=False, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + self.feature_size = feature_size + self.mask_feature_size = mask_feature_size + self.hidden_dim = hidden_dim + self.encoder_feedforward_dim = encoder_feedforward_dim + self.activation_function = activation_function + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.dim_feedforward = dim_feedforward + self.pre_norm = pre_norm + self.enforce_input_projection = enforce_input_projection + self.common_stride = common_stride + self.ignore_value = ignore_value + self.num_queries = num_queries + self.no_object_weight = no_object_weight + self.class_weight = class_weight + self.mask_weight = mask_weight + self.dice_weight = dice_weight + self.train_num_points = train_num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.use_auxiliary_loss = use_auxiliary_loss + self.feature_strides = feature_strides + self.output_auxiliary_logits = output_auxiliary_logits + self.num_hidden_layers = decoder_layers + + super().__init__(**kwargs) + + @classmethod + def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`Mask2FormerConfig`] (or a derived class) from a pre-trained backbone model configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`Mask2FormerConfig`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + def to_dict(self) -> Dict[str, any]: + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..19bba580783f --- /dev/null +++ b/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,1020 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from pprint import pformat +from typing import Any, Dict, Iterator, List, Set, Tuple + +import torch +import torchvision.transforms as T +from PIL import Image +from torch import Tensor, nn + +import requests +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.projects.deeplab import add_deeplab_config +from huggingface_hub import hf_hub_download +from transformers import ( + Mask2FormerConfig, + Mask2FormerForUniversalSegmentation, + Mask2FormerModel, + MaskFormerImageProcessor, + SwinConfig, +) +from transformers.models.mask2former.modeling_mask2former import ( + Mask2FormerForUniversalSegmentationOutput, + Mask2FormerModelOutput, +) +from transformers.utils import logging + + +StateDict = Dict[str, Tensor] + +logging.set_verbosity_info() +logger = logging.get_logger() + +torch.manual_seed(0) + + +class TrackedStateDict: + def __init__(self, to_track: Dict): + """This class "tracks" a python dictionary by keeping track of which item is accessed. + + Args: + to_track (Dict): The dictionary we wish to track + """ + self.to_track = to_track + self._seen: Set[str] = set() + + def __getitem__(self, key: str) -> Any: + return self.to_track[key] + + def __setitem__(self, key: str, item: Any): + self._seen.add(key) + self.to_track[key] = item + + def diff(self) -> List[str]: + """This method returns a set difference between the keys in the tracked state dict and the one we have access so far. + This is an effective method to check if we have update all the keys + + Returns: + List[str]: List of keys not yet updated + """ + return set(list(self.to_track.keys())) - self._seen + + def copy(self) -> Dict: + # proxy the call to the internal dictionary + return self.to_track.copy() + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + img_data = requests.get(url, stream=True).raw + im = Image.open(img_data) + return im + + +@dataclass +class Args: + """Fake command line arguments needed by mask2former/detectron implementation""" + + config_file: str + + +def setup_cfg(args: Args): + # load config from file and command-line arguments + cfg = get_cfg() + add_deeplab_config(cfg) + add_maskformer2_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + return cfg + + +class OriginalMask2FormerConfigToOursConverter: + def __call__(self, original_config: object) -> Mask2FormerConfig: + model = original_config.MODEL + + repo_id = "huggingface/label-files" + if model.SEM_SEG_HEAD.NUM_CLASSES == 847: + filename = "mask2former-ade20k-full-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 150: + filename = "ade20k-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 80: + filename = "coco-detection-mmdet-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 171: + filename = "mask2former-coco-stuff-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 133: + filename = "coco-panoptic-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 19: + filename = "cityscapes-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 8: + filename = "cityscapes-instance-id2label.json" + elif model.SEM_SEG_HEAD.NUM_CLASSES == 65: + filename = "mapillary-vistas-id2label.json" + + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + label2id = {label: idx for idx, label in id2label.items()} + + if model.SWIN.EMBED_DIM == 96: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ) + elif model.SWIN.EMBED_DIM == 128: + backbone_config = SwinConfig( + embed_dim=128, + window_size=12, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + + elif model.SWIN.EMBED_DIM == 192: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-large-patch4-window12-384", out_features=["stage1", "stage2", "stage3", "stage4"] + ) + else: + raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!") + + backbone_config.drop_path_rate = model.SWIN.DROP_PATH_RATE + backbone_config.attention_probs_dropout_prob = model.SWIN.ATTN_DROP_RATE + backbone_config.depths = model.SWIN.DEPTHS + + config: Mask2FormerConfig = Mask2FormerConfig( + ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + num_queries=model.MASK_FORMER.NUM_OBJECT_QUERIES, + no_object_weight=model.MASK_FORMER.NO_OBJECT_WEIGHT, + class_weight=model.MASK_FORMER.CLASS_WEIGHT, + mask_weight=model.MASK_FORMER.MASK_WEIGHT, + dice_weight=model.MASK_FORMER.DICE_WEIGHT, + train_num_points=model.MASK_FORMER.TRAIN_NUM_POINTS, + oversample_ratio=model.MASK_FORMER.OVERSAMPLE_RATIO, + importance_sample_ratio=model.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, + init_std=0.02, + init_xavier_std=1.0, + use_auxiliary_loss=model.MASK_FORMER.DEEP_SUPERVISION, + feature_strides=[4, 8, 16, 32], + backbone_config=backbone_config, + id2label=id2label, + label2id=label2id, + feature_size=model.SEM_SEG_HEAD.CONVS_DIM, + mask_feature_size=model.SEM_SEG_HEAD.MASK_DIM, + hidden_dim=model.MASK_FORMER.HIDDEN_DIM, + encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS, + encoder_feedforward_dim=1024, + decoder_layers=model.MASK_FORMER.DEC_LAYERS, + num_attention_heads=model.MASK_FORMER.NHEADS, + dropout=model.MASK_FORMER.DROPOUT, + dim_feedforward=model.MASK_FORMER.DIM_FEEDFORWARD, + pre_norm=model.MASK_FORMER.PRE_NORM, + enforce_input_proj=model.MASK_FORMER.ENFORCE_INPUT_PROJ, + common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE, + ) + return config + + +class OriginalMask2FormerConfigToFeatureExtractorConverter: + def __call__(self, original_config: object) -> MaskFormerImageProcessor: + model = original_config.MODEL + model_input = original_config.INPUT + + return MaskFormerImageProcessor( + image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), + image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), + size=model_input.MIN_SIZE_TEST, + max_size=model_input.MAX_SIZE_TEST, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + ignore_index=model.SEM_SEG_HEAD.IGNORE_VALUE, + size_divisibility=32, + ) + + +class OriginalMask2FormerCheckpointToOursConverter: + def __init__(self, original_model: nn.Module, config: Mask2FormerConfig): + self.original_model = original_model + self.config = config + + def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict): + for src_key, dst_key in renamed_keys: + dst_state_dict[dst_key] = src_state_dict.pop(src_key) + + def replace_maskformer_swin_backbone( + self, dst_state_dict: StateDict, src_state_dict: StateDict, config: Mask2FormerConfig + ): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + renamed_keys = [ + ( + f"{src_prefix}.patch_embed.proj.weight", + f"{dst_prefix}.model.embeddings.patch_embeddings.projection.weight", + ), + (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.model.embeddings.patch_embeddings.projection.bias"), + (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.model.embeddings.norm.weight"), + (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.model.embeddings.norm.bias"), + ] + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias", + ), + ] + ) + + # second norm + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias", + ), + ] + ) + + # mlp + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias", + ), + ] + ) + + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index", + ) + ] + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.{layer_idx}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.{layer_idx}.bias", + ), + ] + ) + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: Mask2FormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + renamed_keys = [ + ( + f"{src_prefix}.patch_embed.proj.weight", + f"{dst_prefix}.embeddings.patch_embeddings.projection.weight", + ), + (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"), + (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"), + (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"), + ] + + for layer_idx in range(len(config.backbone_config.depths)): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias", + ), + ] + ) + + # second norm + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias", + ), + ] + ) + + # mlp + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias", + ), + ] + ) + + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index", + ) + ] + ) + + if layer_idx < 3: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias", + ), + ] + ) + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Backbone + Pixel Decoder + def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "pixel_level_module.decoder" + src_prefix: str = "sem_seg_head.pixel_decoder" + + self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config) + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str): + self_attn_keys = [] + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets") + ) + self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj")) + + return self_attn_keys + + def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str): + encoder_keys = [] + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1")) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2")) + encoder_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm") + ) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm")) + encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")) + + return encoder_keys + + # convolution layer for final features + renamed_keys = [ + (f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"), + (f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"), + (f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"), + ] + + renamed_keys.extend( + [ + (f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"), + (f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"), + (f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"), + ] + ) + + # proj layers + for i in range(3): + for j in range(2): + renamed_keys.extend( + [ + (f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"), + (f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"), + ] + ) + + renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")]) + + # layers + for layer_idx in range(self.config.encoder_layers): + renamed_keys.extend( + rename_keys_for_encoder_layer( + f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}" + ) + ) + + # proj + renamed_keys.extend( + [ + (f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"), + (f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"), + ] + ) + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Transformer Decoder + def rename_keys_in_masked_attention_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder" + src_prefix: str = "sem_seg_head.predictor" + + rename_keys = [] + for i in range(self.config.decoder_layers - 1): + + rename_keys.append( + ( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.out_proj.weight", + f"{dst_prefix}.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.out_proj.bias", + f"{dst_prefix}.layers.{i}.self_attn.out_proj.bias", + ) + ) + + rename_keys.append( + ( + f"{src_prefix}.transformer_self_attention_layers.{i}.norm.weight", + f"{dst_prefix}.layers.{i}.self_attn_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_self_attention_layers.{i}.norm.bias", + f"{dst_prefix}.layers.{i}.self_attn_layer_norm.bias", + ) + ) + + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.in_proj_weight", + f"{dst_prefix}.layers.{i}.cross_attn.in_proj_weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.in_proj_bias", + f"{dst_prefix}.layers.{i}.cross_attn.in_proj_bias", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.out_proj.weight", + f"{dst_prefix}.layers.{i}.cross_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.out_proj.bias", + f"{dst_prefix}.layers.{i}.cross_attn.out_proj.bias", + ) + ) + + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.norm.weight", + f"{dst_prefix}.layers.{i}.cross_attn_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_cross_attention_layers.{i}.norm.bias", + f"{dst_prefix}.layers.{i}.cross_attn_layer_norm.bias", + ) + ) + + rename_keys.append( + (f"{src_prefix}.transformer_ffn_layers.{i}.linear1.weight", f"{dst_prefix}.layers.{i}.fc1.weight") + ) + rename_keys.append( + (f"{src_prefix}.transformer_ffn_layers.{i}.linear1.bias", f"{dst_prefix}.layers.{i}.fc1.bias") + ) + rename_keys.append( + (f"{src_prefix}.transformer_ffn_layers.{i}.linear2.weight", f"{dst_prefix}.layers.{i}.fc2.weight") + ) + rename_keys.append( + (f"{src_prefix}.transformer_ffn_layers.{i}.linear2.bias", f"{dst_prefix}.layers.{i}.fc2.bias") + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_ffn_layers.{i}.norm.weight", + f"{dst_prefix}.layers.{i}.final_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.transformer_ffn_layers.{i}.norm.bias", + f"{dst_prefix}.layers.{i}.final_layer_norm.bias", + ) + ) + + return rename_keys + + def replace_masked_attention_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder" + src_prefix: str = "sem_seg_head.predictor" + + renamed_keys = self.rename_keys_in_masked_attention_decoder(dst_state_dict, src_state_dict) + + # add more + renamed_keys.extend( + [ + (f"{src_prefix}.decoder_norm.weight", f"{dst_prefix}.layernorm.weight"), + (f"{src_prefix}.decoder_norm.bias", f"{dst_prefix}.layernorm.bias"), + ] + ) + + mlp_len = 3 + for i in range(mlp_len): + renamed_keys.extend( + [ + ( + f"{src_prefix}.mask_embed.layers.{i}.weight", + f"{dst_prefix}.mask_predictor.mask_embedder.{i}.0.weight", + ), + ( + f"{src_prefix}.mask_embed.layers.{i}.bias", + f"{dst_prefix}.mask_predictor.mask_embedder.{i}.0.bias", + ), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder.layers" + src_prefix: str = "sem_seg_head.predictor" + for i in range(self.config.decoder_layers - 1): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight" + ) + in_proj_bias = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias" + ) + # next, add query, keys and values (in that order) to the state dict + dst_state_dict[f"{dst_prefix}.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module" + src_prefix: str = "sem_seg_head.predictor" + + self.replace_masked_attention_decoder(dst_state_dict, src_state_dict) + + renamed_keys = [ + (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"), + (f"{src_prefix}.query_feat.weight", f"{dst_prefix}.queries_features.weight"), + (f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"), + ] + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict) + + def replace_universal_segmentation_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "" + src_prefix: str = "sem_seg_head.predictor" + + renamed_keys = [ + (f"{src_prefix}.class_embed.weight", f"{dst_prefix}class_predictor.weight"), + (f"{src_prefix}.class_embed.bias", f"{dst_prefix}class_predictor.bias"), + ] + + logger.info(f"Replacing keys {pformat(renamed_keys)}") + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def convert(self, mask2former: Mask2FormerModel) -> Mask2FormerModel: + dst_state_dict = TrackedStateDict(mask2former.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_pixel_module(dst_state_dict, src_state_dict) + self.replace_transformer_module(dst_state_dict, src_state_dict) + + logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}") + logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}") + logger.info("🙌 Done") + + state_dict = {key: dst_state_dict[key] for key in dst_state_dict.to_track.keys()} + mask2former.load_state_dict(state_dict) + return mask2former + + def convert_universal_segmentation( + self, mask2former: Mask2FormerForUniversalSegmentation + ) -> Mask2FormerForUniversalSegmentation: + dst_state_dict = TrackedStateDict(mask2former.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_universal_segmentation_module(dst_state_dict, src_state_dict) + + state_dict = {key: dst_state_dict[key] for key in dst_state_dict.to_track.keys()} + mask2former.load_state_dict(state_dict) + + return mask2former + + @staticmethod + def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]: + checkpoints: List[Path] = checkpoints_dir.glob("**/*.pkl") + + for checkpoint in checkpoints: + logger.info(f"💪 Converting {checkpoint.stem}") + # find associated config file + + # dataset_name e.g 'coco' + dataset_name = checkpoint.parents[2].stem + if dataset_name == "ade": + dataset_name = dataset_name.replace("ade", "ade20k") + + # task type e.g 'instance-segmentation' + segmentation_task = checkpoint.parents[1].stem + + # config file corresponding to checkpoint + config_file_name = f"{checkpoint.parents[0].stem}.yaml" + + config: Path = config_dir / dataset_name / segmentation_task / "swin" / config_file_name + yield config, checkpoint + + +def test( + original_model, + our_model: Mask2FormerForUniversalSegmentation, + feature_extractor: MaskFormerImageProcessor, + tolerance: float, +): + with torch.no_grad(): + original_model = original_model.eval() + our_model = our_model.eval() + + im = prepare_img() + x = feature_extractor(images=im, return_tensors="pt")["pixel_values"] + + original_model_backbone_features = original_model.backbone(x.clone()) + our_model_output: Mask2FormerModelOutput = our_model.model(x.clone(), output_hidden_states=True) + + # Test backbone + for original_model_feature, our_model_feature in zip( + original_model_backbone_features.values(), our_model_output.encoder_hidden_states + ): + assert torch.allclose( + original_model_feature, our_model_feature, atol=tolerance + ), "The backbone features are not the same." + + # Test pixel decoder + mask_features, _, multi_scale_features = original_model.sem_seg_head.pixel_decoder.forward_features( + original_model_backbone_features + ) + + for original_model_feature, our_model_feature in zip( + multi_scale_features, our_model_output.pixel_decoder_hidden_states + ): + assert torch.allclose( + original_model_feature, our_model_feature, atol=tolerance + ), "The pixel decoder feature are not the same" + + # Let's test the full model + tr_complete = T.Compose( + [T.Resize((384, 384)), T.ToTensor()], + ) + y = (tr_complete(im) * 255.0).to(torch.int).float() + + # modify original Mask2Former code to return mask and class logits + original_class_logits, original_mask_logits = original_model([{"image": y.clone().squeeze(0)}]) + + our_model_out: Mask2FormerForUniversalSegmentationOutput = our_model(x.clone()) + our_mask_logits = our_model_out.masks_queries_logits + our_class_logits = our_model_out.class_queries_logits + + assert original_mask_logits.shape == our_mask_logits.shape, "Output masks shapes are not matching." + assert original_class_logits.shape == our_class_logits.shape, "Output class logits shapes are not matching." + assert torch.allclose( + original_class_logits, our_class_logits, atol=tolerance + ), "The class logits are not the same." + assert torch.allclose( + original_mask_logits, our_mask_logits, atol=tolerance + ), "The predicted masks are not the same." + + logger.info("✅ Test passed!") + + +def get_model_name(checkpoint_file: Path): + # model_name_raw is something like maskformer2_swin_small_bs16_50ep + model_name_raw: str = checkpoint_file.parents[0].stem + + # `segmentation_task_type` must be one of the following: `instance-segmentation`, `panoptic-segmentation`, `semantic-segmentation` + segmentation_task_name: str = checkpoint_file.parents[1].stem + if segmentation_task_name not in ["instance-segmentation", "panoptic-segmentation", "semantic-segmentation"]: + raise ValueError( + f"{segmentation_task_name} must be wrong since acceptable values are: instance-segmentation," + " panoptic-segmentation, semantic-segmentation." + ) + + # dataset name must be one of the following: `coco`, `ade`, `cityscapes`, `mapillary-vistas` + dataset_name: str = checkpoint_file.parents[2].stem + if dataset_name not in ["coco", "ade", "cityscapes", "mapillary-vistas"]: + raise ValueError( + f"{dataset_name} must be wrong since we didn't find 'coco' or 'ade' or 'cityscapes' or 'mapillary-vistas'" + " in it " + ) + + backbone = "swin" + backbone_types = ["tiny", "small", "base_IN21k", "base", "large"] + backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0].replace("_", "-") + + model_name = f"mask2former-{backbone}-{backbone_type}-{dataset_name}-{segmentation_task_name.split('-')[0]}" + + return model_name + + +if __name__ == "__main__": + parser = ArgumentParser( + description="Command line to convert the original mask2formers (with swin backbone) to our implementations." + ) + + parser.add_argument( + "--checkpoints_dir", + type=Path, + help=( + "A directory containing the model's checkpoints. The directory has to have the following structure:" + " ///.pkl" + ), + ) + parser.add_argument( + "--configs_dir", + type=Path, + help=( + "A directory containing the model's configs, see detectron2 doc. The directory has to have the following" + " structure: ///.yaml" + ), + ) + parser.add_argument( + "--mask2former_dir", + required=True, + type=Path, + help=( + "A path to Mask2Former's original implementation directory. You can download from here:" + " https://github.com/facebookresearch/Mask2Former" + ), + ) + + args = parser.parse_args() + + checkpoints_dir: Path = args.checkpoints_dir + config_dir: Path = args.configs_dir + mask2former_dir: Path = args.mask2former_dir + # append the path to the parents to mask2former dir + sys.path.append(str(mask2former_dir.parent)) + # import original Mask2Former config and model from original source code repo + from Mask2Former.mask2former.config import add_maskformer2_config + from Mask2Former.mask2former.maskformer_model import MaskFormer as OriginalMask2Former + + for config_file, checkpoint_file in OriginalMask2FormerCheckpointToOursConverter.using_dirs( + checkpoints_dir, config_dir + ): + model_name = get_model_name(checkpoint_file) + feature_extractor = OriginalMask2FormerConfigToFeatureExtractorConverter()( + setup_cfg(Args(config_file=config_file)) + ) + feature_extractor.size = {"height": 384, "width": 384} + + original_config = setup_cfg(Args(config_file=config_file)) + mask2former_kwargs = OriginalMask2Former.from_config(original_config) + original_model = OriginalMask2Former(**mask2former_kwargs).eval() + + DetectionCheckpointer(original_model).load(str(checkpoint_file)) + + config: Mask2FormerConfig = OriginalMask2FormerConfigToOursConverter()(original_config) + mask2former = Mask2FormerModel(config=config).eval() + + converter = OriginalMask2FormerCheckpointToOursConverter(original_model, config) + mask2former = converter.convert(mask2former) + + mask2former_for_segmentation = Mask2FormerForUniversalSegmentation(config=config).eval() + mask2former_for_segmentation.model = mask2former + + mask2former_for_segmentation = converter.convert_universal_segmentation(mask2former_for_segmentation) + + tolerance = 3e-1 + high_tolerance_models = [ + "mask2former-swin-base-IN21k-coco-instance", + "mask2former-swin-base-coco-instance", + "mask2former-swin-small-cityscapes-semantic", + ] + + if model_name in high_tolerance_models: + tolerance = 3e-1 + + logger.info(f"🪄 Testing {model_name}...") + test(original_model, mask2former_for_segmentation, feature_extractor, tolerance) + logger.info(f"🪄 Pushing {model_name} to hub...") + + feature_extractor.push_to_hub(model_name) + mask2former_for_segmentation.push_to_hub(model_name) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py new file mode 100644 index 000000000000..a9032e9d2db1 --- /dev/null +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -0,0 +1,2494 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mask2Former model.""" + +import math +import random +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor, nn + +from transformers import AutoBackbone, SwinConfig +from transformers.utils import logging + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + replace_return_docstrings, + requires_backends, +) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions +from ...modeling_utils import PreTrainedModel +from .configuration_mask2former import Mask2FormerConfig + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "Mask2FormerConfig" +_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance" +_IMAGE_PROCESSOR_FOR_DOC = "MaskFormerImageProcessor" + +MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/mask2former-swin-small-coco-instance", + # See all mask2former models at https://huggingface.co/models?filter=mask2former +] + + +@dataclass +class Mask2FormerPixelDecoderOutput(ModelOutput): + """ + Mask2Former's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns + the mask features and the multiscale features. + + Args: + multi_scale_features (`tuple(torch.FloatTensor)`): + Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height, + width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder. + mask_features (`torch.FloatTensor`): + Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder + Layer. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed + or when `config.output_attentions=True` + """ + + multi_scale_features: Tuple[torch.FloatTensor] = None + mask_features: torch.FloatTensor = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Mask2FormerMaskedAttentionDecoderOutput(BaseModelOutputWithCrossAttentions): + """ + Base class for outputs of the Transformer decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions for mask predictions logits and a tuple of intermediate decoder activations, + i.e. the output of each decoder layer, each of them gone through a layernorm. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. Returned when `output_hidden_states=True`. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. Returned when `output_attentions=True`. + masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`): + Tuple of mask predictions from all layers of the transformer decoder. + intermediate_hidden_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[torch.FloatTensor] = None + masks_queries_logits: Tuple[torch.FloatTensor] = None + intermediate_hidden_states: Tuple[torch.FloatTensor] = None + + +@dataclass +class Mask2FormerPixelLevelModuleOutput(ModelOutput): + """ + Mask2Former's pixel level module output. It returns the output of the encoder (optional) and all hidden states + (multi-scale features) from the `decoder`. By default, the `encoder` is a Swin Backbone and the `decoder` is a + Multi-Scale Deformable Attention based decoder. + + The `decoder_last_hidden_state` are the **per-pixel embeddings** while `decoder_hidden_states` refer to multi-scale + feature maps produced using **multi-scaling strategy** defined in the paper. + + Args: + encoder_last_hidden_state (`torch.FloatTensor`): + Last hidden states (final feature map of shape `(batch_size, num_channels, height, width)`) of the last + stage of the encoder. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also + called feature maps) of the model at the output of each stage. Returned if output_hidden_states is set to + True. + decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)): + 1/4 scale features from the last Pixel Decoder Layer. + decoder_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also + called feature maps) of the model at the output of each stage. + """ + + encoder_last_hidden_state: torch.FloatTensor = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_last_hidden_state: torch.FloatTensor = None + decoder_hidden_states: Tuple[torch.FloatTensor] = None + + +@dataclass +class Mask2FormerModelOutput(ModelOutput): + """ + Class for outputs of [`Mask2FormerModel`]. This class returns all the needed hidden states to compute the logits. + + Args: + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). Returned when + `output_hidden_states=True` is passed. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. Returned when `output_hidden_states=True` is passed. + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Last hidden states (final feature map) of the last stage of the pixel decoder model. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, , *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. Returned when `output_hidden_states=True` is passed. + transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): + Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. Returned when `output_hidden_states=True` is passed. + transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`) + Mask Predictions from each layer in the transformer decoder. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self attentions weights from transformer decoder. + """ + + encoder_last_hidden_state: torch.FloatTensor = None + pixel_decoder_last_hidden_state: torch.FloatTensor = None + transformer_decoder_last_hidden_state: torch.FloatTensor = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_intermediate_states: Tuple[torch.FloatTensor] = None + masks_queries_logits: Tuple[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Mask2FormerForUniversalSegmentationOutput(ModelOutput): + """ + Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`]. + + This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or + [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or + [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see + [`~MaskFormerImageProcessor] for details regarding usage. + + Args: + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + auxiliary_logits (`List[Dict(str, torch.FloatTensor)]`, *optional*): + List of class and mask predictions from each layer of the transformer decoder. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the pixel decoder model. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): + Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: torch.FloatTensor = None + masks_queries_logits: torch.FloatTensor = None + auxiliary_logits: Optional[List[Dict[str, torch.FloatTensor]]] = None + encoder_last_hidden_state: torch.FloatTensor = None + pixel_decoder_last_hidden_state: torch.FloatTensor = None + transformer_decoder_last_hidden_state: torch.FloatTensor = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.detr.modeling_detr._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): + """ + Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. + """ + batch_size, source_len = mask.size() + target_len = target_len if target_len is not None else source_len + + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +def sample_point( + input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs +) -> torch.Tensor: + """ + A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. + + Args: + input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): + A tensor that contains features map on a height * width grid + point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: + 2)): + A tensor that contains [0, 1] * [0, 1] normalized point coordinates + add_dim (`bool`): + boolean value to keep track of added dimension + + Returns: + point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, + height_grid, width_grid): + A tensor that contains features for points in `point_coordinates`. + """ + if point_coordinates.dim() == 3: + add_dim = True + point_coordinates = point_coordinates.unsqueeze(2) + + # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation + point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) + if add_dim: + point_features = point_features.squeeze(3) + + return point_features + + +# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: + r""" + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss = criterion(inputs, labels) + + loss = cross_entropy_loss.mean(1).sum() / num_masks + return loss + + +# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + r""" + A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss between each pairs. + """ + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum( + "nc,mc->nm", cross_entropy_loss_neg, (1 - labels) + ) + loss = loss / height_and_width + return loss + + +# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py +class Mask2FormerHungarianMatcher(nn.Module): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + """ + + def __init__( + self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 + ): + """Creates the matcher + + Params: + cost_class (`float`, *optional*, defaults to 1.0): + Relative weight of the classification error in the matching cost. + cost_mask (`float`, *optional*, defaults to 1.0): + This is the relative weight of the focal loss of the binary mask in the matching cost. + cost_dice (`float`, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost. + num_points (`int`, *optional*, defaults to 12544): + No. of points to sample on which the mask loss will be calculated. The same set of K points are + uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite + matching. + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs cant be 0") + + self.num_points = num_points + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + @torch.no_grad() + def forward( + self, + masks_queries_logits: torch.Tensor, + class_queries_logits: torch.Tensor, + mask_labels: torch.Tensor, + class_labels: torch.Tensor, + ) -> List[Tuple[Tensor]]: + """ + Params: + masks_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. + class_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. + class_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the + target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes, height, width` containing the target masks. + + Returns: + matched_indices (`List[Tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) + where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes). + """ + indices: List[Tuple[np.array]] = [] + + # iterate through batch size + batch_size = masks_queries_logits.shape[0] + for i in range(batch_size): + pred_probs = class_queries_logits[i].softmax(-1) + pred_mask = masks_queries_logits[i] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -pred_probs[:, class_labels[i]] + target_mask = mask_labels[i].to(pred_mask) + target_mask = target_mask[:, None] + pred_mask = pred_mask[:, None] + + # Sample ground truth and predicted masks + point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) + + target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1) + target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) + + pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1) + pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) + + # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) + cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) + # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels) + cost_dice = pair_wise_dice_loss(pred_mask, target_mask) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + # do the assigmented using the hungarian algorithm in scipy + assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + +# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py +class Mask2FormerLoss(nn.Module): + def __init__(self, config: Mask2FormerConfig, weight_dict: Dict[str, float]): + """ + The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we + compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair + of matched ground-truth / prediction (supervise class and mask) + + Args: + config (`Mask2FormerConfig`): + The configuration for Mask2Former model also containing loss calculation specific parameters. + weight_dict (`Dict[str, float]`): + A dictionary of weights to be applied to the different losses. + """ + super().__init__() + requires_backends(self, ["scipy"]) + self.num_labels = config.num_labels + self.weight_dict = weight_dict + + # Weight to apply to the null class + self.eos_coef = config.no_object_weight + empty_weight = torch.ones(self.num_labels + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + # pointwise mask loss parameters + self.num_points = config.train_num_points + self.oversample_ratio = config.oversample_ratio + self.importance_sample_ratio = config.importance_sample_ratio + + self.matcher = Mask2FormerHungarianMatcher( + cost_class=1.0, + cost_dice=config.dice_weight, + cost_mask=config.mask_weight, + num_points=self.num_points, + ) + + def _max_by_axis(self, sizes: List[List[int]]) -> List[int]: + maxes = sizes[0] + for sublist in sizes[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + # Adapted from nested_tensor_from_tensor_list() in original implementation + def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + # compute final size + batch_shape = [len(tensors)] + max_size + batch_size, _, height, width = batch_shape + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] + ) -> Dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) + target_classes_o = torch.cat( + [target[j] for target, (_, j) in zip(class_labels, indices)] + ) # shape of (batch_size, num_queries) + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, + masks_queries_logits: torch.Tensor, + mask_labels: List[torch.Tensor], + indices: Tuple[np.array], + num_masks: int, + ) -> Dict[str, torch.Tensor]: + """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + + # No need to upsample predictions as we are using normalized coordinates + pred_masks = pred_masks[:, None] + target_masks = target_masks[:, None] + + # Sample point coordinates + with torch.no_grad(): + point_coordinates = self.sample_points_using_uncertainty( + pred_masks, + lambda logits: self.calculate_uncertainty(logits), + self.num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + + point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) + + point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) + + losses = { + "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), + "loss_dice": dice_loss(point_logits, point_labels, num_masks), + } + + del pred_masks + del target_masks + return losses + + def _get_predictions_permutation_indices(self, indices): + # Permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # Permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: + """ + In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' + for the foreground class in `classes`. + + Args: + logits (`torch.Tensor`): + A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: + the number of foreground classes. The values are logits. + + Returns: + scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most + uncertain locations having the highest uncertainty score. + """ + uncertainty_scores = -(torch.abs(logits)) + return uncertainty_scores + + def sample_points_using_uncertainty( + self, + logits: torch.Tensor, + uncertainty_function, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, + ) -> torch.Tensor: + """ + This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The + uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit + prediction as input. + + Args: + logits (`float`): + Logit predictions for P points. + uncertainty_function: + A function that takes logit predictions for P points and returns their uncertainties. + num_points (`int`): + The number of points P to sample. + oversample_ratio (`int`): + Oversampling parameter. + importance_sample_ratio (`float`): + Ratio of points that are sampled via importance sampling. + + Returns: + point_coordinates (`torch.Tensor`): + Coordinates for P sampled points. + """ + + num_boxes = logits.shape[0] + num_points_sampled = int(num_points * oversample_ratio) + + # Get random point coordinates + point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) + # Get sampled prediction value for the point coordinates + point_logits = sample_point(logits, point_coordinates, align_corners=False) + # Calculate the uncertainties based on the sampled prediction values of the points + point_uncertainties = uncertainty_function(point_logits) + + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) + idx += shift[:, None] + point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + + if num_random_points > 0: + point_coordinates = torch.cat( + [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], + dim=1, + ) + return point_coordinates + + def forward( + self, + masks_queries_logits: torch.Tensor, + class_queries_logits: torch.Tensor, + mask_labels: List[torch.Tensor], + class_labels: List[torch.Tensor], + auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + class_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, num_labels)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], then it contains the logits from + the inner layers of the Mask2FormerMaskedAttentionDecoder. + + Returns: + losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], the dictionary contains additional + losses for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses + losses: Dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum([len(classes) for classes in class_labels]) + num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + return num_masks_pt + + +def multi_scale_deformable_attention( + value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor +): + batch_size, _, num_attn_head, hidden_dim = value.shape + _, num_queries, num_attn_head, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + + for idx, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_attn_head, hidden_dim -> batch_size*num_attn_head, hidden_dim, height, width + value_l_ = ( + value_list[idx].flatten(2).transpose(1, 2).reshape(batch_size * num_attn_head, hidden_dim, height, width) + ) + # (batch_size, num_queries, num_attn_head, num_points) -> (batch_size * num_attn_head, num_queries, num_points, 2) + sampling_grid_l_ = sampling_grids[:, :, :, idx].transpose(1, 2).flatten(0, 1) + # batch_size*num_attn_head, D_, num_queries, num_points + sampling_value_l_ = torch.nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + + # (batch_size, num_queries, num_attn_head, num_levels, num_points) -> (batch_size, num_attn_head, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_attn_head, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_attn_head * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with MaskFormer->Mask2Former +class Mask2FormerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention +class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" + ) + dim_per_head = embed_dim // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 128 + + self.d_model = embed_dim + self.n_levels = n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = nn.functional.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output, attention_weights + + +class Mask2FormerPixelDecoderEncoderLayer(nn.Module): + def __init__(self, config: Mask2FormerConfig): + super().__init__() + self.embed_dim = config.feature_size + self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + n_levels=3, + n_points=4, + ) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = nn.functional.relu + self.activation_dropout = config.dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim) + self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Input to the layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Attention mask. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings, to be added to `hidden_states`. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes of the backbone feature maps. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights.transpose(1, 0),) + + return outputs + + +# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->Mask2FormerPixelDecoderEncoderOnly +class Mask2FormerPixelDecoderEncoderOnly(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a + [`Mask2FormerPixelDecoderEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through + multiple deformable attention layers. + + Args: + config: Mask2FormerConfig + """ + + def __init__(self, config: Mask2FormerConfig): + super().__init__() + + self.config = config + self.dropout = config.dropout + self.layers = nn.ModuleList( + [Mask2FormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)] + ) + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """ + Get reference points for each feature map. Used in decoder. + + Args: + spatial_shapes (`torch.LongTensor`): + Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`. + valid_ratios (`torch.FloatTensor`): + Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. + device (`torch.device`): + Device on which to create the tensors. + Returns: + `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` + """ + reference_points_list = [] + for lvl, (height, width) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + indexing="ij", + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + + return reference_points + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states.transpose(1, 0),) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states.transpose(1, 0),) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrModel with DeformableDetrModel->Mask2FormerPixelDecoder +class Mask2FormerPixelDecoder(nn.Module): + def __init__(self, config: Mask2FormerConfig, feature_channels): + super().__init__() + + self.config = config + + feature_dim = config.feature_size + mask_dim = config.mask_feature_size + num_pos_features = feature_dim // 2 + + self.position_embedding = Mask2FormerSinePositionEmbedding(num_pos_feats=num_pos_features, normalize=True) + self.num_feature_levels = 3 + transformer_in_channels = feature_channels[-self.num_feature_levels :] + + self.transformer_feature_strides = config.feature_strides[-self.num_feature_levels :] + self.feature_channels = feature_channels + self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, feature_dim)) + + # Create input projection layers + if self.num_feature_levels > 1: + input_projections_list = [] + for in_channels in transformer_in_channels[::-1]: + input_projections_list.append( + nn.Sequential( + nn.Conv2d(in_channels, feature_dim, kernel_size=1), + nn.GroupNorm(32, feature_dim), + ) + ) + self.input_projections = nn.ModuleList(input_projections_list) + else: + self.input_projections = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(transformer_in_channels[-1], feature_dim, kernel_size=1), + nn.GroupNorm(32, feature_dim), + ) + ] + ) + + self.encoder = Mask2FormerPixelDecoderEncoderOnly(config) + self.mask_projection = nn.Conv2d(feature_dim, mask_dim, kernel_size=1, stride=1, padding=0) + + # Extra FPN levels + stride = min(self.transformer_feature_strides) + self.common_stride = config.common_stride + self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) + + lateral_convs = [] + output_convs = [] + + for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): + lateral_conv = nn.Sequential( + nn.Conv2d(in_channels, feature_dim, kernel_size=1, bias=False), + nn.GroupNorm(32, feature_dim), + ) + + output_conv = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False), + nn.GroupNorm(32, feature_dim), + nn.ReLU(), + ) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + + # Order convolutional layers from low to high resolution + self.lateral_convolutions = lateral_convs[::-1] + self.output_convolutions = output_convs[::-1] + + def get_valid_ratio(self, mask): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) + valid_ratio_heigth = valid_height.float() / height + valid_ratio_width = valid_width.float() / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + return valid_ratio + + def forward( + self, + features, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + input_embeds = [] + position_embeddings = [] + for level, x in enumerate(features[::-1][: self.num_feature_levels]): + input_embeds.append(self.input_projections[level](x.float())) + position_embeddings.append(self.position_embedding(x.float())) + + masks = [ + torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds + ] + + # Prepare encoder inputs (by flattening) + spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] + input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device) + masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) + + position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] + level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] + level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) + + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1) + + # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=input_embeds_flat, + attention_mask=masks_flat, + position_embeddings=level_pos_embed_flat, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + batch_size = last_hidden_state.shape[0] + + split_sizes = [None] * self.num_feature_levels + for i in range(self.num_feature_levels): + if i < self.num_feature_levels - 1: + split_sizes[i] = level_start_index[i + 1] - level_start_index[i] + else: + split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] + + encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) + + # Compute final features + outputs = [ + x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) + for i, x in enumerate(encoder_output) + ] + + # Append extra FPN levels to outputs, ordered from low to high resolution + for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): + lateral_conv = self.lateral_convolutions[idx] + output_conv = self.output_convolutions[idx] + current_fpn = lateral_conv(feature.float()) + + # Following FPN implementation, we use nearest upsampling here + out = current_fpn + nn.functional.interpolate( + outputs[-1], size=current_fpn.shape[-2:], mode="bilinear", align_corners=False + ) + out = output_conv(out) + outputs.append(out) + + num_cur_levels = 0 + multi_scale_features = [] + + for out in outputs: + if num_cur_levels < self.num_feature_levels: + multi_scale_features.append(out) + num_cur_levels += 1 + + return Mask2FormerPixelDecoderOutput( + mask_features=self.mask_projection(outputs[-1]), + multi_scale_features=tuple(multi_scale_features), + attentions=encoder_outputs.attentions, + ) + + +class Mask2FormerPixelLevelModule(nn.Module): + def __init__(self, config: Mask2FormerConfig): + """ + Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image + Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel + decoder, generating multi-scale feature maps and pixel embeddings. + + Args: + config ([`Mask2FormerConfig`]): + The configuration used to instantiate this model. + """ + super().__init__() + + backbone_config_dict = config.backbone_config.to_dict() + backbone_config = SwinConfig.from_dict(backbone_config_dict) + + self.encoder = AutoBackbone.from_config(backbone_config) + self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) + + def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: + backbone_features = self.encoder(pixel_values).feature_maps + decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states) + + return Mask2FormerPixelLevelModuleOutput( + encoder_last_hidden_state=backbone_features[-1], + encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None, + decoder_last_hidden_state=decoder_output.mask_features, + decoder_hidden_states=decoder_output.multi_scale_features, + ) + + +# Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->Mask2Former +class Mask2FormerAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and + keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None + position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None + key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None + key_value_position_embeddings = ( + key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output).permute(1, 0, 2) + + return attn_output, attn_weights_reshaped + + +class Mask2FormerMaskedAttentionDecoderLayer(nn.Module): + """ + The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN + blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked + attention` block that restricts the attention to localized features centered around predicted segments which leads + to faster convergence and improved performance. The order of self and cross (i.e. masked) attention blocks have + also been swapped in Mask2FormerMaskedAttentionDecoder compared to a standard DetrDecoder as an optimization + improvement. + + Args: + config (`Mask2FormerConfig`): + The configuration used to initialize the Mask2FormerMaskedAttentionDecoder. + """ + + def __init__(self, config: Mask2FormerConfig): + super().__init__() + self.config = config + self.embed_dim = self.config.hidden_dim + self.pre_norm = self.config.pre_norm + self.self_attn = Mask2FormerAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + is_decoder=True, + ) + + self.dropout = self.config.dropout + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout = self.config.dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout) + self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward) + self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + hidden_states: torch.Tensor, + level_index: int = None, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + + # Masked(Cross)-Attention Block + cross_attn_weights = None + self_attn_weights = None + + residual = hidden_states + + hidden_states, cross_attn_weights = self.cross_attn( + query=self.with_pos_embed(hidden_states, query_position_embeddings), + key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), + value=encoder_hidden_states[level_index], + attn_mask=encoder_attention_mask, + key_padding_mask=None, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.cross_attn_layer_norm(hidden_states) + + # Self Attention Block + residual = hidden_states + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + attention_mask=None, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + def forward_pre( + self, + hidden_states: torch.Tensor, + level_index: int = None, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + + # Masked(Cross)-Attention Block + cross_attn_weights = None + self_attn_weights = None + + residual = hidden_states + + hidden_states = self.cross_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.cross_attn( + query=self.with_pos_embed(hidden_states, query_position_embeddings), + key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), + value=encoder_hidden_states[level_index], + attn_mask=encoder_attention_mask, + key_padding_mask=None, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Self Attention Block + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + attention_mask=None, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + def forward( + self, + hidden_states: torch.Tensor, + level_index: int = None, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(1, seq_len, tgt_len, src_len)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the keys in the masked-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + Cross attention input to the layer of shape `(seq_len, batch, embed_dim)`. + encoder_attention_mask (`torch.FloatTensor`): + Encoder attention mask of size`(1, seq_len, tgt_len, src_len)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + if self.pre_norm: + outputs = self.forward_pre( + hidden_states=hidden_states, + level_index=level_index, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + else: + outputs = self.forward_post( + hidden_states=hidden_states, + level_index=level_index, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + return outputs + + +class Mask2FormerMaskedAttentionDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + [`Mask2FormerMaskedAttentionDecoderLayer`]. The decoder updates the query embeddings through multiple cross + (masked) and self-attention layers. The decoder uses a new **masked attention** mechanism instead of the standard + cross-attention, which extracts localized features by constraining cross-attention to within the foreground region + of the predicted mask for each query, instead of attending to the full feature map. + + Args: + config: (`Mask2FormerConfig`): + Configuration used to instantiate Mask2FormerMaskedAttentionDecoder. + """ + + def __init__(self, config: Mask2FormerConfig): + super().__init__() + + self.config = config + self.mask_feature_size = config.mask_feature_size + self.dropout = config.dropout + self.layerdrop = config.dropout + self.num_feature_levels = 3 # level embedding (3 scales) + self.decoder_layers = config.decoder_layers - 1 + + self.layers = nn.ModuleList( + [Mask2FormerMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)] + ) + self.layernorm = nn.LayerNorm(config.hidden_dim) + + self.mask_predictor = Mask2FormerMaskPredictor( + hidden_size=config.hidden_dim, + num_heads=config.num_attention_heads, + mask_feature_size=self.mask_feature_size, + ) + + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds: torch.Tensor = None, + multi_stage_positional_embeddings: torch.Tensor = None, + pixel_embeddings: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + query_position_embeddings: torch.Tensor = None, + feature_size_list: List = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): + The query embeddings that are passed into the decoder. + multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): + Position embeddings that are added to the keys in each cross(masked)-attention layer. + pixel_embeddings (`torch.FloatTensor`): + Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel + Decoder. + query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): + , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross(masked)-attention of the decoder. + feature_size_list (`List[torch.Size]` ): + This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # intermediate hidden states with layernorm applied - required for predicting class logits + intermediate = () + + # decoder layers + all_hidden_states = () if output_hidden_states else None + attentions = () if output_attentions else None + + # intermediate mask predictions from transformer decoder layers + intermediate_mask_predictions = () + + intermediate_hidden_states = self.layernorm(inputs_embeds) + intermediate += (intermediate_hidden_states,) + + predicted_mask, attention_mask = self.mask_predictor( + intermediate_hidden_states, pixel_embeddings, feature_size_list[0] + ) + intermediate_mask_predictions += (predicted_mask,) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + + if self.training and (dropout_probability < self.layerdrop): + continue + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, + None, + ) + + else: + level_index = idx % self.num_feature_levels + + attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False + + layer_outputs = decoder_layer( + hidden_states, + level_index=level_index, + position_embeddings=multi_stage_positional_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + intermediate_hidden_states = self.layernorm(layer_outputs[0]) + + predicted_mask, attention_mask = self.mask_predictor( + intermediate_hidden_states, + pixel_embeddings, + feature_size_list[(idx + 1) % self.num_feature_levels], + ) + + intermediate_mask_predictions += (predicted_mask,) + + # add intermediate hidden states with layer norm applied which will be used for predicting class logits + intermediate += (intermediate_hidden_states,) + + hidden_states = layer_outputs[0] + + if output_attentions: + attentions += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = hidden_states.transpose(1, 0) + if not return_dict: + outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] + return tuple(v for v in outputs if v is not None) + + return Mask2FormerMaskedAttentionDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=attentions, + intermediate_hidden_states=intermediate, + masks_queries_logits=intermediate_mask_predictions, + ) + + +# Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock with MaskFormer->Mask2Former +class Mask2FormerPredictionBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: + super().__init__() + self.layers = [nn.Linear(in_dim, out_dim), activation] + # Maintain submodule indexing as if part of a Sequential block + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class Mask2FormerMLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): + """ + A classic Multi Layer Perceptron (MLP). + + Args: + input_dim (`int`): + The input dimensions. + hidden_dim (`int`): + The hidden dimensions. + output_dim (`int`): + The output dimensions. + num_layers (int, *optional*, defaults to 3): + The number of layers. + """ + super().__init__() + in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) + out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] + + self.layers = [] + for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + activation = nn.ReLU() if i < num_layers - 1 else nn.Identity() + layer = Mask2FormerPredictionBlock(in_dim, out_dim, activation=activation) + self.layers.append(layer) + # Provide backwards compatibility from when the class inherited from nn.Sequential + # In nn.Sequential subclasses, the name given to the layer is its index in the sequence. + # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g. + # self.my_layer_name = Layer() + # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register + # explicitly + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class Mask2FormerMaskPredictor(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Tensor): + """ + This class is used to get the predicted mask for a given Mask2FormerMaskedAttentionDecoder layer. It also + generates the binarized attention mask associated with the given predicted mask. The attention mask obtained + using predicted mask of the (l-1)th decoder layer is fed to the cross(masked)-attention block of the next + decoder layer as input. + + Args: + hidden_size (`int`): + The feature dimension of the Mask2FormerMaskedAttentionDecoder + num_heads (`int`): + The number of heads used in the Mask2FormerMaskedAttentionDecoder + mask_feature_size: (`torch.Tensor`): + one of the output dimensions of the predicted masks for each query + """ + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + + self.mask_embedder = Mask2FormerMLPPredictionHead(self.hidden_size, self.hidden_size, mask_feature_size) + + def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): + + mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) + + # Sum up over the channels + outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) + + attention_mask = nn.functional.interpolate( + outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False + ) + + attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1) + attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool() + attention_mask = attention_mask.detach() + + return outputs_mask, attention_mask + + +class Mask2FormerTransformerModule(nn.Module): + """ + The Mask2Former's transformer module. + """ + + def __init__(self, in_features: int, config: Mask2FormerConfig): + super().__init__() + hidden_dim = config.hidden_dim + self.num_feature_levels = 3 + self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) + self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) + self.queries_features = nn.Embedding(config.num_queries, hidden_dim) + self.input_projections = [] + + for _ in range(self.num_feature_levels): + if in_features != hidden_dim or config.enforce_input_projection: + self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) + else: + self.input_projections.append(nn.Sequential()) + + self.decoder = Mask2FormerMaskedAttentionDecoder(config=config) + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + + def forward( + self, + multi_scale_features: List[Tensor], + mask_features: Tensor, + output_hidden_states: bool = False, + output_attentions: bool = False, + ) -> Mask2FormerMaskedAttentionDecoderOutput: + + multi_stage_features = [] + multi_stage_positional_embeddings = [] + size_list = [] + + for i in range(self.num_feature_levels): + size_list.append(multi_scale_features[i].shape[-2:]) + multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) + multi_stage_features.append( + self.input_projections[i](multi_scale_features[i]).flatten(2) + + self.level_embed.weight[i][None, :, None] + ) + + # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) + multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) + multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) + + _, batch_size, _ = multi_stage_features[0].shape + + # [num_queries, batch_size, num_channels] + query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) + query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) + + decoder_output = self.decoder( + inputs_embeds=query_features, + multi_stage_positional_embeddings=multi_stage_positional_embeddings, + pixel_embeddings=mask_features, + encoder_hidden_states=multi_stage_features, + query_position_embeddings=query_embeddings, + feature_size_list=size_list, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + + return decoder_output + + +MASK2FORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Mask2FormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MASK2FORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of Detr's decoder attention layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~Mask2FormerModelOutput`] instead of a plain tuple. +""" + + +class Mask2FormerPreTrainedModel(PreTrainedModel): + config_class = Mask2FormerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module): + xavier_std = self.config.init_xavier_std + std = self.config.init_std + + if isinstance(module, Mask2FormerTransformerModule): + if module.input_projections is not None: + for input_projection in module.input_projections: + if not isinstance(input_projection, nn.Sequential): + nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) + nn.init.constant_(input_projection.bias, 0) + + elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + + elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + + elif isinstance(module, Mask2FormerPixelLevelModule): + for submodule in module.modules(): + if isinstance(submodule, (nn.Conv2d, nn.Linear)): + submodule.weight.data.normal_(mean=0.0, std=std) + if submodule.bias is not None: + submodule.bias.data.zero_() + + elif isinstance(module, Mask2FormerPixelDecoder): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + nn.init.normal_(module.level_embed, std=0) + + elif isinstance(module, Mask2FormerPixelDecoderEncoderOnly): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if hasattr(module, "reference_points"): + nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) + nn.init.constant_(module.reference_points.bias.data, 0.0) + + +@add_start_docstrings( + "The bare Mask2Former Model outputting raw hidden-states without any specific head on top.", + MASK2FORMER_START_DOCSTRING, +) +class Mask2FormerModel(Mask2FormerPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: Mask2FormerConfig): + super().__init__(config) + self.pixel_level_module = Mask2FormerPixelLevelModule(config) + self.transformer_module = Mask2FormerTransformerModule(in_features=config.feature_size, config=config) + + self.post_init() + + @add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Mask2FormerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Mask2FormerModelOutput: + r""" + Returns: + `Mask2FormerModelOutput` + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoImageProcessor, Mask2FormerModel + + >>> # download texting image + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # Load image preprocessor and Mask2FormerModel trained on ADE20K instance segmentation dataset + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-instance") + >>> model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-instance") + >>> inputs = image_processor(image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ``` + """ + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output = self.pixel_level_module( + pixel_values=pixel_values, output_hidden_states=output_hidden_states + ) + + transformer_module_output = self.transformer_module( + multi_scale_features=pixel_level_module_output.decoder_hidden_states, + mask_features=pixel_level_module_output.decoder_last_hidden_state, + output_hidden_states=True, + output_attentions=output_attentions, + ) + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + transformer_decoder_intermediate_states = None + + if output_hidden_states: + encoder_hidden_states = pixel_level_module_output.encoder_hidden_states + pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states + transformer_decoder_hidden_states = transformer_module_output.hidden_states + transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states + + output = Mask2FormerModelOutput( + encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, + pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, + transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, + attentions=transformer_module_output.attentions, + masks_queries_logits=transformer_module_output.masks_queries_logits, + ) + + if not return_dict: + output = tuple(v for v in output.values() if v is not None) + + return output + + +@add_start_docstrings( + "The Mask2Former Model with heads on top for instance/semantic/panoptic segmentation.", + MASK2FORMER_START_DOCSTRING, +) +class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: Mask2FormerConfig): + super().__init__(config) + self.model = Mask2FormerModel(config) + + self.weight_dict: Dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1) + + self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict) + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + auxiliary_predictions: Dict[str, Tensor], + ) -> Dict[str, Tensor]: + loss_dict: Dict[str, Tensor] = self.criterion( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=auxiliary_predictions, + ) + + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor): + auxiliary_logits: List[Dict(str, Tensor)] = [] + + for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]): + auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}) + + return auxiliary_logits + + @add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Mask2FormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + mask_labels: Optional[List[Tensor]] = None, + class_labels: Optional[List[Tensor]] = None, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_auxiliary_logits: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Mask2FormerForUniversalSegmentationOutput: + r""" + mask_labels (`List[torch.Tensor]`, *optional*): + List of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`List[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + + Returns: + `Mask2FormerUniversalSegmentationOutput` + + Examples: + ```python + >>> from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> # Load Mask2Former trained on ADE20K panoptic segmentation dataset + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-panoptic") + >>> model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-panoptic") + + >>> url = ( + ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + ... ) + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = image_processor(image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # Model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # Perform post-processing to get semantic, instance or panoptic segmentation maps + >>> pred_semantic_map = image_processor.post_process_semantic_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0] + >>> pred_instance_map = image_processor.post_process_instance_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0]["segmentation"] + >>> pred_panoptic_map = image_processor.post_process_panoptic_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0]["segmentation"] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, + output_attentions=output_attentions, + return_dict=True, + ) + + loss, loss_dict, auxiliary_logits = None, None, None + class_queries_logits = () + + for decoder_output in outputs.transformer_decoder_intermediate_states: + class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) + class_queries_logits += (class_prediction,) + + masks_queries_logits = outputs.masks_queries_logits + + auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) + + if mask_labels is not None and class_labels is not None: + loss_dict = self.get_loss_dict( + masks_queries_logits=masks_queries_logits[-1], + class_queries_logits=class_queries_logits[-1], + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=auxiliary_logits, + ) + loss = self.get_loss(loss_dict) + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + + if output_hidden_states: + encoder_hidden_states = outputs.encoder_hidden_states + pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states + transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states + + output_auxiliary_logits = ( + self.config.use_auxiliary_loss if output_auxiliary_logits is None else output_auxiliary_logits + ) + if not output_auxiliary_logits: + auxiliary_logits = None + + output = Mask2FormerForUniversalSegmentationOutput( + loss=loss, + class_queries_logits=class_queries_logits[-1], + masks_queries_logits=masks_queries_logits[-1], + auxiliary_logits=auxiliary_logits, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, + transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + attentions=outputs.attentions, + ) + + if not return_dict: + output = tuple(v for v in output.values() if v is not None) + if loss is not None: + output = ((loss)) + output + return output diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 8ff9fbcde5f9..1ae0f4f2f3a1 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3568,6 +3568,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class Mask2FormerForUniversalSegmentation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Mask2FormerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Mask2FormerPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/mask2former/__init__.py b/tests/models/mask2former/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py new file mode 100644 index 000000000000..a3e95088ef50 --- /dev/null +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Mask2Former model. """ + +import inspect +import unittest + +import numpy as np + +from tests.test_modeling_common import floats_tensor +from transformers import Mask2FormerConfig, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin + + +if is_torch_available(): + import torch + + from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel + + if is_vision_available(): + from transformers import MaskFormerImageProcessor + +if is_vision_available(): + from PIL import Image + + +class Mask2FormerModelTester: + def __init__( + self, + parent, + batch_size=2, + is_training=True, + use_auxiliary_loss=False, + num_queries=10, + num_channels=3, + min_size=32 * 8, + max_size=32 * 8, + num_labels=4, + hidden_dim=64, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_auxiliary_loss = use_auxiliary_loss + self.num_queries = num_queries + self.num_channels = num_channels + self.min_size = min_size + self.max_size = max_size + self.num_labels = num_labels + self.hidden_dim = hidden_dim + self.mask_feature_size = hidden_dim + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to( + torch_device + ) + + pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device) + + mask_labels = ( + torch.rand([self.batch_size, self.num_labels, self.min_size, self.max_size], device=torch_device) > 0.5 + ).float() + class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long() + + config = self.get_config() + return config, pixel_values, pixel_mask, mask_labels, class_labels + + def get_config(self): + config = Mask2FormerConfig( + hidden_size=self.hidden_dim, + ) + config.num_queries = self.num_queries + config.num_labels = self.num_labels + + config.backbone_config.depths = [1, 1, 1, 1] + config.backbone_config.num_channels = self.num_channels + + config.encoder_feedforward_dim = 64 + config.dim_feedforward = 128 + config.hidden_dim = self.hidden_dim + config.mask_feature_size = self.hidden_dim + config.feature_size = self.hidden_dim + return config + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, _, _ = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + return config, inputs_dict + + def check_output_hidden_state(self, output, config): + encoder_hidden_states = output.encoder_hidden_states + pixel_decoder_hidden_states = output.pixel_decoder_hidden_states + transformer_decoder_hidden_states = output.transformer_decoder_hidden_states + + self.parent.assertTrue(len(encoder_hidden_states), len(config.backbone_config.depths)) + self.parent.assertTrue(len(pixel_decoder_hidden_states), len(config.backbone_config.depths)) + self.parent.assertTrue(len(transformer_decoder_hidden_states), config.decoder_layers) + + def create_and_check_mask2former_model(self, config, pixel_values, pixel_mask, output_hidden_states=False): + with torch.no_grad(): + model = Mask2FormerModel(config=config) + model.to(torch_device) + model.eval() + + output = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + output = model(pixel_values, output_hidden_states=True) + + self.parent.assertEqual( + output.transformer_decoder_last_hidden_state.shape, + (self.batch_size, self.num_queries, self.hidden_dim), + ) + # let's ensure the other two hidden state exists + self.parent.assertTrue(output.pixel_decoder_last_hidden_state is not None) + self.parent.assertTrue(output.encoder_last_hidden_state is not None) + + if output_hidden_states: + self.check_output_hidden_state(output, config) + + def create_and_check_mask2former_instance_segmentation_head_model( + self, config, pixel_values, pixel_mask, mask_labels, class_labels + ): + model = Mask2FormerForUniversalSegmentation(config=config) + model.to(torch_device) + model.eval() + + def comm_check_on_output(result): + # let's still check that all the required stuff is there + self.parent.assertTrue(result.transformer_decoder_last_hidden_state is not None) + self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None) + self.parent.assertTrue(result.encoder_last_hidden_state is not None) + # okay, now we need to check the logits shape + # due to the encoder compression, masks have a //4 spatial size + self.parent.assertEqual( + result.masks_queries_logits.shape, + (self.batch_size, self.num_queries, self.min_size // 4, self.max_size // 4), + ) + # + 1 for null class + self.parent.assertEqual( + result.class_queries_logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1) + ) + + with torch.no_grad(): + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + comm_check_on_output(result) + + result = model( + pixel_values=pixel_values, pixel_mask=pixel_mask, mask_labels=mask_labels, class_labels=class_labels + ) + + comm_check_on_output(result) + + self.parent.assertTrue(result.loss is not None) + self.parent.assertEqual(result.loss.shape, torch.Size([1])) + + +@require_torch +class Mask2FormerModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else () + + is_encoder_decoder = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = Mask2FormerModelTester(self) + self.config_tester = ConfigTester(self, config_class=Mask2FormerConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mask2former_model(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_mask2former_model(config, **inputs, output_hidden_states=False) + + def test_mask2former_instance_segmentation_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mask2former_instance_segmentation_head_model(*config_and_inputs) + + @unittest.skip(reason="Mask2Former does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Mask2Former does not have a get_input_embeddings method") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Mask2Former is not a generative model") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="Mask2Former does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @require_torch_multi_gpu + @unittest.skip( + reason="Mask2Former has some layers using `add_module` which doesn't work well with `nn.DataParallel`" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @slow + def test_model_from_pretrained(self): + for model_name in ["facebook/mask2former-swin-small-coco-instance"]: + model = Mask2FormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_model_with_labels(self): + size = (self.model_tester.min_size,) * 2 + inputs = { + "pixel_values": torch.randn((2, 3, *size), device=torch_device), + "mask_labels": torch.randn((2, 10, *size), device=torch_device), + "class_labels": torch.zeros(2, 10, device=torch_device).long(), + } + config = self.model_tester.get_config() + + model = Mask2FormerForUniversalSegmentation(config).to(torch_device) + outputs = model(**inputs) + self.assertTrue(outputs.loss is not None) + + def test_hidden_states_output(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_mask2former_model(config, **inputs, output_hidden_states=True) + + def test_attention_outputs(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + outputs = model(**inputs, output_attentions=True) + self.assertTrue(outputs.attentions is not None) + + def test_training(self): + if not self.model_tester.is_training: + return + + model_class = self.all_model_classes[1] + config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs() + + model = model_class(config) + model.to(torch_device) + model.train() + + loss = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels).loss + loss.backward() + + def test_retain_grad_hidden_states_attentions(self): + model_class = self.all_model_classes[1] + config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs() + config.output_hidden_states = True + config.output_attentions = True + + model = model_class(config).to(torch_device) + model.train() + + outputs = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels) + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states[0] + pixel_decoder_hidden_states.retain_grad() + + transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states[0] + transformer_decoder_hidden_states.retain_grad() + + attentions = outputs.attentions[0] + attentions.retain_grad() + + outputs.loss.backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(pixel_decoder_hidden_states.grad) + self.assertIsNotNone(transformer_decoder_hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + +TOLERANCE = 1e-4 + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_vision +@slow +class Mask2FormerModelIntegrationTest(unittest.TestCase): + @cached_property + def model_checkpoints(self): + return "facebook/mask2former-swin-small-coco-instance" + + @cached_property + def default_feature_extractor(self): + return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None + + def test_inference_no_head(self): + model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device) + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(image, return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size is divisible by 32 + self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0) + # check size + self.assertEqual(inputs_shape, (1, 3, 384, 384)) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_slice_hidden_state = torch.tensor( + [[-0.2790, -1.0717, -1.1668], [-0.5128, -0.3128, -0.4987], [-0.5832, 0.1971, -0.0197]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[0.8973, 1.1847, 1.1776], [1.1934, 1.5040, 1.5128], [1.1153, 1.4486, 1.4951]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[2.1152, 1.7000, -0.8603], [1.5808, 1.8004, -0.9353], [1.6043, 1.7495, -0.5999]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + def test_inference_universal_segmentation_head(self): + model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(image, return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size is divisible by 32 + self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0) + # check size + self.assertEqual(inputs_shape, (1, 3, 384, 384)) + + with torch.no_grad(): + outputs = model(**inputs) + # masks_queries_logits + masks_queries_logits = outputs.masks_queries_logits + self.assertEqual( + masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) + ) + expected_slice = [ + [-8.7839, -9.0056, -8.8121], + [-7.4104, -7.0313, -6.5401], + [-6.6105, -6.3427, -6.4675], + ] + expected_slice = torch.tensor(expected_slice).to(torch_device) + self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) + # class_queries_logits + class_queries_logits = outputs.class_queries_logits + self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1)) + expected_slice = torch.tensor( + [ + [1.8324, -8.0835, -4.1922], + [0.8450, -9.0050, -3.6053], + [0.3045, -7.7293, -3.0275], + ] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_with_segmentation_maps_and_loss(self): + model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + feature_extractor = self.default_feature_extractor + + inputs = feature_extractor( + [np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))], + segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)], + return_tensors="pt", + ) + + inputs["pixel_values"] = inputs["pixel_values"].to(torch_device) + inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]] + inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]] + + with torch.no_grad(): + outputs = model(**inputs) + + self.assertTrue(outputs.loss is not None)