Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
Expand Down
38 changes: 38 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,42 @@ def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData:
)


# GLM-OCR
def run_glm_ocr(questions: list[str], modality: str) -> ModelRequestData:
model_name = "zai-org/GLM-OCR"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)

if modality == "image":
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
elif modality == "video":
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"

prompts = [
(
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
f"{placeholder}"
f"{question}<|assistant|>assistant\n"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1962,6 +1998,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"glm4_1v": run_glm4_1v,
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"glm_ocr": run_glm_ocr,
"h2ovl_chat": run_h2ovl,
"hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
Expand Down Expand Up @@ -2013,6 +2050,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:

MODELS_NEED_VIDEO_METADATA = [
"glm4_1v",
"glm_ocr",
"glm4_5v",
"glm4_5v_fp8",
"molmo2",
Expand Down
14 changes: 14 additions & 0 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,20 @@
],
marks=[large_gpu_mark(min_gb=32)],
),
"glm_ocr": VLMTestInfo(
models=["zai-org/GLM-OCR"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
max_model_len=2048,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models=[
"h2oai/h2ovl-mississippi-800m",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@
"use_processor": True,
"question": "What is the content of each image?",
},
"glm_ocr": {
"model_name": "zai-org/GLM-OCR",
"interface": "llm_generate",
"max_model_len": 131072,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "Text Recognition:",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"ernie4_5_moe_vl": qwen3_vl_patch_mm_data,
"glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data,
"glm_ocr": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data,
"molmo2": qwen3_vl_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
Expand Down
12 changes: 11 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def check_available_online(
"Glm4MoeLiteForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
is_available_online=False,
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
Expand Down Expand Up @@ -707,6 +706,11 @@ def check_available_online(
),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
"GlmOcrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.0.0.dev",
),
"H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
Expand Down Expand Up @@ -1053,7 +1057,13 @@ def check_available_online(
"Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
),
"GlmOcrMTPModel": _HfExamplesInfo(
"zai-org/GLM-OCR",
speculative_model="zai-org/GLM-OCR",
Comment thread
zRzRzRzRzRzRzR marked this conversation as resolved.
is_available_online=False,
min_transformers_version="5.0.0.dev",
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
Expand Down
12 changes: 12 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"mimo_mtp",
"glm4_moe_mtp",
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
Expand Down Expand Up @@ -221,6 +222,17 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
}
)

if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)

if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
Expand Down
101 changes: 99 additions & 2 deletions vllm/model_executor/models/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,22 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType

from .interfaces import SupportsLoRA, SupportsPP
from .llama import LlamaMLP as Glm4MLP
from .llama import LlamaModel
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix,
)


class Glm4Attention(nn.Module):
Expand Down Expand Up @@ -78,7 +87,15 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)

rope_params = getattr(config, "rope_parameters", None)
if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params:
config.rope_parameters.setdefault(
"partial_rotary_factor", rope_params["partial_rotary_factor"]
)
else:
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)

self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
Expand Down Expand Up @@ -220,6 +237,73 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
Expand Down Expand Up @@ -293,3 +377,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


def get_spec_layer_idx_from_weight_name(
config: Glm4Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if f"layers.{layer_idx + i}." in weight_name:
return layer_idx + i
return None
5 changes: 3 additions & 2 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
compatible with HuggingFace weights."""

import itertools
import math
Expand Down Expand Up @@ -1418,7 +1419,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(prefix, "visual"),
)

if config.model_type == "glm4v":
if config.model_type in ("glm4v", "glm_ocr"):
architectures = ["Glm4ForCausalLM"]
elif config.model_type == "glm4v_moe":
architectures = ["Glm4MoeForCausalLM"]
Expand Down
Loading