diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index a526ce5d7af1..bc1185c9d1d5 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1083,6 +1083,12 @@ def get_image_features(self, pixel_values): image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + @can_return_tuple @auto_docstring def forward( @@ -1192,10 +1198,10 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): - self.model = decoder + self.model.set_decoder(decoder) def get_decoder(self): - return self.model + return self.model.get_decoder() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 0d484ffb0c05..711c81d0d4b6 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -921,6 +921,12 @@ def get_image_features(self, pixel_values): image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + @can_return_tuple @auto_docstring def forward( @@ -1030,10 +1036,10 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): - self.model = decoder + self.model.set_decoder(decoder) def get_decoder(self): - return self.model + return self.model.get_decoder() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index c0e990971527..cccac89e26db 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -154,7 +154,7 @@ def __init__( up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, the cos_sin_cache will be recomputed during the forward pass. """ - super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + super().__init__(dim=dim, base=base, device=device, interleaved=False) self.max_seqlen = max_seqlen if max_seqlen is not None and device is not None and dtype is not None: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 137673cfa590..8909875381d1 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -417,7 +417,7 @@ def __init__( up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, the cos_sin_cache will be recomputed during the forward pass. """ - super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + super().__init__(dim=dim, base=base, device=device, interleaved=False) self.max_seqlen = max_seqlen if max_seqlen is not None and device is not None and dtype is not None: