Skip to content

Commit cdc72e3

Browse files
authored
[Model] Remap FP8 kv_scale in CommandR and DBRX (#9174)
1 parent 7627172 commit cdc72e3

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

vllm/model_executor/models/commandr.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from vllm.model_executor.layers.vocab_parallel_embedding import (
4242
VocabParallelEmbedding)
4343
from vllm.model_executor.model_loader.weight_utils import (
44-
default_weight_loader, row_parallel_weight_loader)
44+
default_weight_loader, maybe_remap_kv_scale_name,
45+
row_parallel_weight_loader)
4546
from vllm.model_executor.sampling_metadata import SamplingMetadata
4647
from vllm.model_executor.utils import set_weight_attrs
4748
from vllm.sequence import IntermediateTensors
@@ -426,6 +427,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
426427
# Skip loading extra bias for GPTQ models.
427428
if name.endswith(".bias") and name not in params_dict:
428429
continue
430+
# Remapping the name of FP8 kv-scale.
431+
name = maybe_remap_kv_scale_name(name, params_dict)
432+
if name is None:
433+
continue
434+
429435
if is_pp_missing_parameter(name, self):
430436
continue
431437
param = params_dict[name]

vllm/model_executor/models/dbrx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
1919
from vllm.model_executor.layers.vocab_parallel_embedding import (
2020
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
21-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21+
from vllm.model_executor.model_loader.weight_utils import (
22+
default_weight_loader, maybe_remap_kv_scale_name)
2223
from vllm.model_executor.sampling_metadata import SamplingMetadata
2324
from vllm.sequence import IntermediateTensors
2425
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -425,6 +426,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
425426
weight_loader(param, loaded_weight, weight_name)
426427
break
427428
else:
429+
# Remapping the name of FP8 kv-scale.
430+
name = maybe_remap_kv_scale_name(name, params_dict)
431+
if name is None:
432+
continue
433+
428434
if is_pp_missing_parameter(name, self):
429435
continue
430436
param = params_dict[name]

0 commit comments

Comments
 (0)