Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][VLM] Decouple weight loading logic for Paligemma #8269

Merged
merged 3 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 35 additions & 77 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand All @@ -13,7 +14,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
Expand All @@ -22,14 +23,10 @@
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_multimodal_embeddings
from .utils import filter_weights, merge_multimodal_embeddings

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model",
}


class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down Expand Up @@ -151,8 +148,8 @@ def __init__(self,
projection_dim=config.vision_config.projection_dim)

self.quant_config = quant_config
self.language_model = GemmaModel(config.text_config, cache_config,
quant_config)
self.language_model = GemmaForCausalLM(config.text_config,
cache_config, quant_config)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
Expand Down Expand Up @@ -252,7 +249,8 @@ def forward(self,
vision_embeddings = vision_embeddings * (self.config.hidden_size**
-0.5)

inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
Expand All @@ -262,87 +260,47 @@ def forward(self,
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)

return hidden_states

# Copied from vllm/model_executor/models/gemma.py
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens,
hidden_states, sampling_metadata)
return logits
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

# Copied from vllm/model_executor/models/gemma.py
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)

# Adapted from vllm/model_executor/models/gemma.py
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" not in name or self.vision_tower.shard_weight:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else:
use_default_weight_loading = True

if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

loaded_params.add(name)

unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

# load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
23 changes: 19 additions & 4 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,12 @@ def forward(
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

Expand All @@ -544,7 +550,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if layer_idx >= layer_count:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading