Skip to content

Commit 6fcc39b

Browse files
committed
support qwen3-vl series
1 parent 03792ad commit 6fcc39b

File tree

15 files changed

+765
-46
lines changed

15 files changed

+765
-46
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
*.so
2121
*.swp
2222
*.tmp
23+
*.patch
2324

2425
# IDE / OS
2526

check_model.py

Whitespace-only changes.

convert_hf_to_gguf.py

Lines changed: 236 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3538,6 +3538,147 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35383538
return super().modify_tensors(data_torch, name, bid)
35393539

35403540

3541+
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
3542+
class Qwen3VLVisionModel(MmprojModel):
3543+
def __init__(self, *args, **kwargs):
3544+
super().__init__(*args, **kwargs)
3545+
assert self.has_vision_encoder
3546+
assert self.hparams_vision is not None
3547+
3548+
# Compute image_size if not present
3549+
if "image_size" not in self.hparams_vision:
3550+
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
3551+
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
3552+
patch_size = self.hparams_vision.get("patch_size", 16)
3553+
# num_position_embeddings = (image_size / patch_size) ** 2
3554+
# So image_size = sqrt(num_position_embeddings) * patch_size
3555+
import math
3556+
image_size = int(math.sqrt(num_pos) * patch_size)
3557+
self.hparams_vision["image_size"] = image_size
3558+
3559+
# Rename config values for compatibility
3560+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
3561+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
3562+
3563+
self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", []))
3564+
3565+
def set_gguf_parameters(self):
3566+
super().set_gguf_parameters()
3567+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
3568+
3569+
if self.hparams_vision is not None:
3570+
merge_size = self.hparams_vision.get("spatial_merge_size")
3571+
if merge_size is not None:
3572+
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
3573+
3574+
hidden_act = (self.hparams_vision.get("hidden_act") or "").lower()
3575+
if hidden_act:
3576+
if "gelu" in hidden_act:
3577+
self.gguf_writer.add_vision_use_gelu(True)
3578+
elif hidden_act == "silu":
3579+
self.gguf_writer.add_vision_use_silu(True)
3580+
else:
3581+
raise ValueError(f"Unsupported hidden_act: {hidden_act}")
3582+
3583+
# Use text config's rms_norm_eps for vision attention layernorm eps (similar to qwen2vl)
3584+
rms_norm_eps = self.global_config.get("rms_norm_eps")
3585+
if rms_norm_eps is None:
3586+
# Try text_config
3587+
text_config = self.global_config.get("text_config", {})
3588+
rms_norm_eps = text_config.get("rms_norm_eps", 1e-6)
3589+
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
3590+
3591+
if self.deepstack_layers:
3592+
self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers)
3593+
3594+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3595+
# Skip text model tensors - they go in the text model file
3596+
if name.startswith("model.language_model.") or name.startswith("lm_head."):
3597+
return []
3598+
3599+
if name.startswith("model.visual."):
3600+
name = name.replace("model.visual.", "visual.", 1)
3601+
3602+
if name.startswith("visual.deepstack_merger_list."):
3603+
prefix, rest = name.split(".", maxsplit=3)[2:]
3604+
idx = int(prefix)
3605+
target = rest
3606+
3607+
tensor_type: gguf.MODEL_TENSOR
3608+
if target.startswith("norm."):
3609+
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
3610+
suffix = target.split(".", 1)[1]
3611+
elif target.startswith("linear_fc1."):
3612+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
3613+
suffix = target.split(".", 1)[1]
3614+
elif target.startswith("linear_fc2."):
3615+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
3616+
suffix = target.split(".", 1)[1]
3617+
else:
3618+
raise ValueError(f"Unexpected deepstack tensor: {name}")
3619+
3620+
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
3621+
return [(new_name, data_torch)]
3622+
3623+
if name.startswith("visual.merger."):
3624+
suffix = name.split(".", 2)[2]
3625+
if suffix.startswith("linear_fc"):
3626+
fc_idx_str, tail = suffix.split(".", 1)
3627+
fc_num = int(fc_idx_str.replace("linear_fc", ""))
3628+
# Qwen3VLMoe has linear_fc1 and linear_fc2
3629+
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
3630+
if fc_num == 1:
3631+
fc_idx = 0
3632+
elif fc_num == 2:
3633+
fc_idx = 2
3634+
else:
3635+
raise ValueError(f"unexpected fc index {fc_num} in {name}")
3636+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
3637+
elif suffix.startswith("norm."):
3638+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
3639+
else:
3640+
raise ValueError(f"Unexpected merger tensor: {name}")
3641+
return [(new_name, data_torch)]
3642+
3643+
if name == "visual.patch_embed.proj.weight":
3644+
# split Conv3D into Conv2Ds along temporal dimension
3645+
c1, c2, kt, _, _ = data_torch.shape
3646+
del c1, c2
3647+
if kt != 2:
3648+
raise ValueError("Current implementation only supports temporal_patch_size of 2")
3649+
return [
3650+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
3651+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
3652+
]
3653+
3654+
if name == "visual.patch_embed.proj.bias":
3655+
# Include the bias - it's used by the C++ code
3656+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
3657+
3658+
if name.startswith("visual."):
3659+
if ".qkv." in name:
3660+
if data_torch.ndim == 2:
3661+
c3, _ = data_torch.shape
3662+
else:
3663+
c3 = data_torch.shape[0]
3664+
if c3 % 3 != 0:
3665+
raise ValueError(f"Unexpected QKV shape for {name}: {data_torch.shape}")
3666+
c = c3 // 3
3667+
wq = data_torch[:c]
3668+
wk = data_torch[c: c * 2]
3669+
wv = data_torch[c * 2:]
3670+
base = name.replace("qkv", "{placeholder}")
3671+
return [
3672+
(self.map_tensor_name(base.format(placeholder="q")), wq),
3673+
(self.map_tensor_name(base.format(placeholder="k")), wk),
3674+
(self.map_tensor_name(base.format(placeholder="v")), wv),
3675+
]
3676+
3677+
return [(self.map_tensor_name(name), data_torch)]
3678+
3679+
# Fall back to parent class for other tensors
3680+
return super().modify_tensors(data_torch, name, bid)
3681+
35413682
@ModelBase.register("InternVisionModel")
35423683
class InternVisionModel(MmprojModel):
35433684
def set_gguf_parameters(self):
@@ -3678,7 +3819,43 @@ def set_gguf_parameters(self):
36783819
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
36793820
# process the experts separately
36803821
name = name.replace("language_model.", "") # InternVL
3681-
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
3822+
3823+
# handle aggregated expert tensors
3824+
# GGUF stores dimensions reversed from PyTorch, so:
3825+
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3826+
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3827+
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3828+
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
3829+
mapped = f"{name}.weight" if not name.endswith(".weight") else name
3830+
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3831+
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3832+
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
3833+
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3834+
permuted = data_torch.permute(0, 2, 1).contiguous()
3835+
return [(self.map_tensor_name(mapped), permuted)]
3836+
3837+
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
3838+
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
3839+
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
3840+
split_dim = data_torch.shape[-1] // 2
3841+
gate = data_torch[..., :split_dim].contiguous()
3842+
up = data_torch[..., split_dim:].contiguous()
3843+
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3844+
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3845+
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
3846+
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3847+
base_name = name.removesuffix(".weight")
3848+
base = base_name.rsplit('.', 1)[0]
3849+
mapped_gate = f"{base}.gate_proj.weight"
3850+
mapped_up = f"{base}.up_proj.weight"
3851+
perm_gate = gate.permute(0, 2, 1).contiguous()
3852+
perm_up = up.permute(0, 2, 1).contiguous()
3853+
return [
3854+
(self.map_tensor_name(mapped_gate), perm_gate),
3855+
(self.map_tensor_name(mapped_up), perm_up),
3856+
]
3857+
3858+
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
36823859
# skip visual tensors
36833860
return []
36843861
if name.find("experts") != -1:
@@ -3826,6 +4003,64 @@ def set_vocab(self):
38264003
super().set_vocab()
38274004

