From 6124dbad15cec63d25af498cfdc603d42e9664ef Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 18:03:55 +0900 Subject: [PATCH 01/14] support colqwen series and add pytest for colqwen2.5 auto --- .../transformers/models/colqwen2/modeling_colqwen2.py | 3 +-- tests/test_transformers.py | 8 ++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py b/src/optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py index 262660fc6..71816c052 100644 --- a/src/optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py @@ -113,8 +113,7 @@ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"): model.vlm.model.lm_head = model.embedding_proj_layer model.vlm.model.config.embedding_dim = model.config.embedding_dim - # Some of the model weights are different from the model.dtype(vidore/colqwen2-v1.0-hf) - return model.to(model.dtype) + return model def forward( self, diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 43b41de98..d292954a6 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -508,6 +508,14 @@ class TestColQwen2_5Model(TestColQwen2Model): HF_MODEL_ID = "Sahil-Kabir/colqwen2.5-v0.2-hf" +class TestColQwen2_5Model_Auto(TestColQwen2Model): + TEST_LEVEL = TestLevel.FULL + HF_MODEL_ID = "Sahil-Kabir/colqwen2.5-v0.2-hf" + HF_CONFIG_KWARGS = { + "dtype": "auto", + } + + class TestWav2VecModel(BaseTest.TestModel): RBLN_AUTO_CLASS = RBLNAutoModelForCTC RBLN_CLASS = RBLNWav2Vec2ForCTC From 5feb4f81a18d4d0186d68c6eb3e61ef917da73b2 Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 18:28:56 +0900 Subject: [PATCH 02/14] feat: support blip2 non-fp32 --- .../models/blip_2/modeling_blip_2.py | 1 + tests/test_llm.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py b/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py index dc7311fda..c4e47f0c3 100644 --- a/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +++ b/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py @@ -315,6 +315,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi auto_model_class = AutoModelForVisualQuestionAnswering _rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}] + _supports_non_fp32 = True def __getattr__(self, __name: str) -> Any: def redirect(func): diff --git a/tests/test_llm.py b/tests/test_llm.py index de46125e3..2bd6dad79 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -613,6 +613,27 @@ def _inner_test_save_load(self, tmpdir): ) +class TestBlip2ForConditionalGeneration_Auto(TestBlip2ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "auto", + } + + +class TestBlip2ForConditionalGeneration_Bfloat16(TestBlip2ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "bfloat16", + } + + +class TestBlip2ForConditionalGeneration_Float16(TestBlip2ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "float16", + } + + class TestIdefics3ForConditionalGeneration(LLMTest.TestLLM): RBLN_AUTO_CLASS = RBLNAutoModelForVision2Seq RBLN_CLASS = RBLNIdefics3ForConditionalGeneration From 2ae88bbfe9fd3fe5272ade2ec9328bf98ca0e981 Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 18:29:15 +0900 Subject: [PATCH 03/14] (wip): support idefics3 --- .../models/idefics3/modeling_idefics3.py | 1 + tests/test_llm.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py b/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py index 1b29fe232..35f8d28bc 100644 --- a/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +++ b/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py @@ -231,6 +231,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM auto_model_class = AutoModelForVision2Seq _rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}] _rbln_submodule_prefix = "model" + _supports_non_fp32 = True def __getattr__(self, __name: str) -> Any: def redirect(func): diff --git a/tests/test_llm.py b/tests/test_llm.py index 2bd6dad79..1482f7e4f 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -663,6 +663,26 @@ def get_inputs(self): return inputs +class TestIdefics3ForConditionalGeneration_Auto(TestIdefics3ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "auto", + } + + +class TestIdefics3ForConditionalGeneration_Bfloat16(TestIdefics3ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "bfloat16", + } + + +class TestIdefics3ForConditionalGeneration_Float16(TestIdefics3ForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "float16", + } + class TestQwen2VLForConditionalGeneration(LLMTest.TestLLM): RBLN_AUTO_CLASS = RBLNAutoModelForVision2Seq RBLN_CLASS = RBLNQwen2VLForConditionalGeneration From 32640ce28588986bb0eebb747ab62c5fe24cf711 Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 18:40:21 +0900 Subject: [PATCH 04/14] feat: support llavanext for non-fp32 dtyupe --- .../models/llava_next/modeling_llava_next.py | 1 + tests/test_llm.py | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py index 9f7281f60..339a8f3d0 100644 --- a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +++ b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py @@ -133,6 +133,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration {"name": "vision_tower"}, {"name": "language_model"}, ] + _supports_non_fp32 = True def __getattr__(self, __name: str) -> Any: def redirect(func): diff --git a/tests/test_llm.py b/tests/test_llm.py index 1482f7e4f..03ad95e37 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -2,6 +2,7 @@ import os import unittest import warnings +import copy import pytest import torch @@ -563,6 +564,42 @@ def test_complicate_config(self): _ = self.RBLN_CLASS.from_pretrained(model_id=self.HF_MODEL_ID, **rbln_class_kwargs) +class TestLlavaNextForConditionalGeneration_Auto(TestLlavaNextForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "auto", + }) + return super().setUpClass() + + +class TestLlavaNextForConditionalGeneration_Bfloat16(TestLlavaNextForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "bfloat16", + }) + return super().setUpClass() + + +class TestLlavaNextForConditionalGeneration_Float16(TestLlavaNextForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "float16", + }) + return super().setUpClass() + + class TestBlip2ForConditionalGeneration(LLMTest.TestLLM): RBLN_AUTO_CLASS = RBLNAutoModelForVision2Seq RBLN_CLASS = RBLNBlip2ForConditionalGeneration From ae938067420a0ff04011a205b95d5ac600c9bdfd Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 18:51:38 +0900 Subject: [PATCH 05/14] feat: support llava for non-fp32 dtyupe --- .../models/llava/modeling_llava.py | 1 + tests/test_llm.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/optimum/rbln/transformers/models/llava/modeling_llava.py b/src/optimum/rbln/transformers/models/llava/modeling_llava.py index 074b28cb4..3b74b8ce4 100644 --- a/src/optimum/rbln/transformers/models/llava/modeling_llava.py +++ b/src/optimum/rbln/transformers/models/llava/modeling_llava.py @@ -168,6 +168,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi {"name": "vision_tower"}, {"name": "language_model"}, ] + _supports_non_fp32 = True def __getattr__(self, __name: str) -> Any: def redirect(func): diff --git a/tests/test_llm.py b/tests/test_llm.py index 03ad95e37..d1a9b11b2 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -466,6 +466,41 @@ def _inner_test_save_load(self, tmpdir): ) +class TestLlavaForConditionalGeneration_Auto(TestLlavaForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "auto", + }) + return super().setUpClass() + +class TestLlavaForConditionalGeneration_Bfloat16(TestLlavaForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "bfloat16", + }) + return super().setUpClass() + + +class TestLlavaForConditionalGeneration_Float16(TestLlavaForConditionalGeneration): + TEST_LEVEL = TestLevel.FULL + + # override + @classmethod + def setUpClass(cls): + cls.HF_CONFIG_KWARGS.update({ + "dtype": "float16", + }) + return super().setUpClass() + + class TestPegasusModel(LLMTest.TestLLM): RBLN_AUTO_CLASS = RBLNAutoModelForSeq2SeqLM RBLN_CLASS = RBLNPegasusForConditionalGeneration From 66e5cb7393b1110b044baf16465d03fbf6dd7b66 Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 19:07:48 +0900 Subject: [PATCH 06/14] feat: support colpali for non-fp32 dtype --- .../models/colpali/modeling_colpali.py | 1 + tests/test_transformers.py | 26 ++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/optimum/rbln/transformers/models/colpali/modeling_colpali.py b/src/optimum/rbln/transformers/models/colpali/modeling_colpali.py index 359f63bc3..9f7b04fdf 100644 --- a/src/optimum/rbln/transformers/models/colpali/modeling_colpali.py +++ b/src/optimum/rbln/transformers/models/colpali/modeling_colpali.py @@ -157,6 +157,7 @@ class RBLNColPaliForRetrieval(RBLNModel): _rbln_submodules = [ {"name": "vlm"}, ] + _supports_non_fp32 = True def __post_init__(self, **kwargs): self.vlm_model = self.rbln_submodules[0] diff --git a/tests/test_transformers.py b/tests/test_transformers.py index d292954a6..bf1cc8c5b 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -421,6 +421,26 @@ def _inner_propagate_rbln_config(self, tmpdir): assert model.rbln_config.vlm.language_model.device == 2 +class TestColPaliModel_Auto(TestColPaliModel): + HF_CONFIG_KWARGS = { + "dtype": "auto", + } + + +class TestColPaliModel_BFloat16(TestColPaliModel): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": "bfloat16", + } + + +class TestColPaliModel_Float16(TestColPaliModel): + TEST_LEVEL = TestLevel.FULL + HF_CONFIG_KWARGS = { + "dtype": torch.float16, + } + + class TestColQwen2Model(BaseTest.TestModel): RBLN_AUTO_CLASS = None RBLN_CLASS = RBLNColQwen2ForRetrieval @@ -490,16 +510,15 @@ class TestColQwen2Model_BFloat16(TestColQwen2Model): class TestColQwen2Model_Auto(TestColQwen2Model): - TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "auto", } -class TestColQwen2Model_Float32(TestColQwen2Model): +class TestColQwen2Model_Float16(TestColQwen2Model): TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { - "dtype": torch.float32, + "dtype": torch.float16, } @@ -509,7 +528,6 @@ class TestColQwen2_5Model(TestColQwen2Model): class TestColQwen2_5Model_Auto(TestColQwen2Model): - TEST_LEVEL = TestLevel.FULL HF_MODEL_ID = "Sahil-Kabir/colqwen2.5-v0.2-hf" HF_CONFIG_KWARGS = { "dtype": "auto", From 95fc7a24c6a65e6e3013a9dee4ae5d57c40752ef Mon Sep 17 00:00:00 2001 From: seinpark Date: Tue, 27 Jan 2026 19:08:04 +0900 Subject: [PATCH 07/14] fix: update test level for auto dtype --- tests/test_llm.py | 67 ++++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/tests/test_llm.py b/tests/test_llm.py index d1a9b11b2..f68f0521d 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -2,7 +2,6 @@ import os import unittest import warnings -import copy import pytest import torch @@ -467,37 +466,42 @@ def _inner_test_save_load(self, tmpdir): class TestLlavaForConditionalGeneration_Auto(TestLlavaForConditionalGeneration): - TEST_LEVEL = TestLevel.FULL - # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "auto", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "auto", + } + ) return super().setUpClass() + class TestLlavaForConditionalGeneration_Bfloat16(TestLlavaForConditionalGeneration): TEST_LEVEL = TestLevel.FULL - + # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "bfloat16", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "bfloat16", + } + ) return super().setUpClass() class TestLlavaForConditionalGeneration_Float16(TestLlavaForConditionalGeneration): TEST_LEVEL = TestLevel.FULL - + # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "float16", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "float16", + } + ) return super().setUpClass() @@ -600,38 +604,42 @@ def test_complicate_config(self): class TestLlavaNextForConditionalGeneration_Auto(TestLlavaNextForConditionalGeneration): - TEST_LEVEL = TestLevel.FULL - # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "auto", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "auto", + } + ) return super().setUpClass() class TestLlavaNextForConditionalGeneration_Bfloat16(TestLlavaNextForConditionalGeneration): TEST_LEVEL = TestLevel.FULL - + # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "bfloat16", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "bfloat16", + } + ) return super().setUpClass() class TestLlavaNextForConditionalGeneration_Float16(TestLlavaNextForConditionalGeneration): TEST_LEVEL = TestLevel.FULL - + # override @classmethod def setUpClass(cls): - cls.HF_CONFIG_KWARGS.update({ - "dtype": "float16", - }) + cls.HF_CONFIG_KWARGS.update( + { + "dtype": "float16", + } + ) return super().setUpClass() @@ -686,7 +694,6 @@ def _inner_test_save_load(self, tmpdir): class TestBlip2ForConditionalGeneration_Auto(TestBlip2ForConditionalGeneration): - TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "auto", } @@ -736,7 +743,6 @@ def get_inputs(self): class TestIdefics3ForConditionalGeneration_Auto(TestIdefics3ForConditionalGeneration): - TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "auto", } @@ -749,12 +755,13 @@ class TestIdefics3ForConditionalGeneration_Bfloat16(TestIdefics3ForConditionalGe } -class TestIdefics3ForConditionalGeneration_Float16(TestIdefics3ForConditionalGeneration): +class TestIdefics3ForConditionalGeneration_Float16(TestIdefics3ForConditionalGeneration): TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "float16", } + class TestQwen2VLForConditionalGeneration(LLMTest.TestLLM): RBLN_AUTO_CLASS = RBLNAutoModelForVision2Seq RBLN_CLASS = RBLNQwen2VLForConditionalGeneration From 6c8f5250456a0d6c53da6ebd5c9921adfbf55547 Mon Sep 17 00:00:00 2001 From: seinpark Date: Wed, 28 Jan 2026 19:56:09 +0900 Subject: [PATCH 08/14] feat:support non-fp32 dtype for idefics3 --- .../rbln/transformers/models/clip/modeling_clip.py | 1 + .../models/decoderonly/modeling_decoderonly.py | 2 +- .../transformers/models/idefics3/modeling_idefics3.py | 10 ++++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/optimum/rbln/transformers/models/clip/modeling_clip.py b/src/optimum/rbln/transformers/models/clip/modeling_clip.py index 3274bf760..380a4f5c8 100644 --- a/src/optimum/rbln/transformers/models/clip/modeling_clip.py +++ b/src/optimum/rbln/transformers/models/clip/modeling_clip.py @@ -158,6 +158,7 @@ class RBLNCLIPVisionModel(RBLNModel): """ _tp_support = False + _supports_non_fp32 = True @classmethod def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module: diff --git a/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py b/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py index 8b6e10839..8b2e55579 100644 --- a/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +++ b/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py @@ -79,7 +79,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin): def __post_init__(self, **kwargs): if self.rbln_config.use_inputs_embeds: artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False) - self.embed_tokens = self._create_embedding_layer() + self.embed_tokens = self._create_embedding_layer().to(dtype=self.rbln_config.dtype) self.embed_tokens.load_state_dict(artifacts["embed_tokens"]) else: self.embed_tokens = None diff --git a/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py b/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py index 35f8d28bc..bc6bd9c0f 100644 --- a/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +++ b/src/optimum/rbln/transformers/models/idefics3/modeling_idefics3.py @@ -87,12 +87,14 @@ def forward( class RBLNIdefics3VisionTransformer(RBLNModel): _tp_support = False + _supports_non_fp32 = True def __post_init__(self, **kwargs): artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False) with no_init_weights(): self.embeddings = Idefics3VisionEmbeddings(self.config) self.embeddings.load_state_dict(artifacts["embeddings"]) + self.embeddings = self.embeddings.to(dtype=self.rbln_config.dtype) self.model = RBLNRuntimeVisionModel( self.model[0], main_input_name="pixel_values", config=self.config, embeddings=self.embeddings ) @@ -150,7 +152,7 @@ def _update_rbln_config( (model_config.image_size // model_config.patch_size) ** 2, model_config.hidden_size, ], - "float32", + rbln_config.dtype, ), ] @@ -170,7 +172,7 @@ def forward( (self.config.image_size // self.config.patch_size) ** 2, self.config.hidden_size, ] - last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=torch.float32, device="cpu") + last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=pixel_values.dtype, device="cpu") for i in range(pixel_values.shape[0]): if patch_attention_mask is not None: batch_attention_mask = patch_attention_mask[i : i + 1,] @@ -295,7 +297,7 @@ def _update_rbln_config( (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2, model_config.vision_config.hidden_size, ], - "float32", + rbln_config.dtype, ), ] @@ -435,7 +437,7 @@ def _preprocess_prefill( image_hidden_states.shape[1] // self.config.scale_factor**2, self.config.text_config.hidden_size, ] - connector_outputs = torch.empty(size=connector_output_size, dtype=torch.float32, device="cpu") + connector_outputs = torch.empty(size=connector_output_size, dtype=image_hidden_states.dtype, device="cpu") for i in range(image_hidden_states.shape[0]): self.connector(image_hidden_states[i : i + 1,], out=connector_outputs[i : i + 1,]) image_hidden_states = connector_outputs From 3d187506d6560655d86ecd3a60db3999c8e526d9 Mon Sep 17 00:00:00 2001 From: seinpark Date: Wed, 28 Jan 2026 20:08:56 +0900 Subject: [PATCH 09/14] feat:support non-fp32 dtype for llavanext --- .../transformers/models/llava_next/modeling_llava_next.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py index 339a8f3d0..f35ac01f6 100644 --- a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +++ b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py @@ -222,7 +222,7 @@ def _update_rbln_config( ( "image_features", [rbln_config.vision_tower.batch_size, selected_image_feature_dim, feature_size], - "float32", + rbln_config.dtype, ) ] rbln_compile_config = RBLNCompileConfig(input_info=input_info) @@ -309,15 +309,15 @@ def get_image_features( pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size] vision_out_buffer = [] for _ in range(self.config.vision_config.num_hidden_layers + 2): - vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu")) - vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu")) + vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.dtype, device="cpu")) + vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.dtype, device="cpu")) projector_out_size = [ pixel_values.shape[0] * pixel_values.shape[1], (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2, self.config.text_config.hidden_size, ] - projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")] + projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")] if pixel_values.dim() == 5: # stacked if input is (batch_size, num_patches, num_channels, height, width) From 31164d239fb151a0ebcf9c524ef357bc6b41eb59 Mon Sep 17 00:00:00 2001 From: seinpark Date: Wed, 28 Jan 2026 20:09:05 +0900 Subject: [PATCH 10/14] feat:support non-fp32 dtype for llava --- .../rbln/transformers/models/llava/modeling_llava.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/optimum/rbln/transformers/models/llava/modeling_llava.py b/src/optimum/rbln/transformers/models/llava/modeling_llava.py index 3b74b8ce4..548102528 100644 --- a/src/optimum/rbln/transformers/models/llava/modeling_llava.py +++ b/src/optimum/rbln/transformers/models/llava/modeling_llava.py @@ -250,7 +250,7 @@ def _update_rbln_config( selected_image_feature_dim, model_config.vision_config.hidden_size, ], - "float32", + rbln_config.dtype, ) ] @@ -342,9 +342,9 @@ def get_image_features( vision_out_buffer = [] for _ in range(self.config.vision_config.num_hidden_layers + 2): - vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu")) + vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.dtype, device="cpu")) if pooler_out_size is not None: - vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu")) + vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.dtype, device="cpu")) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, out=vision_out_buffer, **kwargs) @@ -380,7 +380,7 @@ def get_image_features( split_features = torch.cat(chunks, dim=0) num_chunks = len(chunks) projector_out_size = [1, max_patches * num_chunks, self.config.text_config.hidden_size] - projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")] + projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")] projected_features = self.multi_modal_projector(split_features, out=projector_out_buffer) projected_features = projected_features.view( selected_image_feature.shape[0], num_chunks * max_patches, self.config.text_config.hidden_size @@ -392,7 +392,7 @@ def get_image_features( (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2, self.config.text_config.hidden_size, ] - projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")] + projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")] image_features = self.multi_modal_projector(selected_image_feature, out=projector_out_buffer) return image_features From aee80b0cbd254cda9eba656663e0b564d63e51c3 Mon Sep 17 00:00:00 2001 From: seinpark Date: Thu, 29 Jan 2026 12:10:13 +0900 Subject: [PATCH 11/14] fix:specify the out buffer dtype --- src/optimum/rbln/transformers/models/llava/modeling_llava.py | 4 ++-- .../transformers/models/llava_next/modeling_llava_next.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/optimum/rbln/transformers/models/llava/modeling_llava.py b/src/optimum/rbln/transformers/models/llava/modeling_llava.py index 548102528..4b98aeea9 100644 --- a/src/optimum/rbln/transformers/models/llava/modeling_llava.py +++ b/src/optimum/rbln/transformers/models/llava/modeling_llava.py @@ -342,9 +342,9 @@ def get_image_features( vision_out_buffer = [] for _ in range(self.config.vision_config.num_hidden_layers + 2): - vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.dtype, device="cpu")) + vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")) if pooler_out_size is not None: - vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.dtype, device="cpu")) + vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, out=vision_out_buffer, **kwargs) diff --git a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py index f35ac01f6..277d57a2a 100644 --- a/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +++ b/src/optimum/rbln/transformers/models/llava_next/modeling_llava_next.py @@ -309,8 +309,8 @@ def get_image_features( pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size] vision_out_buffer = [] for _ in range(self.config.vision_config.num_hidden_layers + 2): - vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.dtype, device="cpu")) - vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.dtype, device="cpu")) + vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")) + vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")) projector_out_size = [ pixel_values.shape[0] * pixel_values.shape[1], From cdc9ff1e6d559655827c07917f65ca84f8a5fb60 Mon Sep 17 00:00:00 2001 From: seinpark Date: Thu, 29 Jan 2026 16:13:59 +0900 Subject: [PATCH 12/14] (tmp)feat:support non-fp32 dtype for blip2 --- .../models/blip_2/modeling_blip_2.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py b/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py index c4e47f0c3..9aa53a6d5 100644 --- a/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +++ b/src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py @@ -66,6 +66,7 @@ class RBLNBlip2VisionModel(RBLNModel): """ _tp_support = False + _supports_non_fp32 = True def get_input_embeddings(self): return self.embeddings @@ -100,7 +101,7 @@ def _update_rbln_config( model_config.image_size, model_config.image_size, ], - "float32", + rbln_config.dtype, ), ] @@ -213,7 +214,7 @@ def _update_rbln_config( rbln_config.num_query_tokens, model_config.hidden_size, ], - "float32", + rbln_config.dtype, ), ( "encoder_hidden_states", @@ -223,7 +224,7 @@ def _update_rbln_config( rbln_config.image_text_hidden_size + 1, model_config.encoder_hidden_size, ], - "float32", + rbln_config.dtype, ), ( "encoder_attention_mask", @@ -375,6 +376,8 @@ def _update_rbln_config( model_config: Optional["PretrainedConfig"] = None, rbln_config: Optional[RBLNModelConfig] = None, ) -> RBLNModelConfig: + # FIXME(seinpark): need to check all dtypes are properly set. + rbln_config.dtype = model.language_projection.weight.dtype input_info = [ ( "query_output", @@ -383,7 +386,7 @@ def _update_rbln_config( model_config.num_query_tokens, model_config.qformer_config.hidden_size, ], - "float32", + rbln_config.dtype, ), ] @@ -475,7 +478,7 @@ def generate( """ batch_size = pixel_values.shape[0] image_embeds = self.vision_model( - pixel_values, + pixel_values.to(self.rbln_config.vision_model.dtype), return_dict=True, interpolate_pos_encoding=interpolate_pos_encoding, ).last_hidden_state @@ -483,8 +486,8 @@ def generate( query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, + query_embeds=query_tokens.to(self.rbln_config.qformer.dtype), + encoder_hidden_states=image_embeds.to(self.rbln_config.qformer.dtype), encoder_attention_mask=image_attention_mask, return_dict=True, ) @@ -514,6 +517,7 @@ def generate( else: special_image_mask = input_ids == self.config.image_token_id + inputs_embeds = inputs_embeds.to(self.rbln_config.language_model.dtype) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) From 07792fcc469dc7e96511ed4abb7e8d229e6067dc Mon Sep 17 00:00:00 2001 From: seinpark Date: Thu, 29 Jan 2026 16:19:05 +0900 Subject: [PATCH 13/14] fix: test name --- tests/test_llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_llm.py b/tests/test_llm.py index f3c9b8c88..bb9f17448 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -477,7 +477,7 @@ def setUpClass(cls): return super().setUpClass() -class TestLlavaForConditionalGeneration_Bfloat16(TestLlavaForConditionalGeneration): +class TestLlavaForConditionalGeneration_BFloat16(TestLlavaForConditionalGeneration): TEST_LEVEL = TestLevel.FULL # override @@ -615,7 +615,7 @@ def setUpClass(cls): return super().setUpClass() -class TestLlavaNextForConditionalGeneration_Bfloat16(TestLlavaNextForConditionalGeneration): +class TestLlavaNextForConditionalGeneration_BFloat16(TestLlavaNextForConditionalGeneration): TEST_LEVEL = TestLevel.FULL # override @@ -699,7 +699,7 @@ class TestBlip2ForConditionalGeneration_Auto(TestBlip2ForConditionalGeneration): } -class TestBlip2ForConditionalGeneration_Bfloat16(TestBlip2ForConditionalGeneration): +class TestBlip2ForConditionalGeneration_BFloat16(TestBlip2ForConditionalGeneration): TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "bfloat16", @@ -748,7 +748,7 @@ class TestIdefics3ForConditionalGeneration_Auto(TestIdefics3ForConditionalGenera } -class TestIdefics3ForConditionalGeneration_Bfloat16(TestIdefics3ForConditionalGeneration): +class TestIdefics3ForConditionalGeneration_BFloat16(TestIdefics3ForConditionalGeneration): TEST_LEVEL = TestLevel.FULL HF_CONFIG_KWARGS = { "dtype": "bfloat16", From 2bcbf9072f8fc0b9cdbfa7cb2841a4e178d1a68b Mon Sep 17 00:00:00 2001 From: seinpark Date: Thu, 29 Jan 2026 17:11:44 +0900 Subject: [PATCH 14/14] feat:support non-fp32 dtype for colpali, paligemma, siglip --- .../models/paligemma/modeling_paligemma.py | 10 ++++++---- .../rbln/transformers/models/siglip/modeling_siglip.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/optimum/rbln/transformers/models/paligemma/modeling_paligemma.py b/src/optimum/rbln/transformers/models/paligemma/modeling_paligemma.py index 8c385e515..a78cfad65 100644 --- a/src/optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +++ b/src/optimum/rbln/transformers/models/paligemma/modeling_paligemma.py @@ -99,6 +99,7 @@ class RBLNPaliGemmaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration {"name": "vision_tower"}, {"name": "language_model"}, ] + _supports_non_fp32 = True def __getattr__(self, __name: str) -> Any: def redirect(func): @@ -152,7 +153,7 @@ def __post_init__(self, **kwargs): artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False) self.embed_tokens = self._create_embedding_layer() self.embed_tokens.load_state_dict(artifacts["embed_tokens"]) - self.multi_modal_projector = self._create_multi_modal_projector() + self.multi_modal_projector = self._create_multi_modal_projector().to(self.rbln_config.dtype) self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"]) return super().__post_init__(**kwargs) @@ -239,7 +240,7 @@ def get_image_features(self, pixel_values: torch.Tensor): self.config.vision_config.num_image_tokens, self.config.vision_config.hidden_size, ] - vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu") + vision_output = torch.empty(size=vision_output_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu") self.vision_tower(pixel_values, out=vision_output) image_features = self.multi_modal_projector(vision_output) image_features = image_features / (self.config.text_config.hidden_size**0.5) @@ -383,6 +384,7 @@ class RBLNPaliGemmaModel(RBLNModel): {"name": "vision_tower"}, {"name": "language_model"}, ] + _supports_non_fp32 = True def __post_init__(self, **kwargs): self.vision_tower = LoopVisionTower(self.rbln_submodules[0]) @@ -401,7 +403,7 @@ def __post_init__(self, **kwargs): artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False) self.embed_tokens = self._create_embedding_layer() self.embed_tokens.load_state_dict(artifacts["embed_tokens"]) - self.multi_modal_projector = self._create_multi_modal_projector() + self.multi_modal_projector = self._create_multi_modal_projector().to(self.rbln_config.dtype) self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"]) return super().__post_init__(**kwargs) @@ -459,7 +461,7 @@ def get_image_features(self, pixel_values: torch.Tensor): self.config.vision_config.num_image_tokens, self.config.vision_config.hidden_size, ] - vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu") + vision_output = torch.empty(size=vision_output_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu") self.vision_tower(pixel_values, out=vision_output) image_features = self.multi_modal_projector(vision_output) image_features = image_features / (self.config.text_config.hidden_size**0.5) diff --git a/src/optimum/rbln/transformers/models/siglip/modeling_siglip.py b/src/optimum/rbln/transformers/models/siglip/modeling_siglip.py index 2889509d1..90bada888 100644 --- a/src/optimum/rbln/transformers/models/siglip/modeling_siglip.py +++ b/src/optimum/rbln/transformers/models/siglip/modeling_siglip.py @@ -65,6 +65,7 @@ class RBLNSiglipVisionModel(RBLNModel): """ _tp_support = False + _supports_non_fp32 = True @classmethod def _wrap_model_if_needed( @@ -108,7 +109,7 @@ def _update_rbln_config( rbln_config.image_height, rbln_config.image_width, ], - "float32", + rbln_config.dtype, ) ] )