From e455aee8b1f2f3b6de5580fdb821c6c4d4f99512 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 21 Dec 2022 09:09:01 +0100 Subject: [PATCH] Use config.num_channels in CLIP-like modeling files --- .../models/chinese_clip/modeling_chinese_clip.py | 6 +++++- src/transformers/models/clip/modeling_clip.py | 6 +++++- src/transformers/models/clipseg/modeling_clipseg.py | 6 +++++- src/transformers/models/x_clip/modeling_x_clip.py | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index d321065f6bb4..635c7a95f784 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -181,7 +181,11 @@ def __init__(self, config: ChineseCLIPVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( - in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index be5f9d83d6ae..aa46008f91d6 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -178,7 +178,11 @@ def __init__(self, config: CLIPVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( - in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index ff4af1297661..d5074a8dff20 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -171,7 +171,11 @@ def __init__(self, config: CLIPSegVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( - in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 67e15dad80e2..83ca74761274 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -129,7 +129,11 @@ def __init__(self, config: XCLIPVisionConfig): self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( - in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2