38284005

4006+
@ModelBase.register("Qwen3VLForConditionalGeneration")
4007+
class Qwen3VLTextModel(Qwen3Model):
4008+
model_arch = gguf.MODEL_ARCH.QWEN3VL
4009+
4010+
def set_gguf_parameters(self):
4011+
super().set_gguf_parameters()
4012+
4013+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4014+
text_config = self.hparams.get("text_config", {})
4015+
rope_scaling = text_config.get("rope_scaling") or {}
4016+
4017+
if rope_scaling.get("mrope_section"):
4018+
# mrope_section contains [time, height, width] dimensions
4019+
mrope_section = rope_scaling["mrope_section"]
4020+
# Pad to 4 dimensions [time, height, width, extra]
4021+
while len(mrope_section) < 4:
4022+
mrope_section.append(0)
4023+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4024+
4025+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4026+
4027+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4028+
# Skip vision tensors - they go in the mmproj file
4029+
if name.startswith("model.visual."):
4030+
return []
4031+
4032+
return super().modify_tensors(data_torch, name, bid)
4033+
4034+
4035+
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
4036+
class Qwen3VLMoeTextModel(Qwen3MoeModel):
4037+
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
4038+
4039+
def set_gguf_parameters(self):
4040+
super().set_gguf_parameters()
4041+
4042+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4043+
text_config = self.hparams.get("text_config", {})
4044+
rope_scaling = text_config.get("rope_scaling") or {}
4045+
4046+
if rope_scaling.get("mrope_section"):
4047+
# mrope_section contains [time, height, width] dimensions
4048+
mrope_section = rope_scaling["mrope_section"]
4049+
# Pad to 4 dimensions [time, height, width, extra]
4050+
while len(mrope_section) < 4:
4051+
mrope_section.append(0)
4052+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4053+
4054+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4055+
4056+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4057+
# Skip vision tensors - they go in the mmproj file
4058+
if name.startswith("model.visual."):
4059+
return []
4060+
4061+
return super().modify_tensors(data_torch, name, bid)
4062+
4063+
38294064
@ModelBase.register("GPT2LMHeadModel")
38304065
class GPT2Model(TextModel):
38314066
model_arch = gguf.MODEL_ARCH.GPT2

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5509,6 +5509,7 @@ static void ggml_mrope_cache_init(
55095509
}
55105510

55115511
float theta = theta_t;
5512+
55125513
if (sector >= sections[0] && sector < sec_w) {
55135514
theta = theta_h;
55145515
}

0 commit comments

Comments
 (0)