Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,12 +1264,6 @@ def _set_compile_ranges(self):
computed_compile_ranges_split_points
)

def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len

def try_verify_and_update_config(self):
if self.model_config is None:
return
Expand Down
53 changes: 33 additions & 20 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None:

class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config

assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
Expand All @@ -137,6 +137,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
model_config.model_arch_config.hidden_size = config.hidden_size
model_config.model_arch_config.total_num_hidden_layers = (
config.num_hidden_layers
)

head_dim = config.hidden_size // config.num_attention_heads
max_trained_positions = getattr(config, "max_trained_positions", 2048)
Expand All @@ -153,56 +157,65 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
# The context extension uses vllm style rope_theta and rope_parameters.
# See #17785 #18755
if (
not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None
not model_config.hf_overrides
and model_config.original_max_model_len is None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(
vllm_config.model_config.max_model_len, max_trained_positions
)
max_model_len_before = model_config.max_model_len
max_model_len = min(model_config.max_model_len, max_trained_positions)

vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before,
vllm_config.model_config.max_model_len,
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)

if model_config.max_model_len != max_model_len_before:
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before,
model_config.max_model_len,
)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config

if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len
"max_model_len", model_config.max_model_len
)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len
max_model_len = model_config.max_model_len

# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]

# Update the cached derived_max_model_len to enforce the limit
model_config.model_arch_config.derived_max_model_len_and_key = (
float(max_trained_positions),
"max_position_embeddings",
)

# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config

vllm_config.recalculate_max_model_len(max_model_len)
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)


class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
Expand Down