diff --git a/README.md b/README.md
index 0906c65deeda..32ab0fc894c8 100644
--- a/README.md
+++ b/README.md
@@ -407,6 +407,7 @@ Current number of checkpoints: ** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang.
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
diff --git a/README_es.md b/README_es.md
index 341fd87923ca..7bcaab0cc3d7 100644
--- a/README_es.md
+++ b/README_es.md
@@ -407,6 +407,7 @@ Número actual de puntos de control: ** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang.
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
diff --git a/README_hd.md b/README_hd.md
index 194aa1ab7a8b..6443d230ee1e 100644
--- a/README_hd.md
+++ b/README_hd.md
@@ -380,6 +380,7 @@ conda install -c huggingface transformers
1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (माइक्रोसॉफ्ट रिसर्च से) साथ में दिया गया पेपर [UniSpeech: यूनिफाइड स्पीच रिप्रेजेंटेशन लर्निंग विद लेबलेड एंड अनलेबल्ड डेटा](https:/ /arxiv.org/abs/2101.07597) चेंगई वांग, यू वू, याओ कियान, केनिची कुमातानी, शुजी लियू, फुरु वेई, माइकल ज़ेंग, ज़ुएदोंग हुआंग द्वारा।
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (माइक्रोसॉफ्ट रिसर्च से) कागज के साथ [UNISPEECH-SAT: यूनिवर्सल स्पीच रिप्रेजेंटेशन लर्निंग विद स्पीकर अवेयर प्री-ट्रेनिंग ](https://arxiv.org/abs/2110.05752) सानयुआन चेन, यू वू, चेंग्यी वांग, झेंगयांग चेन, झूओ चेन, शुजी लियू, जियान वू, याओ कियान, फुरु वेई, जिन्यु ली, जियांगज़ान यू द्वारा पोस्ट किया गया।
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (सिंघुआ यूनिवर्सिटी और ननकाई यूनिवर्सिटी से) साथ में पेपर [विजुअल अटेंशन नेटवर्क](https://arxiv.org/ pdf/2202.09741.pdf) मेंग-हाओ गुओ, चेंग-ज़े लू, झेंग-निंग लियू, मिंग-मिंग चेंग, शि-मिन हू द्वारा।
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (मल्टीमीडिया कम्प्यूटिंग ग्रुप, नानजिंग यूनिवर्सिटी से) साथ में पेपर [वीडियोएमएई: मास्क्ड ऑटोएन्कोडर स्व-पर्यवेक्षित वीडियो प्री-ट्रेनिंग के लिए डेटा-कुशल सीखने वाले हैं] (https://arxiv.org/abs/2203.12602) ज़ान टोंग, यिबिंग सॉन्ग, जुए द्वारा वांग, लिमिन वांग द्वारा पोस्ट किया गया।
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain से) साथ में कागज [ViLT: Vision-and-Language Transformer बिना कनवल्शन या रीजन सुपरविजन](https://arxiv.org/abs/2102.03334) वोनजे किम, बोक्यूंग सोन, इल्डू किम द्वारा पोस्ट किया गया।
diff --git a/README_ja.md b/README_ja.md
index 72f23dbeae3d..6c0f50af716a 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -442,6 +442,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (Google Research から) Yi Tay, Mostafa Dehghani, Vinh Q から公開された研究論文: [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (Microsoft Research から) Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang から公開された研究論文: [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (Microsoft Research から) Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu から公開された研究論文: [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752)
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (Peking University から) Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. から公開された研究論文 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221)
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (Tsinghua University and Nankai University から) Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu から公開された研究論文: [Visual Attention Network](https://arxiv.org/abs/2202.09741)
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (Multimedia Computing Group, Nanjing University から) Zhan Tong, Yibing Song, Jue Wang, Limin Wang から公開された研究論文: [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602)
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain から) Wonjae Kim, Bokyung Son, Ildoo Kim から公開された研究論文: [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334)
diff --git a/README_ko.md b/README_ko.md
index 8d0443fd6f50..83f240d02ed3 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -357,6 +357,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (Google Research 에서) Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzle 의 [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) 논문과 함께 발표했습니다.
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (Microsoft Research 에서) Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 의 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 논문과 함께 발표했습니다.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (Microsoft Research 에서) Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu 의 [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) 논문과 함께 발표했습니다.
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (Peking University 에서 제공)은 Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.의 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221)논문과 함께 발표했습니다.
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (Tsinghua University and Nankai University 에서) Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu 의 [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) 논문과 함께 발표했습니다.
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (Multimedia Computing Group, Nanjing University 에서) Zhan Tong, Yibing Song, Jue Wang, Limin Wang 의 [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) 논문과 함께 발표했습니다.
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain 에서) Wonjae Kim, Bokyung Son, Ildoo Kim 의 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 논문과 함께 발표했습니다.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 8a7b507599b3..542c5740b5a9 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -381,6 +381,7 @@ conda install -c huggingface transformers
1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (来自 Microsoft Research) 伴随论文 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 由 Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 发布。
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (来自 Microsoft Research) 伴随论文 [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) 由 Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu 发布。
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (来自 Peking University) 伴随论文 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) 由 Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun 发布。
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (来自 Tsinghua University and Nankai University) 伴随论文 [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) 由 Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu 发布。
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (来自 Multimedia Computing Group, Nanjing University) 伴随论文 [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) 由 Zhan Tong, Yibing Song, Jue Wang, Limin Wang 发布。
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (来自 NAVER AI Lab/Kakao Enterprise/Kakao Brain) 伴随论文 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 由 Wonjae Kim, Bokyung Son, Ildoo Kim 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 5d0f1b9057a3..f0b83136f61a 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -393,6 +393,7 @@ conda install -c huggingface transformers
1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[UPerNet](https://huggingface.co/docs/transformers/main/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.
1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang.
1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 3573c6070cdc..d242c20b3b28 100755
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -450,6 +450,8 @@
title: Table Transformer
- local: model_doc/timesformer
title: TimeSformer
+ - local: model_doc/upernet
+ title: UperNet
- local: model_doc/van
title: VAN
- local: model_doc/videomae
diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index 7f5f80dba063..6aefe26686ec 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -194,6 +194,7 @@ The documentation is organized into five sections:
1. **[UL2](model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler
1. **[UniSpeech](model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[UPerNet](model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.
1. **[VAN](model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
1. **[VideoMAE](model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang.
1. **[ViLT](model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
@@ -363,6 +364,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
+| UPerNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| VAN | ❌ | ❌ | ✅ | ❌ | ❌ |
| VideoMAE | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViLT | ❌ | ❌ | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/upernet.mdx b/docs/source/en/model_doc/upernet.mdx
new file mode 100644
index 000000000000..cd83e2c6723c
--- /dev/null
+++ b/docs/source/en/model_doc/upernet.mdx
@@ -0,0 +1,65 @@
+
+
+# UPerNet
+
+## Overview
+
+The UPerNet model was proposed in [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221)
+by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. UPerNet is a general framework to effectively segment
+a wide range of concepts from images, leveraging any vision backbone like [ConvNeXt](convnext) or [Swin](swin).
+
+The abstract from the paper is the following:
+
+*Humans recognize the visual world at multiple levels: we effortlessly categorize scenes and detect objects inside, while also identifying the textures and surfaces of the objects along with their different compositional parts. In this paper, we study a new task called Unified Perceptual Parsing, which requires the machine vision systems to recognize as many visual concepts as possible from a given image. A multi-task framework called UPerNet and a training strategy are developed to learn from heterogeneous image annotations. We benchmark our framework on Unified Perceptual Parsing and show that it is able to effectively segment a wide range of concepts from images. The trained networks are further applied to discover visual knowledge in natural scenes.*
+
+
+
+ UPerNet framework. Taken from the original paper.
+
+This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code is based on OpenMMLab's mmsegmentation [here](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/uper_head.py).
+
+## Usage
+
+UPerNet is a general framework for semantic segmentation. It can be used with any vision backbone, like so:
+
+```py
+from transformers import SwinConfig, UperNetConfig, UperNetForSemanticSegmentation
+
+backbone_config = SwinConfig(out_features=["stage1", "stage2", "stage3", "stage4"])
+
+config = UperNetConfig(backbone_config=backbone_config)
+model = UperNetForSemanticSegmentation(config)
+```
+
+To use another vision backbone, like [ConvNeXt](convnext), simply instantiate the model with the appropriate backbone:
+
+```py
+from transformers import ConvNextConfig, UperNetConfig, UperNetForSemanticSegmentation
+
+backbone_config = ConvNextConfig(out_features=["stage1", "stage2", "stage3", "stage4"])
+
+config = UperNetConfig(backbone_config=backbone_config)
+model = UperNetForSemanticSegmentation(config)
+```
+
+Note that this will randomly initialize all the weights of the model.
+
+## UperNetConfig
+
+[[autodoc]] UperNetConfig
+
+## UperNetForSemanticSegmentation
+
+[[autodoc]] UperNetForSemanticSegmentation
+ - forward
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 829c0a18bdc2..a6652b853806 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -424,6 +424,7 @@
"UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"UniSpeechSatConfig",
],
+ "models.upernet": ["UperNetConfig"],
"models.van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
"models.videomae": ["VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VideoMAEConfig"],
"models.vilt": [
@@ -1224,6 +1225,7 @@
_import_structure["models.convnext"].extend(
[
"CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "ConvNextBackbone",
"ConvNextForImageClassification",
"ConvNextModel",
"ConvNextPreTrainedModel",
@@ -2259,6 +2261,12 @@
"UniSpeechSatPreTrainedModel",
]
)
+ _import_structure["models.upernet"].extend(
+ [
+ "UperNetForSemanticSegmentation",
+ "UperNetPreTrainedModel",
+ ]
+ )
_import_structure["models.van"].extend(
[
"VAN_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -3772,6 +3780,7 @@
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
+ from .models.upernet import UperNetConfig
from .models.van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
from .models.videomae import VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP, VideoMAEConfig
from .models.vilt import (
@@ -4456,6 +4465,7 @@
)
from .models.convnext import (
CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ConvNextBackbone,
ConvNextForImageClassification,
ConvNextModel,
ConvNextPreTrainedModel,
@@ -5292,6 +5302,7 @@
UniSpeechSatModel,
UniSpeechSatPreTrainedModel,
)
+ from .models.upernet import UperNetForSemanticSegmentation, UperNetPreTrainedModel
from .models.van import (
VAN_PRETRAINED_MODEL_ARCHIVE_LIST,
VanForImageClassification,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 43ed17f30dee..a788bd53087a 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -166,6 +166,7 @@
trocr,
unispeech,
unispeech_sat,
+ upernet,
van,
videomae,
vilt,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 6a49d2f4e2c0..8a60ea42fbe3 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -162,6 +162,7 @@
("trocr", "TrOCRConfig"),
("unispeech", "UniSpeechConfig"),
("unispeech-sat", "UniSpeechSatConfig"),
+ ("upernet", "UperNetConfig"),
("van", "VanConfig"),
("videomae", "VideoMAEConfig"),
("vilt", "ViltConfig"),
@@ -311,6 +312,7 @@
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("upernet", "UPERNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("videomae", "VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -489,6 +491,7 @@
("ul2", "UL2"),
("unispeech", "UniSpeech"),
("unispeech-sat", "UniSpeechSat"),
+ ("upernet", "UPerNet"),
("van", "VAN"),
("videomae", "VideoMAE"),
("vilt", "ViLT"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index e23458955c68..42b658f83a94 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -79,6 +79,7 @@
("swinv2", "ViTImageProcessor"),
("table-transformer", "DetrImageProcessor"),
("timesformer", "VideoMAEImageProcessor"),
+ ("upernet", "SegformerImageProcessor"),
("van", "ConvNextImageProcessor"),
("videomae", "VideoMAEImageProcessor"),
("vilt", "ViltImageProcessor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index a6c43a2f1e78..21882e6f1b5d 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -438,6 +438,7 @@
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
("mobilevit", "MobileViTForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
+ ("upernet", "UperNetForSemanticSegmentation"),
]
)
@@ -891,6 +892,7 @@
[
# Backbone mapping
("bit", "BitBackbone"),
+ ("convnext", "ConvNextBackbone"),
("dinat", "DinatBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py
index 109f79daea8f..6d2d5c7bbe07 100644
--- a/src/transformers/models/convnext/__init__.py
+++ b/src/transformers/models/convnext/__init__.py
@@ -51,6 +51,7 @@
"ConvNextForImageClassification",
"ConvNextModel",
"ConvNextPreTrainedModel",
+ "ConvNextBackbone",
]
try:
@@ -85,6 +86,7 @@
else:
from .modeling_convnext import (
CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ConvNextBackbone,
ConvNextForImageClassification,
ConvNextModel,
ConvNextPreTrainedModel,
diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py
index 4027973e08de..41562cbcd44e 100644
--- a/src/transformers/models/convnext/configuration_convnext.py
+++ b/src/transformers/models/convnext/configuration_convnext.py
@@ -64,6 +64,9 @@ class ConvNextConfig(PretrainedConfig):
The initial value for the layer scale.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The drop rate for stochastic depth.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). Will default to the last stage if unset.
Example:
```python
@@ -93,6 +96,7 @@ def __init__(
layer_scale_init_value=1e-6,
drop_path_rate=0.0,
image_size=224,
+ out_features=None,
**kwargs
):
super().__init__(**kwargs)
@@ -108,6 +112,16 @@ def __init__(
self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate
self.image_size = image_size
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+ if out_features is not None:
+ if not isinstance(out_features, list):
+ raise ValueError("out_features should be a list")
+ for feature in out_features:
+ if feature not in self.stage_names:
+ raise ValueError(
+ f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
+ )
+ self.out_features = out_features
class ConvNextOnnxConfig(OnnxConfig):
diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py
index 605c01dbd72c..dc0da67280d0 100755
--- a/src/transformers/models/convnext/modeling_convnext.py
+++ b/src/transformers/models/convnext/modeling_convnext.py
@@ -24,12 +24,19 @@
from ...activations import ACT2FN
from ...modeling_outputs import (
+ BackboneOutput,
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from ...modeling_utils import BackboneMixin, PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
from .configuration_convnext import ConvNextConfig
@@ -290,7 +297,7 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, ConvNextModel):
+ if isinstance(module, ConvNextEncoder):
module.gradient_checkpointing = value
@@ -465,3 +472,102 @@ def forward(
logits=logits,
hidden_states=outputs.hidden_states,
)
+
+
+@add_start_docstrings(
+ """
+ ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.stage_names = config.stage_names
+ self.embeddings = ConvNextEmbeddings(config)
+ self.encoder = ConvNextEncoder(config)
+
+ self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
+
+ out_feature_channels = {}
+ out_feature_channels["stem"] = config.hidden_sizes[0]
+ for idx, stage in enumerate(self.stage_names[1:]):
+ out_feature_channels[stage] = config.hidden_sizes[idx]
+
+ self.out_feature_channels = out_feature_channels
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = dict()
+ for stage, num_channels in zip(self.out_features, self.channels):
+ hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @property
+ def channels(self):
+ return [self.out_feature_channels[name] for name in self.out_features]
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ hidden_states = outputs.hidden_states
+
+ feature_maps = ()
+ # we skip the stem
+ for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
+ if stage in self.out_features:
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py
index 46b0c54c4cbf..701ce16e6a1f 100644
--- a/src/transformers/models/donut/modeling_donut_swin.py
+++ b/src/transformers/models/donut/modeling_donut_swin.py
@@ -577,8 +577,12 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
- self.set_shift_and_window_size(input_dimensions)
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
@@ -668,13 +672,16 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
@@ -725,6 +732,7 @@ def forward(
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, DonutSwinEncoderOutput]:
all_hidden_states = () if output_hidden_states else None
@@ -754,7 +762,9 @@ def custom_forward(*inputs):
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
)
else:
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py
index acc6026451f0..b96c38b4fb1b 100644
--- a/src/transformers/models/segformer/image_processing_segformer.py
+++ b/src/transformers/models/segformer/image_processing_segformer.py
@@ -23,7 +23,7 @@
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
-from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
@@ -159,30 +159,6 @@ def resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
- def center_crop(
- self,
- image: np.ndarray,
- size: Dict[str, int],
- data_format: Optional[Union[str, ChannelDimension]] = None,
- **kwargs
- ) -> np.ndarray:
- """
- Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
- any edge, the image is padded with 0's and then center cropped.
-
- Args:
- image (`np.ndarray`):
- Image to center crop.
- size (`Dict[str, int]`):
- Size of the output image.
- data_format (`str` or `ChannelDimension`, *optional*):
- The channel dimension format of the image. If not provided, it will be the same as the input image.
- """
- size = get_size_dict(size)
- if "height" not in size or "width" not in size:
- raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
- return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
-
def rescale(
self,
image: np.ndarray,
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index fe46e7f532c3..ef9a4ef95f06 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -644,8 +644,12 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
- self.set_shift_and_window_size(input_dimensions)
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
@@ -734,13 +738,16 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
@@ -790,6 +797,7 @@ def forward(
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, SwinEncoderOutput]:
all_hidden_states = () if output_hidden_states else None
@@ -819,7 +827,9 @@ def custom_forward(*inputs):
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
)
else:
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
@@ -1315,6 +1325,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
+ always_partition=True,
return_dict=True,
)
diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py
index b40ce8868cdb..0683f42c6722 100644
--- a/src/transformers/models/swin2sr/modeling_swin2sr.py
+++ b/src/transformers/models/swin2sr/modeling_swin2sr.py
@@ -576,8 +576,12 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
- self.set_shift_and_window_size(input_dimensions)
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py
index c73afc096607..7ae7f2fbfc9c 100644
--- a/src/transformers/models/swinv2/modeling_swinv2.py
+++ b/src/transformers/models/swinv2/modeling_swinv2.py
@@ -718,8 +718,12 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
- self.set_shift_and_window_size(input_dimensions)
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
@@ -808,13 +812,16 @@ def forward(
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
@@ -868,6 +875,7 @@ def forward(
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, Swinv2EncoderOutput]:
all_hidden_states = () if output_hidden_states else None
@@ -897,7 +905,9 @@ def custom_forward(*inputs):
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
)
else:
- layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
diff --git a/src/transformers/models/upernet/__init__.py b/src/transformers/models/upernet/__init__.py
new file mode 100644
index 000000000000..fed32f4c6d88
--- /dev/null
+++ b/src/transformers/models/upernet/__init__.py
@@ -0,0 +1,55 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_upernet": ["UperNetConfig"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_upernet"] = [
+ "UperNetForSemanticSegmentation",
+ "UperNetPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_upernet import UperNetConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_upernet import UperNetForSemanticSegmentation, UperNetPreTrainedModel
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/upernet/configuration_upernet.py b/src/transformers/models/upernet/configuration_upernet.py
new file mode 100644
index 000000000000..449e69def3e3
--- /dev/null
+++ b/src/transformers/models/upernet/configuration_upernet.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" UperNet model configuration"""
+
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import CONFIG_MAPPING
+
+
+logger = logging.get_logger(__name__)
+
+
+class UperNetConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`UperNetForSemanticSegmentation`]. It is used to
+ instantiate an UperNet model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the UperNet
+ [openmmlab/upernet-convnext-tiny](https://huggingface.co/openmmlab/upernet-convnext-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
+ The configuration of the backbone model.
+ hidden_size (`int`, *optional*, defaults to 512):
+ The number of hidden units in the convolutional layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
+ Whether to use an auxiliary head during training.
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
+ Weight of the cross-entropy loss of the auxiliary head.
+ auxiliary_channels (`int`, *optional*, defaults to 256):
+ Number of channels to use in the auxiliary head.
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
+ Number of convolutional layers to use in the auxiliary head.
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
+ loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import UperNetConfig, UperNetForSemanticSegmentation
+
+ >>> # Initializing a configuration
+ >>> configuration = UperNetConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = UperNetForSemanticSegmentation(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "upernet"
+
+ def __init__(
+ self,
+ backbone_config=None,
+ hidden_size=512,
+ initializer_range=0.02,
+ pool_scales=[1, 2, 3, 6],
+ use_auxiliary_head=True,
+ auxiliary_loss_weight=0.4,
+ auxiliary_in_channels=384,
+ auxiliary_channels=256,
+ auxiliary_num_convs=1,
+ auxiliary_concat_input=False,
+ loss_ignore_index=255,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ if backbone_config is None:
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"])
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ self.backbone_config = backbone_config
+ self.hidden_size = hidden_size
+ self.initializer_range = initializer_range
+ self.pool_scales = pool_scales
+ self.use_auxiliary_head = use_auxiliary_head
+ self.auxiliary_loss_weight = auxiliary_loss_weight
+ self.auxiliary_in_channels = auxiliary_in_channels
+ self.auxiliary_channels = auxiliary_channels
+ self.auxiliary_num_convs = auxiliary_num_convs
+ self.auxiliary_concat_input = auxiliary_concat_input
+ self.loss_ignore_index = loss_ignore_index
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["backbone_config"] = self.backbone_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py b/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py
new file mode 100644
index 000000000000..0b7b6e11b11d
--- /dev/null
+++ b/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py
@@ -0,0 +1,214 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ConvNext + UperNet checkpoints from mmsegmentation."""
+
+import argparse
+import json
+
+import torch
+from PIL import Image
+
+import requests
+from huggingface_hub import hf_hub_download
+from transformers import ConvNextConfig, SegformerImageProcessor, UperNetConfig, UperNetForSemanticSegmentation
+
+
+def get_upernet_config(model_name):
+ auxiliary_in_channels = 384
+ if "tiny" in model_name:
+ depths = [3, 3, 9, 3]
+ hidden_sizes = [96, 192, 384, 768]
+ if "small" in model_name:
+ depths = [3, 3, 27, 3]
+ hidden_sizes = [96, 192, 384, 768]
+ if "base" in model_name:
+ depths = [3, 3, 27, 3]
+ hidden_sizes = [128, 256, 512, 1024]
+ auxiliary_in_channels = 512
+ if "large" in model_name:
+ depths = [3, 3, 27, 3]
+ hidden_sizes = [192, 384, 768, 1536]
+ auxiliary_in_channels = 768
+ if "xlarge" in model_name:
+ depths = [3, 3, 27, 3]
+ hidden_sizes = [256, 512, 1024, 2048]
+ auxiliary_in_channels = 1024
+
+ # set label information
+ num_labels = 150
+ repo_id = "huggingface/label-files"
+ filename = "ade20k-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ label2id = {v: k for k, v in id2label.items()}
+
+ backbone_config = ConvNextConfig(
+ depths=depths, hidden_sizes=hidden_sizes, out_features=["stage1", "stage2", "stage3", "stage4"]
+ )
+ config = UperNetConfig(
+ backbone_config=backbone_config,
+ auxiliary_in_channels=auxiliary_in_channels,
+ num_labels=num_labels,
+ id2label=id2label,
+ label2id=label2id,
+ )
+
+ return config
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config):
+ rename_keys = []
+
+ # fmt: off
+ # stem
+ rename_keys.append(("backbone.downsample_layers.0.0.weight", "backbone.embeddings.patch_embeddings.weight"))
+ rename_keys.append(("backbone.downsample_layers.0.0.bias", "backbone.embeddings.patch_embeddings.bias"))
+ rename_keys.append(("backbone.downsample_layers.0.1.weight", "backbone.embeddings.layernorm.weight"))
+ rename_keys.append(("backbone.downsample_layers.0.1.bias", "backbone.embeddings.layernorm.bias"))
+ # stages
+ for i in range(len(config.backbone_config.depths)):
+ for j in range(config.backbone_config.depths[i]):
+ rename_keys.append((f"backbone.stages.{i}.{j}.gamma", f"backbone.encoder.stages.{i}.layers.{j}.layer_scale_parameter"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.depthwise_conv.weight", f"backbone.encoder.stages.{i}.layers.{j}.dwconv.weight"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.depthwise_conv.bias", f"backbone.encoder.stages.{i}.layers.{j}.dwconv.bias"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.norm.weight", f"backbone.encoder.stages.{i}.layers.{j}.layernorm.weight"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.norm.bias", f"backbone.encoder.stages.{i}.layers.{j}.layernorm.bias"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv1.weight", f"backbone.encoder.stages.{i}.layers.{j}.pwconv1.weight"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv1.bias", f"backbone.encoder.stages.{i}.layers.{j}.pwconv1.bias"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv2.weight", f"backbone.encoder.stages.{i}.layers.{j}.pwconv2.weight"))
+ rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv2.bias", f"backbone.encoder.stages.{i}.layers.{j}.pwconv2.bias"))
+ if i > 0:
+ rename_keys.append((f"backbone.downsample_layers.{i}.0.weight", f"backbone.encoder.stages.{i}.downsampling_layer.0.weight"))
+ rename_keys.append((f"backbone.downsample_layers.{i}.0.bias", f"backbone.encoder.stages.{i}.downsampling_layer.0.bias"))
+ rename_keys.append((f"backbone.downsample_layers.{i}.1.weight", f"backbone.encoder.stages.{i}.downsampling_layer.1.weight"))
+ rename_keys.append((f"backbone.downsample_layers.{i}.1.bias", f"backbone.encoder.stages.{i}.downsampling_layer.1.bias"))
+
+ rename_keys.append((f"backbone.norm{i}.weight", f"backbone.hidden_states_norms.stage{i+1}.weight"))
+ rename_keys.append((f"backbone.norm{i}.bias", f"backbone.hidden_states_norms.stage{i+1}.bias"))
+
+ # decode head
+ rename_keys.extend(
+ [
+ ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
+ ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
+ ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
+ ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
+ ]
+ )
+ # fmt: on
+
+ return rename_keys
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+def convert_upernet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
+ model_name_to_url = {
+ "upernet-convnext-tiny": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth",
+ "upernet-convnext-small": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth",
+ "upernet-convnext-base": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth",
+ "upernet-convnext-large": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth",
+ "upernet-convnext-xlarge": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth",
+ }
+ checkpoint_url = model_name_to_url[model_name]
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"]
+
+ config = get_upernet_config(model_name)
+ model = UperNetForSemanticSegmentation(config)
+ model.eval()
+
+ # replace "bn" => "batch_norm"
+ for key in state_dict.copy().keys():
+ val = state_dict.pop(key)
+ if "bn" in key:
+ key = key.replace("bn", "batch_norm")
+ state_dict[key] = val
+
+ # rename keys
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+
+ model.load_state_dict(state_dict)
+
+ # verify on image
+ url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+ processor = SegformerImageProcessor()
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ if model_name == "upernet-convnext-tiny":
+ expected_slice = torch.tensor(
+ [[-8.8110, -8.8110, -8.6521], [-8.8110, -8.8110, -8.6521], [-8.7746, -8.7746, -8.6130]]
+ )
+ elif model_name == "upernet-convnext-small":
+ expected_slice = torch.tensor(
+ [[-8.8236, -8.8236, -8.6771], [-8.8236, -8.8236, -8.6771], [-8.7638, -8.7638, -8.6240]]
+ )
+ elif model_name == "upernet-convnext-base":
+ expected_slice = torch.tensor(
+ [[-8.8558, -8.8558, -8.6905], [-8.8558, -8.8558, -8.6905], [-8.7669, -8.7669, -8.6021]]
+ )
+ elif model_name == "upernet-convnext-large":
+ expected_slice = torch.tensor(
+ [[-8.6660, -8.6660, -8.6210], [-8.6660, -8.6660, -8.6210], [-8.6310, -8.6310, -8.5964]]
+ )
+ elif model_name == "upernet-convnext-xlarge":
+ expected_slice = torch.tensor(
+ [[-8.4980, -8.4980, -8.3977], [-8.4980, -8.4980, -8.3977], [-8.4379, -8.4379, -8.3412]]
+ )
+ print("Logits:", outputs.logits[0, 0, :3, :3])
+ assert torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print(f"Pushing model and processor for {model_name} to hub")
+ model.push_to_hub(f"openmmlab/{model_name}")
+ processor.push_to_hub(f"openmmlab/{model_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="upernet-convnext-tiny",
+ type=str,
+ choices=[f"upernet-convnext-{size}" for size in ["tiny", "small", "base", "large", "xlarge"]],
+ help="Name of the ConvNext UperNet model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
+ )
+
+ args = parser.parse_args()
+ convert_upernet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py b/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py
new file mode 100644
index 000000000000..a44549a4703c
--- /dev/null
+++ b/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py
@@ -0,0 +1,297 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Swin Transformer + UperNet checkpoints from mmsegmentation.
+
+URL: https://github.com/open-mmlab/mmsegmentation/tree/master/configs/swin
+"""
+
+import argparse
+import json
+
+import torch
+from PIL import Image
+
+import requests
+from huggingface_hub import hf_hub_download
+from transformers import SegformerImageProcessor, SwinConfig, UperNetConfig, UperNetForSemanticSegmentation
+
+
+def get_upernet_config(model_name):
+ auxiliary_in_channels = 384
+ window_size = 7
+ if "tiny" in model_name:
+ embed_dim = 96
+ depths = (2, 2, 6, 2)
+ num_heads = (3, 6, 12, 24)
+ elif "small" in model_name:
+ embed_dim = 96
+ depths = (2, 2, 18, 2)
+ num_heads = (3, 6, 12, 24)
+ elif "base" in model_name:
+ embed_dim = 128
+ depths = (2, 2, 18, 2)
+ num_heads = (4, 8, 16, 32)
+ window_size = 12
+ auxiliary_in_channels = 512
+ elif "large" in model_name:
+ embed_dim = 192
+ depths = (2, 2, 18, 2)
+ num_heads = (6, 12, 24, 48)
+ window_size = 12
+ auxiliary_in_channels = 768
+
+ # set label information
+ num_labels = 150
+ repo_id = "huggingface/label-files"
+ filename = "ade20k-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ label2id = {v: k for k, v in id2label.items()}
+
+ backbone_config = SwinConfig(
+ embed_dim=embed_dim,
+ depths=depths,
+ num_heads=num_heads,
+ window_size=window_size,
+ out_features=["stage1", "stage2", "stage3", "stage4"],
+ )
+ config = UperNetConfig(
+ backbone_config=backbone_config,
+ auxiliary_in_channels=auxiliary_in_channels,
+ num_labels=num_labels,
+ id2label=id2label,
+ label2id=label2id,
+ )
+
+ return config
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config):
+ rename_keys = []
+
+ # fmt: off
+ # stem
+ rename_keys.append(("backbone.patch_embed.projection.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("backbone.patch_embed.projection.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
+ rename_keys.append(("backbone.patch_embed.norm.weight", "backbone.embeddings.norm.weight"))
+ rename_keys.append(("backbone.patch_embed.norm.bias", "backbone.embeddings.norm.bias"))
+ # stages
+ for i in range(len(config.backbone_config.depths)):
+ for j in range(config.backbone_config.depths[i]):
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.relative_position_bias_table", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.relative_position_index", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.0.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.0.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
+ rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
+
+ if i < 3:
+ rename_keys.append((f"backbone.stages.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight"))
+ rename_keys.append((f"backbone.stages.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight"))
+ rename_keys.append((f"backbone.stages.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias"))
+ rename_keys.append((f"backbone.norm{i}.weight", f"backbone.hidden_states_norms.stage{i+1}.weight"))
+ rename_keys.append((f"backbone.norm{i}.bias", f"backbone.hidden_states_norms.stage{i+1}.bias"))
+
+ # decode head
+ rename_keys.extend(
+ [
+ ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
+ ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
+ ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
+ ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
+ ]
+ )
+ # fmt: on
+
+ return rename_keys
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, backbone_config):
+ num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
+ for i in range(len(backbone_config.depths)):
+ dim = num_features[i]
+ for j in range(backbone_config.depths[i]):
+ # fmt: off
+ # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"backbone.stages.{i}.blocks.{j}.attn.w_msa.qkv.weight")
+ in_proj_bias = state_dict.pop(f"backbone.stages.{i}.blocks.{j}.attn.w_msa.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
+ dim : dim * 2, :
+ ]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
+ dim : dim * 2
+ ]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
+ -dim :, :
+ ]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
+ # fmt: on
+
+
+def correct_unfold_reduction_order(x):
+ out_channel, in_channel = x.shape
+ x = x.reshape(out_channel, 4, in_channel // 4)
+ x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
+ return x
+
+
+def reverse_correct_unfold_reduction_order(x):
+ out_channel, in_channel = x.shape
+ x = x.reshape(out_channel, in_channel // 4, 4)
+ x = x[:, :, [0, 2, 1, 3]].transpose(1, 2).reshape(out_channel, in_channel)
+
+ return x
+
+
+def correct_unfold_norm_order(x):
+ in_channel = x.shape[0]
+ x = x.reshape(4, in_channel // 4)
+ x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
+ return x
+
+
+# there was an incompatibility with this version, due to a new implementation of their downsampling operation using nn.Unfold.
+# was resolved as seen here:
+# https://github.com/open-mmlab/mmdetection/blob/31c84958f54287a8be2b99cbf87a6dcf12e57753/mmdet/models/utils/ckpt_convert.py#L96.
+def reverse_correct_unfold_norm_order(x):
+ in_channel = x.shape[0]
+ x = x.reshape(in_channel // 4, 4)
+ x = x[:, [0, 2, 1, 3]].transpose(0, 1).reshape(in_channel)
+ return x
+
+
+def convert_upernet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
+ model_name_to_url = {
+ "upernet-swin-tiny": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth",
+ "upernet-swin-small": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth",
+ "upernet-swin-base": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth",
+ "upernet-swin-large": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth",
+ }
+ checkpoint_url = model_name_to_url[model_name]
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", file_name=model_name)[
+ "state_dict"
+ ]
+
+ for name, param in state_dict.items():
+ print(name, param.shape)
+
+ config = get_upernet_config(model_name)
+ model = UperNetForSemanticSegmentation(config)
+ model.eval()
+
+ # replace "bn" => "batch_norm"
+ for key in state_dict.copy().keys():
+ val = state_dict.pop(key)
+ if "bn" in key:
+ key = key.replace("bn", "batch_norm")
+ state_dict[key] = val
+
+ # rename keys
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config.backbone_config)
+
+ # fix downsample parameters
+ for key, value in state_dict.items():
+ if "downsample" in key:
+ if "reduction" in key:
+ state_dict[key] = reverse_correct_unfold_reduction_order(value)
+ if "norm" in key:
+ state_dict[key] = reverse_correct_unfold_norm_order(value)
+
+ model.load_state_dict(state_dict)
+
+ # verify on image
+ url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+ processor = SegformerImageProcessor()
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ with torch.no_grad():
+ outputs = model(pixel_values)
+ logits = outputs.logits
+
+ print(logits.shape)
+ print("First values of logits:", logits[0, 0, :3, :3])
+ # assert values
+ if model_name == "upernet-swin-tiny":
+ expected_slice = torch.tensor(
+ [[-7.5958, -7.5958, -7.4302], [-7.5958, -7.5958, -7.4302], [-7.4797, -7.4797, -7.3068]]
+ )
+ elif model_name == "upernet-swin-small":
+ expected_slice = torch.tensor(
+ [[-7.1921, -7.1921, -6.9532], [-7.1921, -7.1921, -6.9532], [-7.0908, -7.0908, -6.8534]]
+ )
+ elif model_name == "upernet-swin-base":
+ expected_slice = torch.tensor(
+ [[-6.5851, -6.5851, -6.4330], [-6.5851, -6.5851, -6.4330], [-6.4763, -6.4763, -6.3254]]
+ )
+ elif model_name == "upernet-swin-large":
+ expected_slice = torch.tensor(
+ [[-7.5297, -7.5297, -7.3802], [-7.5297, -7.5297, -7.3802], [-7.4044, -7.4044, -7.2586]]
+ )
+ print("Logits:", outputs.logits[0, 0, :3, :3])
+ assert torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print(f"Pushing model and processor for {model_name} to hub")
+ model.push_to_hub(f"openmmlab/{model_name}")
+ processor.push_to_hub(f"openmmlab/{model_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="upernet-swin-tiny",
+ type=str,
+ choices=[f"upernet-swin-{size}" for size in ["tiny", "small", "base", "large"]],
+ help="Name of the Swin + UperNet model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
+ )
+
+ args = parser.parse_args()
+ convert_upernet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py
new file mode 100644
index 000000000000..56190d050389
--- /dev/null
+++ b/src/transformers/models/upernet/modeling_upernet.py
@@ -0,0 +1,442 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch UperNet model. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation."""
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoBackbone
+
+from ...modeling_outputs import SemanticSegmenterOutput
+from ...modeling_utils import BackboneMixin, PreTrainedModel
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from .configuration_upernet import UperNetConfig
+
+
+UPERNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "openmmlab/upernet-convnext-tiny",
+ # See all UperNet models at https://huggingface.co/models?filter=upernet
+]
+
+# General docstring
+_CONFIG_FOR_DOC = "UperNetConfig"
+
+
+class UperNetConvModule(nn.Module):
+ """
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ padding: Union[int, Tuple[int, int], str] = 0,
+ bias: bool = False,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ ) -> None:
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=bias,
+ dilation=dilation,
+ )
+ self.batch_norm = nn.BatchNorm2d(out_channels)
+ self.activation = nn.ReLU()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output = self.conv(input)
+ output = self.batch_norm(output)
+ output = self.activation(output)
+
+ return output
+
+
+class UperNetPyramidPoolingBlock(nn.Module):
+ def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
+ super().__init__()
+ self.layers = [
+ nn.AdaptiveAvgPool2d(pool_scale),
+ UperNetConvModule(in_channels, channels, kernel_size=1),
+ ]
+ for i, layer in enumerate(self.layers):
+ self.add_module(str(i), layer)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class UperNetPyramidPoolingModule(nn.Module):
+ """
+ Pyramid Pooling Module (PPM) used in PSPNet.
+
+ Args:
+ pool_scales (`Tuple[int]`):
+ Pooling scales used in Pooling Pyramid Module.
+ in_channels (`int`):
+ Input channels.
+ channels (`int`):
+ Channels after modules, before conv_seg.
+ align_corners (`bool`):
+ align_corners argument of F.interpolate.
+ """
+
+ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
+ super().__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.blocks = []
+ for i, pool_scale in enumerate(pool_scales):
+ block = UperNetPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
+ self.blocks.append(block)
+ self.add_module(str(i), block)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ ppm_outs = []
+ for ppm in self.blocks:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = nn.functional.interpolate(
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+class UperNetHead(nn.Module):
+ """
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+ [UPerNet](https://arxiv.org/abs/1807.10221).
+ """
+
+ def __init__(self, config, in_channels):
+ super().__init__()
+
+ self.config = config
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
+ self.in_channels = in_channels
+ self.channels = config.hidden_size
+ self.align_corners = False
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+ # PSP Module
+ self.psp_modules = UperNetPyramidPoolingModule(
+ self.pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ align_corners=self.align_corners,
+ )
+ self.bottleneck = UperNetConvModule(
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ )
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1)
+ fpn_conv = UperNetConvModule(self.channels, self.channels, kernel_size=3, padding=1)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = UperNetConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ def init_weights(self):
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def psp_forward(self, inputs):
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ # build laterals
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+ laterals.append(self.psp_forward(encoder_hidden_states))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
+ )
+
+ # build outputs
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = nn.functional.interpolate(
+ fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.classifier(output)
+
+ return output
+
+
+class UperNetFCNHead(nn.Module):
+ """
+ Fully Convolution Networks for Semantic Segmentation. This head is the implementation of
+ [FCNNet](https://arxiv.org/abs/1411.4038>).
+
+ Args:
+ config:
+ Configuration.
+ in_channels (int):
+ Number of input channels.
+ kernel_size (int):
+ The kernel size for convs in the head. Default: 3.
+ dilation (int):
+ The dilation rate for convs in the head. Default: 1.
+ """
+
+ def __init__(
+ self, config, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.in_channels = config.auxiliary_in_channels
+ self.channels = config.auxiliary_channels
+ self.num_convs = config.auxiliary_num_convs
+ self.concat_input = config.auxiliary_concat_input
+ self.in_index = in_index
+
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ UperNetConvModule(
+ self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+ )
+ )
+ for i in range(self.num_convs - 1):
+ convs.append(
+ UperNetConvModule(
+ self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+ )
+ )
+ if self.num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = UperNetConvModule(
+ self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+ def init_weights(self):
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ # just take the relevant feature maps
+ hidden_states = encoder_hidden_states[self.in_index]
+ output = self.convs(hidden_states)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
+ output = self.classifier(output)
+ return output
+
+
+class UperNetPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = UperNetConfig
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def init_weights(self):
+ """Initialize the weights"""
+ self.backbone.init_weights()
+ self.decode_head.init_weights()
+ self.auxiliary_head.init_weights()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, BackboneMixin):
+ module.gradient_checkpointing = value
+
+
+UPERNET_START_DOCSTRING = r"""
+ Parameters:
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+ config ([`UperNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+UPERNET_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See
+ `attentions` under returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under
+ returned tensors for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """UperNet framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""",
+ UPERNET_START_DOCSTRING,
+)
+class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.backbone = AutoBackbone.from_config(config.backbone_config)
+
+ # Semantic segmentation head(s)
+ self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)
+ self.auxiliary_head = UperNetFCNHead(config) if config.use_auxiliary_head else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(UPERNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SemanticSegmenterOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+ >>> from PIL import Image
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny")
+ >>> model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny")
+
+ >>> filepath = hf_hub_download(
+ ... repo_id="hf-internal-testing/fixtures_ade20k", filename="ADE_val_00000001.jpg", repo_type="dataset"
+ ... )
+ >>> image = Image.open(filepath).convert("RGB")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
+ >>> list(logits.shape)
+ [1, 150, 512, 512]
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ outputs = self.backbone.forward_with_filtered_kwargs(
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+ )
+ features = outputs.feature_maps
+
+ logits = self.decode_head(features)
+ logits = nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(features)
+ auxiliary_logits = nn.functional.interpolate(
+ auxiliary_logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False
+ )
+
+ loss = None
+ if labels is not None:
+ if self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+ else:
+ # compute weighted loss
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
+ main_loss = loss_fct(logits, labels)
+ auxiliary_loss = loss_fct(auxiliary_logits, labels)
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index b3dc278739c9..8ff9fbcde5f9 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -1576,6 +1576,13 @@ def load_tf_weights_in_convbert(*args, **kwargs):
CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+class ConvNextBackbone(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ConvNextForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
@@ -5866,6 +5873,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class UperNetForSemanticSegmentation(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class UperNetPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
VAN_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py
index 6cdaafabec35..b6abc332c424 100644
--- a/tests/models/convnext/test_modeling_convnext.py
+++ b/tests/models/convnext/test_modeling_convnext.py
@@ -29,7 +29,7 @@
if is_torch_available():
import torch
- from transformers import ConvNextForImageClassification, ConvNextModel
+ from transformers import ConvNextBackbone, ConvNextForImageClassification, ConvNextModel
from transformers.models.convnext.modeling_convnext import CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -53,9 +53,9 @@ def __init__(
use_labels=True,
intermediate_size=37,
hidden_act="gelu",
- type_sequence_label_size=10,
+ num_labels=10,
initializer_range=0.02,
- num_labels=3,
+ out_features=["stage2", "stage3", "stage4"],
scope=None,
):
self.parent = parent
@@ -69,8 +69,9 @@ def __init__(
self.use_labels = use_labels
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
- self.type_sequence_label_size = type_sequence_label_size
+ self.num_labels = num_labels
self.initializer_range = initializer_range
+ self.out_features = out_features
self.scope = scope
def prepare_config_and_inputs(self):
@@ -78,7 +79,7 @@ def prepare_config_and_inputs(self):
labels = None
if self.use_labels:
- labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
@@ -93,6 +94,8 @@ def get_config(self):
hidden_act=self.hidden_act,
is_decoder=False,
initializer_range=self.initializer_range,
+ out_features=self.out_features,
+ num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -107,12 +110,40 @@ def create_and_check_model(self, config, pixel_values, labels):
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
- config.num_labels = self.type_sequence_label_size
model = ConvNextForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_backbone(self, config, pixel_values, labels):
+ model = ConvNextBackbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify hidden states
+ self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
+ self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), len(config.out_features))
+ self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
+
+ # verify backbone works with out_features=None
+ config.out_features = None
+ model = ConvNextBackbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify feature maps
+ self.parent.assertEqual(len(result.feature_maps), 1)
+ self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), 1)
+ self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -132,6 +163,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
(
ConvNextModel,
ConvNextForImageClassification,
+ ConvNextBackbone,
)
if is_torch_available()
else ()
@@ -167,6 +199,10 @@ def test_inputs_embeds(self):
def test_model_common_attributes(self):
pass
+ @unittest.skip(reason="ConvNext does not use feedforward chunking")
+ def test_feed_forward_chunking(self):
+ pass
+
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/upernet/__init__.py b/tests/models/upernet/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/upernet/test_modeling_upernet.py b/tests/models/upernet/test_modeling_upernet.py
new file mode 100644
index 000000000000..2e9daec542bc
--- /dev/null
+++ b/tests/models/upernet/test_modeling_upernet.py
@@ -0,0 +1,307 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch UperNet framework. """
+
+
+import inspect
+import unittest
+
+from huggingface_hub import hf_hub_download
+from transformers import ConvNextConfig, UperNetConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import UperNetForSemanticSegmentation
+ from transformers.models.upernet.modeling_upernet import UPERNET_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoImageProcessor
+
+
+class UperNetModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ num_channels=3,
+ num_stages=4,
+ hidden_sizes=[10, 20, 30, 40],
+ depths=[2, 2, 3, 2],
+ is_training=True,
+ use_labels=True,
+ intermediate_size=37,
+ hidden_act="gelu",
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ out_features=["stage2", "stage3", "stage4"],
+ num_labels=3,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.num_stages = num_stages
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.out_features = out_features
+ self.num_labels = num_labels
+ self.scope = scope
+ self.num_hidden_layers = num_stages
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_backbone_config(self):
+ return ConvNextConfig(
+ num_channels=self.num_channels,
+ num_stages=self.num_stages,
+ hidden_sizes=self.hidden_sizes,
+ depths=self.depths,
+ is_training=self.is_training,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ out_features=self.out_features,
+ )
+
+ def get_config(self):
+ return UperNetConfig(
+ backbone_config=self.get_backbone_config(),
+ hidden_size=512,
+ pool_scales=[1, 2, 3, 6],
+ use_auxiliary_head=True,
+ auxiliary_loss_weight=0.4,
+ auxiliary_in_channels=40,
+ auxiliary_channels=256,
+ auxiliary_num_convs=1,
+ auxiliary_concat_input=False,
+ loss_ignore_index=255,
+ num_labels=self.num_labels,
+ )
+
+ def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels):
+ model = UperNetForSemanticSegmentation(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values,
+ labels,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class UperNetModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as UperNet does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (UperNetForSemanticSegmentation,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_torchscript = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = UperNetModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=UperNetConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_for_semantic_segmentation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
+
+ @unittest.skip(reason="UperNet does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="UperNet does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="UperNet does not have a base model")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="UperNet does not have a base model")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_stages = self.model_tester.num_stages
+ self.assertEqual(len(hidden_states), expected_num_stages + 1)
+
+ # ConvNext's feature maps are of shape (batch_size, num_channels, height, width)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [self.model_tester.image_size // 4, self.model_tester.image_size // 4],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ configs_no_init.backbone_config = _config_zero_init(configs_no_init.backbone_config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ @unittest.skip(reason="UperNet does not have tied weights")
+ def test_tied_model_weights_key_ignore(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in UPERNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = UperNetForSemanticSegmentation.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of ADE20k
+def prepare_img():
+ filepath = hf_hub_download(
+ repo_id="hf-internal-testing/fixtures_ade20k", repo_type="dataset", filename="ADE_val_00000001.jpg"
+ )
+ image = Image.open(filepath).convert("RGB")
+ return image
+
+
+@require_torch
+@require_vision
+@slow
+class UperNetModelIntegrationTest(unittest.TestCase):
+ def test_inference_swin_backbone(self):
+ processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-swin-tiny")
+ model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-tiny")
+
+ image = prepare_img()
+ inputs = processor(images=image, return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-7.5958, -7.5958, -7.4302], [-7.5958, -7.5958, -7.4302], [-7.4797, -7.4797, -7.3068]]
+ ).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4))
+
+ def test_inference_convnext_backbone(self):
+ processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny")
+ model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny")
+
+ image = prepare_img()
+ inputs = processor(images=image, return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-8.8110, -8.8110, -8.6521], [-8.8110, -8.8110, -8.6521], [-8.7746, -8.7746, -8.6130]]
+ ).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4))
diff --git a/utils/check_repo.py b/utils/check_repo.py
index fc687ba464e9..b0e3637a8ec9 100755
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -689,14 +689,15 @@ def find_all_documented_objects():
"PyTorchBenchmarkArguments",
"TensorFlowBenchmark",
"TensorFlowBenchmarkArguments",
- "BitBackbone",
- "MaskFormerSwinBackbone",
- "ResNetBackbone",
"AutoBackbone",
+ "BitBackbone",
+ "ConvNextBackbone",
"DinatBackbone",
- "NatBackbone",
+ "MaskFormerSwinBackbone",
"MaskFormerSwinConfig",
"MaskFormerSwinModel",
+ "NatBackbone",
+ "ResNetBackbone",
"SwinBackbone",
]
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 7839f58a2016..2e4613d918f5 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -179,6 +179,7 @@ src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/configuration_unispeech.py
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+src/transformers/models/upernet/modeling_upernet.py
src/transformers/models/van/modeling_van.py
src/transformers/models/videomae/modeling_videomae.py
src/transformers/models/vilt/modeling_vilt.py