diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e64756a74ab..7ad88314af9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -552,7 +552,6 @@ def prepare_tensors(self): break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - # TODO: why do we squeeze here? # data = data_torch.squeeze().numpy() data = data_torch.numpy() @@ -646,6 +645,9 @@ def prepare_tensors(self): # n_dims is implicit in the shape logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + # Debug: print all tensors being added + print(f"DEBUG ADD: {new_name}, shape={data.shape}") + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) def set_type(self): @@ -3814,6 +3816,313 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Ernie4_5_VLMoeForConditionalGeneration") +class Ernie4_5VLMoeModel(Ernie4_5MoeModel): + model_arch = gguf.MODEL_ARCH.ERNIE4_5_VL_MOE + _experts: list[dict[str, Tensor]] | None = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._experts = [{} for _ in range(self.block_count)] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # Handle list-based expert configurations by taking the first value + moe_num_experts = self.hparams["moe_num_experts"] + if isinstance(moe_num_experts, list): + moe_num_experts = moe_num_experts[0] + self.gguf_writer.add_expert_count(moe_num_experts) + + self.gguf_writer.add_expert_used_count(self.hparams["moe_k"]) + self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"]) + + moe_layer_start_index = self.hparams["moe_layer_start_index"] + if isinstance(moe_layer_start_index, list): + moe_layer_start_index = moe_layer_start_index[0] + self.gguf_writer.add_leading_dense_block_count(moe_layer_start_index) + + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + if isinstance(moe_intermediate_size, list): + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + if len(moe_intermediate_size) > 1: + self.gguf_writer.add_vision_expert_feed_forward_length(moe_intermediate_size[1]) + else: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + if (shared_expert_count := self.hparams.get('moe_num_shared_experts')) is not None: + self.gguf_writer.add_expert_shared_count(shared_expert_count) + if shared_expert_count > 0 and (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None: + self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision and multimodal tensors - they are not part of the text model + if name.startswith("vision_model") or name.startswith("resampler_model") or \ + name.startswith("model.vision_model") or name.startswith("model.resampler_model") or \ + name.endswith(".rotary_emb.original_inv_freq") or name.endswith(".rotary_emb.inv_freq"): + return + + # todo(megemini): gate_inp weight/weight_1 + # weight + if name.endswith(".mlp.gate.weight") or name.endswith(".mlp.gate.weight_1"): + if name.endswith(".mlp.gate.weight_1"): + name = name.replace(".mlp.gate.weight_1", ".mlp.gate.vision.weight") + + data_torch = data_torch.t() + # Extract bid from name if not provided + if bid is None: + match = re.search(r"model\.layers\.(\d+)", name) + if match: + bid = int(match.group(1)) + # todo(megemini): + logger.info("Processing gate.weight/weight_1: %s -> shape %s", name, data_torch.shape) + # Map the tensor name and ensure it has .weight suffix + mapped_name = self.map_tensor_name(name) + + yield (mapped_name, data_torch) + return + + # todo(megemini): e_score_correction.bias/bias_1 for weight/weight_1 + if name.endswith(".mlp.moe_statics.e_score_correction_bias"): + name_text = name.replace("e_score_correction_bias", "e_score_correction.bias") + data_torch_text = data_torch[0, :] + + name_vision = name.replace("e_score_correction_bias", "e_score_correction.vision.bias") + data_torch_vision = data_torch[1, :] + + yield (self.map_tensor_name(name_text), data_torch_text) + yield (self.map_tensor_name(name_vision), data_torch_vision) + return + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["moe_num_experts"] + + # Handle n_experts being a list (for models with multiple expert groups) + if isinstance(n_experts, list): + total_experts = sum(n_experts) + else: + total_experts = n_experts + + assert bid is not None + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + # Only merge routed experts (not shared experts) + # Total tensors = total_experts * 3 (gate, up, down) + if len(self._experts[bid]) >= total_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # For models with multiple expert groups of different sizes, + for w_name in ["gate_proj", "up_proj", "down_proj"]: + # Collect all experts for this weight type + expert_data: dict[int, Tensor] = {} + for xid in range(total_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + if ename in self._experts[bid]: + expert_data[xid] = self._experts[bid][ename] + del self._experts[bid][ename] + + if not expert_data: + continue + + # Group experts by shape (to handle different intermediate sizes) + shape_groups: dict[tuple[int, ...], list[tuple[int, Tensor]]] = {} + for xid, tensor in expert_data.items(): + shape_key = tuple(tensor.shape) + if shape_key not in shape_groups: + shape_groups[shape_key] = [] + shape_groups[shape_key].append((xid, tensor)) + + # For each shape group, stack the experts + # For ERNIE-4.5-VL with multiple expert groups of different sizes, + # we need to save them separately as llama.cpp doesn't support mixed sizes yet + if len(shape_groups) > 1: + # Sort shape groups by number of experts (descending) + sorted_groups = sorted(shape_groups.items(), key=lambda x: len(x[1]), reverse=True) + + for group_idx, (shape_key, expert_list) in enumerate(sorted_groups): + # Sort by expert ID to maintain order + expert_list.sort(key=lambda x: x[0]) + datas = [tensor for _, tensor in expert_list] + + data_torch = torch.stack(datas, dim=0) + + # Use group suffix for additional groups + if group_idx == 0: + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + else: + merged_name = f"model.vision.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + else: + # Single shape - stack all experts + expert_list = list(shape_groups.values())[0] + expert_list.sort(key=lambda x: x[0]) + datas = [tensor for _, tensor in expert_list] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + for tensor_tuple in tensors: + yield tensor_tuple + return + else: + return + yield (self.map_tensor_name(name), data_torch) + return + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("Ernie4_5_VLMoeForConditionalGeneration") +class Ernie4_5VLMoeVisionModel(MmprojModel): + # Resampler tensor name mapping: HF name -> GGUF name + _resampler_mapping = { + "model.resampler_model.spatial_linear.0": "mm.0", + "model.resampler_model.spatial_linear.2": "mm.2", + "model.resampler_model.spatial_linear.3": "mm.3", + "model.resampler_model.temporal_linear.0": "mm_temp.0", + "model.resampler_model.temporal_linear.2": "mm_temp.2", + "model.resampler_model.temporal_linear.3": "mm_temp.3", + "model.resampler_model.mlp": "mm.mlp", + "model.resampler_model.after_norm": "mm.norm", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + # Set default vision parameters for ERNIE-4.5-VL + if "image_size" not in self.hparams_vision: + self.hparams_vision["image_size"] = 448 # default for ERNIE-4.5-VL + if "patch_size" not in self.hparams_vision: + self.hparams_vision["patch_size"] = 14 + if "hidden_size" not in self.hparams_vision: + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim", 1280) + if "intermediate_size" not in self.hparams_vision: + self.hparams_vision["intermediate_size"] = self.hparams_vision.get("mlp_ratio", 4) * self.hparams_vision["hidden_size"] + if "num_attention_heads" not in self.hparams_vision: + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads", 16) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ERNIE45VLMOE) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6)) + # ERNIE-VL uses quick_gelu activation (C++ default when neither use_gelu nor use_silu is set) + ffn_op = self.hparams_vision.get("hidden_act", "quick_gelu") + if ffn_op == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + elif ffn_op == "silu": + self.gguf_writer.add_vision_use_silu(True) + # quick_gelu: don't set either flag, C++ defaults to FFN_GELU_QUICK + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Handle resampler tensors: bias should be F32, weights F16 + # new_name is already mapped by modify_tensors (e.g., "mm.0.weight", "mm.0.bias") + if new_name.startswith("mm.") or new_name.startswith("mm_"): + if new_name.endswith(".bias"): + return gguf.GGMLQuantizationType.F32 + else: + return gguf.GGMLQuantizationType.F16 + # Let parent handle other tensors + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def prepare_tensors(self): + # Call parent prepare_tensors - resampler tensors will be handled by modify_tensors + # and their types will be controlled by tensor_force_quant + super().prepare_tensors() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Handle resampler tensors with manual mapping + for hf_prefix, gguf_prefix in self._resampler_mapping.items(): + if name.startswith(hf_prefix): + suffix = name[len(hf_prefix):] # e.g. ".weight" or ".bias" + new_name = gguf_prefix + suffix + print(f"DEBUG: Resampler mapping: {name} -> {new_name}, shape={data_torch.shape}") + # Yield the tensor - it will be handled by prepare_tensors + yield (new_name, data_torch) + return + + # Debug: print all model.* tensors that are being skipped + if name.startswith("model."): + print(f"DEBUG: Skipping model tensor: {name}") + + # Handle vision encoder tensors + if name.startswith("vision_model."): + # Split fused QKV into separate Q, K, V + if ".attn.qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + yield from super().modify_tensors(wq, name.replace("qkv", "q"), bid) + yield from super().modify_tensors(wk, name.replace("qkv", "k"), bid) + yield from super().modify_tensors(wv, name.replace("qkv", "v"), bid) + # Split Conv3D patch_embed into Conv2Ds (similar to QWEN2VL) + elif 'patch_embed.proj.weight' in name: + print(f"DEBUG: patch_embed.proj.weight shape = {data_torch.shape}") + if data_torch.ndim == 5: + # Conv3D: [out_channels, in_channels, 2, height, width] for spatial merge + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "Current implementation only supports spatial_merge_size of 2" + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]) + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]) + elif data_torch.ndim == 4: + # Conv2D: [out_channels, in_channels, height, width] - use as is + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch) + elif data_torch.ndim == 2: + # Linear projection: [out_features, in_features] = (1280, 588) + # Convert to Conv2D: (out_channels, in_channels, height, width) = (1280, 3, 14, 14) + # ERNIE-VL uses a linear layer, but we convert it to Conv2D for compatibility + out_ch, in_ch = data_torch.shape + patch_size = 14 + channels = 3 + assert in_ch == channels * patch_size * patch_size, \ + f"Expected in_features={channels * patch_size * patch_size}, got {in_ch}" + # Reshape: (out_ch, in_ch) -> (out_ch, channels, patch_size, patch_size) + # Note: data is stored as (out_ch, in_ch) = (1280, 588) + # We need to reshape to (out_ch, channels, patch_size, patch_size) = (1280, 3, 14, 14) + # The memory layout is contiguous, so we can view directly + data_conv = data_torch.view(out_ch, channels, patch_size, patch_size) + print(f"DEBUG: Converted linear to Conv2D: {data_conv.shape}") + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_conv) + else: + raise ValueError(f"Unexpected patch_embed.proj.weight shape: {data_torch.shape}") + # Handle patch_embed bias - it's used by the C++ code + # NOTE: The conv_2d output is f32 because inp_raw is created as f32 in build_inp_raw() + # So we must keep bias as f32 to match the output type of conv_2d + elif 'patch_embed.proj.bias' in name: + # Keep as f32 to match output type + if data_torch.dtype != torch.float32: + data_torch = data_torch.to(torch.float32) + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch) + else: + yield from super().modify_tensors(data_torch, name, bid) + # Skip text model tensors (model.* but not model.resampler_model.* which is handled above) + elif name.startswith("model.") or name.startswith("ernie."): + return + else: + yield from super().modify_tensors(data_torch, name, bid) + @ModelBase.register( "Qwen2VLModel", diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d5883..2baf875fc7b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -252,6 +252,7 @@ #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 #define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 +#define GGML_ROPE_TYPE_ERNIE3D 72 // binary: 1001000, ERNIE-VL 3D RoPE (NORMAL rotation + interleaved h/w freq) #define GGML_MROPE_SECTIONS 4 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ce15b18ce0e..3b608d682fa 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5651,6 +5651,43 @@ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * } } +static void ggml_ernie3d_rope_cache_init( + float theta_base_t, float theta_base_h, float theta_base_w, + int sections[4], + float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // n_hw = sections[0] + sections[1] = total number of interleaved h/w frequencies + int n_hw = sections[0] + sections[1]; + + float theta_accum = 1.0f; // accumulated theta_scale^freq_idx + + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + int freq_idx = (int)(i0 / 2); + const float ff = freq_factors ? freq_factors[freq_idx] : 1.0f; + + float theta; + if (freq_idx < n_hw) { + if (freq_idx % 2 == 0) { + // even freq index -> height position + theta = theta_base_h * theta_accum; + } else { + // odd freq index -> width position + theta = theta_base_w * theta_accum; + } + } else { + // temporal position + theta = theta_base_t * theta_accum; + } + + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta_accum *= theta_scale; + } +} + template //float or ggml_fp16_t static void ggml_compute_forward_rope_flt( const ggml_compute_params * params, @@ -5723,7 +5760,7 @@ static void ggml_compute_forward_rope_flt( if (is_vision) { GGML_ASSERT(n_dims == ne0/2); } - + const bool is_ernie3d = mode == GGML_ROPE_TYPE_ERNIE3D; const float * freq_factors = NULL; if (src2 != NULL) { GGML_ASSERT(src2->type == GGML_TYPE_F32); @@ -5745,6 +5782,14 @@ static void ggml_compute_forward_rope_flt( if (!mrope_used) { const int64_t p = pos[i2]; ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } else if (is_ernie3d) { + // ERNIE-VL 3D RoPE: interleaved h/w freq with NORMAL rotation + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + ggml_ernie3d_rope_cache_init( + p_t, p_h, p_w, sections, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } else { const int64_t p_t = pos[i2]; @@ -5765,6 +5810,7 @@ static void ggml_compute_forward_rope_flt( switch (mode) { case GGML_ROPE_TYPE_NORMAL: + case GGML_ROPE_TYPE_ERNIE3D: rotate_pairs(n_dims, 1, cache, src, dst_data, 1); break; case GGML_ROPE_TYPE_NEOX: diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..875e989704b 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -264,6 +264,68 @@ static __global__ void rope_multi(const T * x, dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } +template +static __global__ void rope_ernie3d( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + // NORMAL rotation: pair (x[i0], x[i0+1]), stored at adjacent positions + const int idst = row_dst*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; + + if (i0 >= n_dims) { + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; + return; + } + + // freq_idx = i0/2 (which frequency pair this is) + const int freq_idx = i0 / 2; + // n_hw = sections[0] + sections[1] = total number of h+w interleaved frequencies + const int n_hw = sections.v[0] + sections.v[1]; + + // Determine which position slot to use based on interleaved pattern + // Position slots: slot 0 = t_position, slot 1 = h_position, slot 2 = w_position + float theta_base = 0.0f; + if (freq_idx < n_hw) { + if (freq_idx % 2 == 0) { + // even freq index -> height position (slot 1) + theta_base = pos[channel_x + ne2 * 1] * powf(theta_scale, (float)freq_idx); + } else { + // odd freq index -> width position (slot 2) + theta_base = pos[channel_x + ne2 * 2] * powf(theta_scale, (float)freq_idx); + } + } else { + // temporal position (slot 0) + theta_base = pos[channel_x] * powf(theta_scale, (float)freq_idx); + } + + const float freq_factor = has_ff ? freq_factors[freq_idx] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + // NORMAL (GPT-J) rotation: adjacent pair (x[i0], x[i0+1]) + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; + + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + 1] = x0*sin_theta + x1*cos_theta; +} + template static __global__ void rope_vision(const T * x, T * dst, @@ -453,6 +515,29 @@ static void rope_multi_cuda(const T * x, } } +template +static void rope_ernie3d_cuda( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_ernie3d<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } else { + rope_ernie3d<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } +} + template static void rope_vision_cuda(const T * x, T * dst, @@ -603,7 +688,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, row_indices, set_rows_stride, stream); - } else { + } else if (is_ernie3d) { + if (src0->type == GGML_TYPE_F32) { + rope_ernie3d_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_ernie3d_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else { + GGML_ABORT("fatal error"); + } + } + else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8a3fab1e1c3..454bd9b9508 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -110,6 +110,7 @@ class LLM: LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + VISION_EXPERT_FEED_FORWARD_LENGTH = "{arch}.vision_expert_feed_forward_length" EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length" EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" @@ -447,6 +448,7 @@ class MODEL_ARCH(IntEnum): AFMOE = auto() ERNIE4_5 = auto() ERNIE4_5_MOE = auto() + ERNIE4_5_VL_MOE = auto() HUNYUAN_MOE = auto() HUNYUAN_DENSE = auto() SMOLLM3 = auto() @@ -723,6 +725,17 @@ class MODEL_TENSOR(IntEnum): V_DS_NORM = auto() # qwen3vl V_DS_FC1 = auto() # qwen3vl V_DS_FC2 = auto() # qwen3vl + V_FFN_GATE_INP = auto() # ernie45vlmoe + V_FFN_UP_EXPS = auto() # ernie45vlmoe + V_FFN_DOWN_EXPS = auto() # ernie45vlmoe + V_FFN_NORM_EXPS = auto() # ernie45vlmoe + V_FFN_GATE_EXPS = auto() # ernie45vlmoe + V_FFN_GATE_SHEXP = auto() # ernie45vlmoe + V_FFN_UP_SHEXP = auto() # ernie45vlmoe + V_FFN_DOWN_SHEXP = auto() # ernie45vlmoe + V_FFN_GATE_INP_SHEXP = auto() # ernie45vlmoe + V_FFN_NORM_SHEXP = auto() # ernie45vlmoe + V_FFN_EXP_PROBS_B = auto() # ernie45vlmoe V_MM_POST_FC_NORM = auto() # cogvlm V_MM_UP = auto() # cogvlm V_MM_DOWN = auto() # cogvlm @@ -879,6 +892,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.AFMOE: "afmoe", MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", + MODEL_ARCH.ERNIE4_5_VL_MOE: "ernie4_5-vl-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", @@ -1159,6 +1173,11 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", + MODEL_TENSOR.V_FFN_GATE_INP: "blk.{bid}.v_ffn_gate_inp", + MODEL_TENSOR.V_FFN_GATE_EXPS: "blk.{bid}.v_ffn_gate_exps", + MODEL_TENSOR.V_FFN_DOWN_EXPS: "blk.{bid}.v_ffn_down_exps", + MODEL_TENSOR.V_FFN_UP_EXPS: "blk.{bid}.v_ffn_up_exps", + MODEL_TENSOR.V_FFN_EXP_PROBS_B: "blk.{bid}.v_exp_probs_b", # audio (mtmd) # note: all audio tensor names must use prefix "a." or "mm.a." MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", @@ -2597,6 +2616,33 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.ERNIE4_5_VL_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.V_FFN_GATE_INP, + MODEL_TENSOR.V_FFN_GATE_EXPS, + MODEL_TENSOR.V_FFN_DOWN_EXPS, + MODEL_TENSOR.V_FFN_UP_EXPS, + MODEL_TENSOR.V_FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.PLM: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, @@ -3770,6 +3816,7 @@ class VisionProjectorType: MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" YOUTUVL = "youtuvl" + ERNIE45VLMOE = "ernie4.5vl_moe" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 62172b24c38..a8633afc257 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -717,6 +717,9 @@ def add_feed_forward_length(self, length: int | Sequence[int]) -> None: def add_expert_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_vision_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.VISION_EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_shared_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 43f32c7b522..26d6162c24e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1324,6 +1324,8 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "vision_model.embeddings.class_embedding", # ernie4.5-vl-moe + "vision_model.patch_embed.cls_embedding", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1338,10 +1340,13 @@ class TensorNameMap: "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm "siglip2.vision_model.embeddings.patch_embedding", + "vision_model.embeddings.patch_embedding", # ernie4.5-vl-moe + "vision_model.patch_embed.proj", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( "visual.post_conv_layernorm", # glm4v + "vision_model.ln", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1354,11 +1359,14 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm "visual.embeddings.position_embedding", # glm4v + "vision_model.embeddings.position_embedding", # ernie4.5-vl-moe + "vision_model.patch_embed.pos_emb", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "vision_model.blocks.{bid}.attn.qkv", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1372,6 +1380,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl + "vision_model.blocks.{bid}.attn.q", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1390,6 +1399,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "vision_model.blocks.{bid}.attn.k", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1408,6 +1418,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "vision_model.blocks.{bid}.attn.v", # ernie4.5-vl-moe ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1421,6 +1432,7 @@ class TensorNameMap: "vision_model.model.layers.{bid}.input_layernorm", # llama4 "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) + "vision_model.blocks.{bid}.norm1", # ernie4.5-vl-moe "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", ), @@ -1428,6 +1440,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_O: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL + "vision_model.blocks.{bid}.attn.proj", # ernie4.5-vl-moe "model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM @@ -1452,6 +1465,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) + "vision_model.blocks.{bid}.norm2", # ernie4.5-vl-moe "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", ), @@ -1467,6 +1481,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl + "vision_model.blocks.{bid}.mlp.fc1", # ernie4.5-vl-moe "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", @@ -1489,6 +1504,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl + "vision_model.blocks.{bid}.mlp.fc2", # ernie4.5-vl-moe "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", @@ -1519,6 +1535,7 @@ class TensorNameMap: "vision_tower.encoder.final_layernorm", # kimi-vl "visual.post_layernorm", # glm4v "siglip2.vision_model.post_layernorm", + "vision_model.post_layernorm", # ernie4.5-vl-moe ), MODEL_TENSOR.V_MM_POST_NORM: ( @@ -1544,18 +1561,22 @@ class TensorNameMap: MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( "resampler.pos_embed_k", + "resampler_model.pos_embed_k", # ernie4.5-vl-moe ), MODEL_TENSOR.V_RESMPL_ATTN_Q: ( "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + "resampler_model.attn.in_proj_q", # ernie4.5-vl-moe ), MODEL_TENSOR.V_RESMPL_ATTN_K: ( "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + "resampler_model.attn.in_proj_k", # ernie4.5-vl-moe ), MODEL_TENSOR.V_RESMPL_ATTN_V: ( "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + "resampler_model.attn.in_proj_v", # ernie4.5-vl-moe ), MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( @@ -1584,6 +1605,7 @@ class TensorNameMap: MODEL_TENSOR.V_RESMPL_QUERY: ( "resampler.query", + "resampler_model.query", # ernie4.5-vl-moe ), MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( @@ -1635,6 +1657,26 @@ class TensorNameMap: "model.vision.eoi", # cogvlm ), + MODEL_TENSOR.V_FFN_GATE_INP: ( + "model.layers.{bid}.mlp.gate.vision", # ernie4.5-vl-moe + ), + + MODEL_TENSOR.V_FFN_GATE_EXPS: ( + "model.vision.layers.{bid}.mlp.experts.gate_proj", # ernie4.5-vl-moe + ), + + MODEL_TENSOR.V_FFN_DOWN_EXPS: ( + "model.vision.layers.{bid}.mlp.experts.down_proj", # ernie4.5-vl-moe + ), + + MODEL_TENSOR.V_FFN_UP_EXPS: ( + "model.vision.layers.{bid}.mlp.experts.up_proj", # ernie4.5-vl-moe + ), + + MODEL_TENSOR.V_FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.moe_statics.e_score_correction.vision", # ernie4.5-vl-moe + ), + # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: ( diff --git a/include/llama.h b/include/llama.h index bf4e28a8be1..7ed25e7ab22 100644 --- a/include/llama.h +++ b/include/llama.h @@ -85,6 +85,7 @@ extern "C" { LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE, LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, + LLAMA_ROPE_TYPE_ERNIE3D = GGML_ROPE_TYPE_ERNIE3D, }; enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0c164617a12..2d3f13209b0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(llama models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp + models/ernie4-5-vl-moe.cpp models/ernie4-5.cpp models/exaone.cpp models/exaone4.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fce46772d7e..ab6f9082bb2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -101,6 +101,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_AFMOE, "afmoe" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, + { LLM_ARCH_ERNIE4_5_VL_MOE, "ernie4_5-vl-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, @@ -163,6 +164,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_VISION_EXPERT_FEED_FORWARD_LENGTH, "%s.vision_expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, @@ -333,6 +335,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_V_FFN_GATE_INP, "blk.%d.v_ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, @@ -341,8 +344,11 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_V_FFN_GATE_EXPS, "blk.%d.v_ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_V_FFN_DOWN_EXPS, "blk.%d.v_ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_V_FFN_UP_EXPS, "blk.%d.v_ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, @@ -352,6 +358,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_V_FFN_EXP_PROBS_B, "blk.%d.v_exp_probs_b" }, { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, @@ -2106,6 +2113,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_EXP_PROBS_B, }; case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_ERNIE4_5_VL_MOE: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, @@ -2127,6 +2135,11 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_V_FFN_GATE_INP, + LLM_TENSOR_V_FFN_GATE_EXPS, + LLM_TENSOR_V_FFN_DOWN_EXPS, + LLM_TENSOR_V_FFN_UP_EXPS, + LLM_TENSOR_V_FFN_EXP_PROBS_B, }; case LLM_ARCH_HUNYUAN_MOE: return { @@ -2511,6 +2524,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2598,6 +2615,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_V_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index a392ecce2b4..c6ff643ed01 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -105,6 +105,7 @@ enum llm_arch { LLM_ARCH_AFMOE, LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, + LLM_ARCH_ERNIE4_5_VL_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, LLM_ARCH_SMOLLM3, @@ -157,7 +158,6 @@ enum llm_kv { LLM_KV_GENERAL_LICENSE, LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_HF_REPO, - LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, @@ -167,6 +167,7 @@ enum llm_kv { LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_VISION_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, LLM_KV_SWIGLU_CLAMP_EXP, @@ -201,7 +202,6 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, - LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_MAX_ALIBI_BIAS, @@ -228,7 +228,6 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, - LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, @@ -351,6 +350,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_SINKS, LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_V_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_POST_NORM, @@ -363,8 +363,11 @@ enum llm_tensor { LLM_TENSOR_FFN_UP_EXP, LLM_TENSOR_FFN_NORM_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_V_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_V_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_V_FFN_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -372,6 +375,7 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_CHEXPS, LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_V_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 756dda1a7ab..a584d033949 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -178,7 +178,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const { } uint32_t llama_hparams::n_pos_per_embd() const { - return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; + return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE || rope_type == LLAMA_ROPE_TYPE_ERNIE3D ? 4 : 1; } bool llama_hparams::is_swa(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 6c695bdbf66..3a2f6ca5f48 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -70,6 +70,7 @@ struct llama_hparams { uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; + uint32_t n_ff_v_exp = 0; uint32_t n_ff_shexp = 0; uint32_t n_ff_chexp = 0; uint32_t n_expert_shared = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8fc61aee372..d409c68b528 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -123,6 +123,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_28B_A3B: return "28B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_48B_A3B: return "48B.A3B"; @@ -2195,6 +2196,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_ERNIE4_5_VL_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_VISION_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_v_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false)) { + hparams.rope_sections[0] = 22; + hparams.rope_sections[1] = 22; + hparams.rope_sections[2] = 20; + hparams.rope_sections[3] = 0; + } + + LLAMA_LOG_INFO("%s: ERNIE-VL rope_sections=[%d,%d,%d,%d]\n", __func__, + hparams.rope_sections[0], hparams.rope_sections[1], + hparams.rope_sections[2], hparams.rope_sections[3]); + + if (hparams.n_ff_v_exp == 0) { + hparams.n_ff_v_exp = 512; // ERNIE-VL default + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_28B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_FALCON_H1: { // Common parameters @@ -6397,6 +6426,67 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; + case LLM_ARCH_ERNIE4_5_VL_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + int n_ff_v_exp = hparams.n_ff_v_exp; // Vision expert intermediate size + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Vision expert MoE tensors + layer.v_ffn_gate_inp = create_tensor(tn(LLM_TENSOR_V_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.v_ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_V_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.v_ffn_gate_exps = create_tensor(tn(LLM_TENSOR_V_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_v_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.v_ffn_down_exps = create_tensor(tn(LLM_TENSOR_V_FFN_DOWN_EXPS, "weight", i), { n_ff_v_exp, n_embd, n_expert}, 0); + layer.v_ffn_up_exps = create_tensor(tn(LLM_TENSOR_V_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_v_exp, n_expert}, 0); + + // Shared expert (if present) if (hparams.n_ff_shexp > 0) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); @@ -8430,6 +8520,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_ERNIE4_5_VL_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_HUNYUAN_MOE: { llm = std::make_unique(*this, params); @@ -8771,6 +8865,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM; case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_ERNIE4_5_VL_MOE: + return LLAMA_ROPE_TYPE_ERNIE3D; // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: diff --git a/src/llama-model.h b/src/llama-model.h index 7b580043b33..876cd30280f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_28B_A3B, // Ernie MoE vl small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_48B_A3B, // Kimi Linear @@ -286,6 +287,12 @@ struct llama_layer { struct ggml_tensor * ffn_down_exps_b = nullptr; struct ggml_tensor * ffn_up_exps_b = nullptr; + // ff Vision expert MoE + struct ggml_tensor * v_ffn_gate_inp = nullptr; + struct ggml_tensor * v_ffn_gate_exps = nullptr; + struct ggml_tensor * v_ffn_down_exps = nullptr; + struct ggml_tensor * v_ffn_up_exps = nullptr; + // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; struct ggml_tensor * ffn_gate_shexp = nullptr; @@ -303,6 +310,7 @@ struct llama_layer { struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; struct ggml_tensor * ffn_exp_probs_b = nullptr; + struct ggml_tensor * v_ffn_exp_probs_b = nullptr; // mamba proj struct ggml_tensor * ssm_in = nullptr; diff --git a/src/models/ernie4-5-vl-moe.cpp b/src/models/ernie4-5-vl-moe.cpp new file mode 100644 index 00000000000..f246085d2c0 --- /dev/null +++ b/src/models/ernie4-5-vl-moe.cpp @@ -0,0 +1,185 @@ +#include "models.h" + +llm_build_ernie4_5_vl_moe::llm_build_ernie4_5_vl_moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0"); + + // Get MROPE sections from hparams + // ERNIE-VL uses [22, 22, 20, 0] for (n_h, n_w, n_t, extra) + // With ERNIE3D rope type: interleaved [h,w,h,w,...,t,t,...] frequency layout + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::end(hparams.rope_sections), sections); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // ERNIE-VL uses ERNIE3D RoPE (NORMAL rotation + interleaved 3D frequency) + // sections [22, 22, 20, 0]: n_h=22, n_w=22, n_t=20 + // Frequency layout: [h0,w1,h2,w3,...,h42,w43,t44,...,t63] + // Position slots: [t, h, w, 0] + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + bool is_moe_layer = + static_cast(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0; + + if (!is_moe_layer) { + // Dense layer + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = nullptr; + + // Use vision experts for vision tokens, text experts for text tokens + if (ubatch.embd) { + // Vision tokens: use vision MoE experts + moe_out = build_moe_ffn(cur, + model.layers[il].v_ffn_gate_inp, + model.layers[il].v_ffn_up_exps, + model.layers[il].v_ffn_gate_exps, + model.layers[il].v_ffn_down_exps, + model.layers[il].v_ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + } else { + // Text tokens: use text MoE experts + moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + } + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_residual", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 2a750c168ea..dbb75dc9571 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -205,6 +205,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_ernie4_5_vl_moe : public llm_graph_context { + llm_build_ernie4_5_vl_moe(const llama_model & model, const llm_graph_params & params); +}; + template struct llm_build_exaone4 : public llm_graph_context { llm_build_exaone4(const llama_model & model, const llm_graph_params & params); diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af323..ac840a23ed0 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -29,6 +29,7 @@ add_library(mtmd models/whisper-enc.cpp models/mobilenetv5.cpp models/youtuvl.cpp + models/ernie45vlmoe.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ad232178bf4..af7508364a4 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -235,6 +235,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_ERNIE45VLMOE, PROJECTOR_TYPE_UNKNOWN, }; @@ -268,6 +269,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_ERNIE45VLMOE, "ernie4.5vl_moe"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390e..43efae422bd 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -829,6 +829,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_LDP: @@ -1139,6 +1143,12 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + hparams.n_merge = 2; + hparams.set_limit_image_tokens(8, 1024); + hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -1831,6 +1841,29 @@ struct clip_model_loader { layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias")); } } break; + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + // spatial path: Linear -> GELU -> Linear -> LayerNorm + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + model.mm_post_norm_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight")); + model.mm_post_norm_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); + + // temporal path: Linear -> GELU -> Linear -> LayerNorm (optional, not used for single images) + model.mm_1_w = get_tensor("mm_temp.0.weight", false); + model.mm_1_b = get_tensor("mm_temp.0.bias", false); + model.mm_3_w = get_tensor("mm_temp.2.weight", false); + model.mm_3_b = get_tensor("mm_temp.2.bias", false); + model.mm_input_norm_w = get_tensor("mm_temp.3.weight", false); + model.mm_input_norm_b = get_tensor("mm_temp.3.bias", false); + + // output MLP + RMS norm + model.mm_fc_w = get_tensor("mm.mlp.weight"); + model.mm_fc_b = get_tensor("mm.mlp.bias"); + model.mm_norm_mid_w = get_tensor("mm.norm.weight"); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -3003,7 +3036,21 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); } break; - + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); + clip_image_u8 resized_image; + const int cur_merge = params.n_merge == 0 ? 1 : params.n_merge; + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * cur_merge, + params.image_min_pixels, + params.image_max_pixels); + img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BILINEAR); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_LLAMA4: { GGML_ASSERT(!params.image_res_candidates.empty()); @@ -3145,6 +3192,8 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; + case PROJECTOR_TYPE_ERNIE45VLMOE: + return (img->nx / params.patch_size) / 2; default: break; } @@ -3161,6 +3210,8 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; + case PROJECTOR_TYPE_ERNIE45VLMOE: + return (img->nx / params.patch_size) / 2; default: break; } @@ -3230,6 +3281,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int y_patch = img->ny / (params.patch_size * 2); n_patches = x_patch * y_patch; } break; + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + // dynamic size (2 conv, so double patch size) + int x_patch = img->nx / (params.patch_size * 2); + int y_patch = img->ny / (params.patch_size * 2); + n_patches = x_patch * y_patch; + } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: @@ -3584,6 +3642,25 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_ERNIE45VLMOE: + { + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + std::vector positions(n_pos * 4); + int ptr = 0; + + for (int y = 0; y < ph; y++) { + for (int x = 0; x < pw; x++) { + positions[ ptr] = y; + positions[ num_patches + ptr] = x; + positions[2 * num_patches + ptr] = 0; + positions[3 * num_patches + ptr] = 0; + ptr++; + } + } + set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_PIXTRAL: @@ -3777,6 +3854,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.position_embeddings->ne[0]; case PROJECTOR_TYPE_GLM4V: return ctx->model.mm_ffn_down_w->ne[1]; + case PROJECTOR_TYPE_ERNIE45VLMOE: + return ctx->model.mm_fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } @@ -3795,6 +3874,14 @@ bool clip_is_glm(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } +bool clip_is_mrope(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL + || ctx->proj_type() == PROJECTOR_TYPE_ERNIE45VLMOE + || ctx->proj_type() == PROJECTOR_TYPE_GLM4V; +} + bool clip_is_llava(const struct clip_ctx * ctx) { return ctx->model.hparams.has_llava_projector; } diff --git a/tools/mtmd/models/ernie45vlmoe.cpp b/tools/mtmd/models/ernie45vlmoe.cpp new file mode 100644 index 00000000000..23679e0d0ef --- /dev/null +++ b/tools/mtmd/models/ernie45vlmoe.cpp @@ -0,0 +1,100 @@ +#include "models.h" + +ggml_cgraph * clip_graph_ernie45vlmoe::build() { + // ERNIE-4.5-VL-MoE Vision + Resampler: + // 1. ViT encoder with 2D position embeddings and M-RoPE support + // 2. Resampler with spatial conv (2x2 grouping) + optional temporal + MLP + RMS norm + + const int n_pos = n_patches; + // Use n_merge for patch merge size (same as spatial_conv_size = 2) + const int spatial_merge_size = hparams.n_merge > 0 ? hparams.n_merge : 2; + + GGML_ASSERT(spatial_merge_size == 2 && "ERNIE-4.5-VL-MoE requires n_merge=2"); + + // ERNIE-VL Vision uses 2D position lookup RoPE: + // - Front half of frequencies use h_position + // - Back half of frequencies use w_position + // For d_head=80, n_dims=40, we need sections[0]=20 (for h) and sections[1]=20 (for w) + // GGML_ROPE_TYPE_VISION uses only 2 sections: sect_0 for first pos slot, sect_1 for second + int mrope_sections[4] = {d_head/4, d_head/4, 0, 0}; // [20, 20, 0, 0] for d_head=80 + + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // Use the standard build_inp() which handles Conv2D patch embeddings + // The Python conversion now converts linear weights to Conv2D format + ggml_tensor * inp = build_inp(); + + // Build ViT encoder using the generic build_vit() with M-RoPE position encoding + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi( + ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, 10000, 1, 0, 1, 32, 1); + }; + + ggml_tensor * embeddings = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, // no learned position embeddings, using RoPE + add_pos); + cb(embeddings, "vision_output", -1); + + // ------------------------------------------- + // Resampler projection + // ------------------------------------------- + // Group 2x2 patches: 40x40 -> 20x20, output shape [n_embd*4, n_groups] + embeddings = build_patch_merge_permute(embeddings, spatial_merge_size); + cb(embeddings, "spatial_reshape", -1); + + // Spatial linear path: Linear -> GELU -> Linear -> LayerNorm + // Weights are expected to be already transposed in GGUF format + ggml_tensor * spatial_out = embeddings; + + spatial_out = build_ffn(spatial_out, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + + // LayerNorm + spatial_out = build_norm(spatial_out, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, eps, -1); + cb(spatial_out, "spatial_norm", -1); + + ggml_tensor * resampler_out = spatial_out; + + // Temporal processing for single images (t=1): + // Following ERNIE-VL original: when t=1, slice_offsets and slice_offsets2 both point to the same frame + resampler_out = ggml_concat(ctx0, resampler_out, resampler_out, 0); + + // Temporal linear path: Linear -> GELU -> Linear -> LayerNorm + // Weights are expected to be already transposed in GGUF format + resampler_out = build_ffn(resampler_out, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_3_w, model.mm_3_b, + FFN_GELU, + -1); + + // LayerNorm + resampler_out = build_norm(resampler_out, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, eps, -1); + cb(resampler_out, "temporal_norm", -1); + + // Final MLP: Linear (weights are expected to be already transposed in GGUF format) + resampler_out = ggml_mul_mat(ctx0, model.mm_fc_w, resampler_out); + resampler_out = ggml_add(ctx0, resampler_out, model.mm_fc_b); + cb(resampler_out, "mlp", -1); + + // RMS norm (final output normalization) + resampler_out = build_norm(resampler_out, model.mm_norm_mid_w, nullptr, NORM_TYPE_RMS, eps, -1); + cb(resampler_out, "after_norm", -1); + + // Build the graph + ggml_build_forward_expand(gf, resampler_out); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7bc..2e7966c40f4 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -109,3 +109,8 @@ struct clip_graph_mobilenetv5 : clip_graph { ggml_tensor * inp, const mobilenetv5_block & block); }; + +struct clip_graph_ernie45vlmoe : clip_graph { + clip_graph_ernie45vlmoe(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d037e834f3b..48018e85ac5 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -315,6 +315,11 @@ struct mtmd_context { img_end = "<|end_of_image|>"; } + else if (proj == PROJECTOR_TYPE_ERNIE45VLMOE) { + img_beg = "<|IMAGE_START|>"; + img_end = "<|IMAGE_END|>"; + + } } void init_audio() {