Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/optimum/rbln/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class RBLNBlip2VisionModel(RBLNModel):
"""

_tp_support = False
_supports_non_fp32 = True

def get_input_embeddings(self):
return self.embeddings
Expand Down Expand Up @@ -100,7 +101,7 @@ def _update_rbln_config(
model_config.image_size,
model_config.image_size,
],
"float32",
rbln_config.dtype,
),
]

Expand Down Expand Up @@ -213,7 +214,7 @@ def _update_rbln_config(
rbln_config.num_query_tokens,
model_config.hidden_size,
],
"float32",
rbln_config.dtype,
),
(
"encoder_hidden_states",
Expand All @@ -223,7 +224,7 @@ def _update_rbln_config(
rbln_config.image_text_hidden_size + 1,
model_config.encoder_hidden_size,
],
"float32",
rbln_config.dtype,
),
(
"encoder_attention_mask",
Expand Down Expand Up @@ -315,6 +316,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi

auto_model_class = AutoModelForVisualQuestionAnswering
_rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
_supports_non_fp32 = True

def __getattr__(self, __name: str) -> Any:
def redirect(func):
Expand Down Expand Up @@ -374,6 +376,8 @@ def _update_rbln_config(
model_config: Optional["PretrainedConfig"] = None,
rbln_config: Optional[RBLNModelConfig] = None,
) -> RBLNModelConfig:
# FIXME(seinpark): need to check all dtypes are properly set.
rbln_config.dtype = model.language_projection.weight.dtype
input_info = [
(
"query_output",
Expand All @@ -382,7 +386,7 @@ def _update_rbln_config(
model_config.num_query_tokens,
model_config.qformer_config.hidden_size,
],
"float32",
rbln_config.dtype,
),
]

Expand Down Expand Up @@ -474,16 +478,16 @@ def generate(
"""
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(
pixel_values,
pixel_values.to(self.rbln_config.vision_model.dtype),
return_dict=True,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
query_embeds=query_tokens.to(self.rbln_config.qformer.dtype),
encoder_hidden_states=image_embeds.to(self.rbln_config.qformer.dtype),
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
Expand Down Expand Up @@ -513,6 +517,7 @@ def generate(
else:
special_image_mask = input_ids == self.config.image_token_id

inputs_embeds = inputs_embeds.to(self.rbln_config.language_model.dtype)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
Expand Down
1 change: 1 addition & 0 deletions src/optimum/rbln/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class RBLNCLIPVisionModel(RBLNModel):
"""

_tp_support = False
_supports_non_fp32 = True

@classmethod
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
_rbln_submodules = [
{"name": "vlm"},
]
_supports_non_fp32 = True

def __post_init__(self, **kwargs):
self.vlm_model = self.rbln_submodules[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
model.vlm.model.lm_head = model.embedding_proj_layer
model.vlm.model.config.embedding_dim = model.config.embedding_dim

# Some of the model weights are different from the model.dtype(vidore/colqwen2-v1.0-hf)
return model.to(model.dtype)
return model

def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
def __post_init__(self, **kwargs):
if self.rbln_config.use_inputs_embeds:
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
self.embed_tokens = self._create_embedding_layer()
self.embed_tokens = self._create_embedding_layer().to(dtype=self.rbln_config.dtype)
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
else:
self.embed_tokens = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def forward(

class RBLNIdefics3VisionTransformer(RBLNModel):
_tp_support = False
_supports_non_fp32 = True

def __post_init__(self, **kwargs):
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
with no_init_weights():
self.embeddings = Idefics3VisionEmbeddings(self.config)
self.embeddings.load_state_dict(artifacts["embeddings"])
self.embeddings = self.embeddings.to(dtype=self.rbln_config.dtype)
self.model = RBLNRuntimeVisionModel(
self.model[0], main_input_name="pixel_values", config=self.config, embeddings=self.embeddings
)
Expand Down Expand Up @@ -150,7 +152,7 @@ def _update_rbln_config(
(model_config.image_size // model_config.patch_size) ** 2,
model_config.hidden_size,
],
"float32",
rbln_config.dtype,
),
]

Expand All @@ -170,7 +172,7 @@ def forward(
(self.config.image_size // self.config.patch_size) ** 2,
self.config.hidden_size,
]
last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=torch.float32, device="cpu")
last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=pixel_values.dtype, device="cpu")
for i in range(pixel_values.shape[0]):
if patch_attention_mask is not None:
batch_attention_mask = patch_attention_mask[i : i + 1,]
Expand Down Expand Up @@ -231,6 +233,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM
auto_model_class = AutoModelForVision2Seq
_rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
_rbln_submodule_prefix = "model"
_supports_non_fp32 = True

def __getattr__(self, __name: str) -> Any:
def redirect(func):
Expand Down Expand Up @@ -294,7 +297,7 @@ def _update_rbln_config(
(model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2,
model_config.vision_config.hidden_size,
],
"float32",
rbln_config.dtype,
),
]

Expand Down Expand Up @@ -434,7 +437,7 @@ def _preprocess_prefill(
image_hidden_states.shape[1] // self.config.scale_factor**2,
self.config.text_config.hidden_size,
]
connector_outputs = torch.empty(size=connector_output_size, dtype=torch.float32, device="cpu")
connector_outputs = torch.empty(size=connector_output_size, dtype=image_hidden_states.dtype, device="cpu")
for i in range(image_hidden_states.shape[0]):
self.connector(image_hidden_states[i : i + 1,], out=connector_outputs[i : i + 1,])
image_hidden_states = connector_outputs
Expand Down
11 changes: 6 additions & 5 deletions src/optimum/rbln/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
{"name": "vision_tower"},
{"name": "language_model"},
]
_supports_non_fp32 = True

def __getattr__(self, __name: str) -> Any:
def redirect(func):
Expand Down Expand Up @@ -249,7 +250,7 @@ def _update_rbln_config(
selected_image_feature_dim,
model_config.vision_config.hidden_size,
],
"float32",
rbln_config.dtype,
)
]

Expand Down Expand Up @@ -341,9 +342,9 @@ def get_image_features(

vision_out_buffer = []
for _ in range(self.config.vision_config.num_hidden_layers + 2):
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu"))
if pooler_out_size is not None:
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu"))

image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, out=vision_out_buffer, **kwargs)

Expand Down Expand Up @@ -379,7 +380,7 @@ def get_image_features(
split_features = torch.cat(chunks, dim=0)
num_chunks = len(chunks)
projector_out_size = [1, max_patches * num_chunks, self.config.text_config.hidden_size]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")]
projected_features = self.multi_modal_projector(split_features, out=projector_out_buffer)
projected_features = projected_features.view(
selected_image_feature.shape[0], num_chunks * max_patches, self.config.text_config.hidden_size
Expand All @@ -391,7 +392,7 @@ def get_image_features(
(self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
self.config.text_config.hidden_size,
]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")]
image_features = self.multi_modal_projector(selected_image_feature, out=projector_out_buffer)

return image_features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
{"name": "vision_tower"},
{"name": "language_model"},
]
_supports_non_fp32 = True

def __getattr__(self, __name: str) -> Any:
def redirect(func):
Expand Down Expand Up @@ -221,7 +222,7 @@ def _update_rbln_config(
(
"image_features",
[rbln_config.vision_tower.batch_size, selected_image_feature_dim, feature_size],
"float32",
rbln_config.dtype,
)
]
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
Expand Down Expand Up @@ -308,15 +309,15 @@ def get_image_features(
pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size]
vision_out_buffer = []
for _ in range(self.config.vision_config.num_hidden_layers + 2):
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu"))
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu"))

projector_out_size = [
pixel_values.shape[0] * pixel_values.shape[1],
(self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
self.config.text_config.hidden_size,
]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=self.rbln_config.dtype, device="cpu")]

if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class RBLNPaliGemmaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
{"name": "vision_tower"},
{"name": "language_model"},
]
_supports_non_fp32 = True

def __getattr__(self, __name: str) -> Any:
def redirect(func):
Expand Down Expand Up @@ -152,7 +153,7 @@ def __post_init__(self, **kwargs):
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
self.embed_tokens = self._create_embedding_layer()
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
self.multi_modal_projector = self._create_multi_modal_projector()
self.multi_modal_projector = self._create_multi_modal_projector().to(self.rbln_config.dtype)
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])

return super().__post_init__(**kwargs)
Expand Down Expand Up @@ -239,7 +240,7 @@ def get_image_features(self, pixel_values: torch.Tensor):
self.config.vision_config.num_image_tokens,
self.config.vision_config.hidden_size,
]
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
vision_output = torch.empty(size=vision_output_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")
self.vision_tower(pixel_values, out=vision_output)
image_features = self.multi_modal_projector(vision_output)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
Expand Down Expand Up @@ -383,6 +384,7 @@ class RBLNPaliGemmaModel(RBLNModel):
{"name": "vision_tower"},
{"name": "language_model"},
]
_supports_non_fp32 = True

def __post_init__(self, **kwargs):
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
Expand All @@ -401,7 +403,7 @@ def __post_init__(self, **kwargs):
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
self.embed_tokens = self._create_embedding_layer()
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
self.multi_modal_projector = self._create_multi_modal_projector()
self.multi_modal_projector = self._create_multi_modal_projector().to(self.rbln_config.dtype)
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])

return super().__post_init__(**kwargs)
Expand Down Expand Up @@ -459,7 +461,7 @@ def get_image_features(self, pixel_values: torch.Tensor):
self.config.vision_config.num_image_tokens,
self.config.vision_config.hidden_size,
]
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
vision_output = torch.empty(size=vision_output_size, dtype=self.rbln_config.vision_tower.dtype, device="cpu")
self.vision_tower(pixel_values, out=vision_output)
image_features = self.multi_modal_projector(vision_output)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class RBLNSiglipVisionModel(RBLNModel):
"""

_tp_support = False
_supports_non_fp32 = True

@classmethod
def _wrap_model_if_needed(
Expand Down Expand Up @@ -108,7 +109,7 @@ def _update_rbln_config(
rbln_config.image_height,
rbln_config.image_width,
],
"float32",
rbln_config.dtype,
)
]
)
Expand Down
Loading