diff --git a/conversion/nemotron.py b/conversion/nemotron.py index dfeeb9785822..1e1d74343db2 100644 --- a/conversion/nemotron.py +++ b/conversion/nemotron.py @@ -39,23 +39,40 @@ def get_vision_config(self) -> dict[str, Any] | None: } return vision_config + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config.get("sound_config") + def set_gguf_parameters(self): if "image_mean" not in self.preprocessor_config: self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406] if "image_std" not in self.preprocessor_config: self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225] + if self.hparams_audio is not None: + self.has_vision_encoder = True + self.has_audio_encoder = True + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(1e-5) + self.gguf_writer.add_audio_subsampling_factor(self.hparams_audio["subsampling_factor"]) + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.PARAKEET) + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL) + else: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL) + super().set_gguf_parameters() hparams = self.global_config - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL) self.gguf_writer.add_vision_attention_layernorm_eps(1e-6) self.gguf_writer.add_vision_use_gelu(True) downsample_ratio = hparams.get("downsample_ratio", 0.5) self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) def tensor_force_quant(self, name, new_name, bid, n_dims): - if ".position_embd." in new_name or "pos_embed" in new_name: - return gguf.GGMLQuantizationType.F32 + if "sound_encoder" in name or new_name.startswith("mm.a."): + if "bias" in new_name or "norm" in new_name: + return gguf.GGMLQuantizationType.F32 + if "conv" in new_name and "weight" in new_name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) @classmethod @@ -65,18 +82,25 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca if "input_conditioner" in name: return None + if "language_model" in name: + return None + # mtmd does not support video yet so skip tensors related to video. if "radio_model.model.patch_generator.video_embedder" in name: return None - if not name.startswith("vision_model.radio_model.model.") and not name.startswith("mlp1."): + if not name.startswith(("vision_model.radio_model.model.", "mlp1.", "sound_encoder.", "sound_projection.")): return None if "patch_generator.pos_embed" in name: if not name.endswith(".weight"): name += ".weight" - return super().filter_tensors((name, gen)) + # num_batches is only used for training not inference. + if "conv.norm" in name and "num_batches" in name: + return None + + return name, gen def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it @@ -104,7 +128,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_embd = self.hparams["hidden_size"] data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size) - yield from super().modify_tensors(data_torch, name, bid) + if "depthwise_conv.weight" in name: + data_torch = data_torch.unsqueeze(-1) + data_torch = data_torch.permute(3, 1, 0, 2).contiguous() + + if "pointwise_conv" in name and name.endswith(".weight"): + if len(data_torch.shape) == 3 and data_torch.shape[2] == 1: + data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1]) + + if "subsampling.layers" in name and name.endswith(".bias"): + if len(data_torch.shape) == 1: + data_torch = data_torch.reshape(1, -1, 1, 1) + + if "pointwise_conv" in name and name.endswith(".bias"): + if len(data_torch.shape) == 1: + data_torch = data_torch.reshape(1, -1, 1, 1) + + for mapped_name, tensor in super().modify_tensors(data_torch, name, bid): + if name.startswith("sound_projection.") and mapped_name.startswith("mm.model.mlp."): + mapped_name = mapped_name.replace("mm.model.mlp.", "mm.a.mlp.") + yield mapped_name, tensor @ModelBase.register("NemotronForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ce556ec9b655..4b81d8b3db30 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -346,6 +346,7 @@ class ClipAudio: FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length" PROJECTION_DIM = "clip.audio.projection_dim" BLOCK_COUNT = "clip.audio.block_count" + SUBSAMPLING_FACTOR = "clip.audio.subsampling_factor" CHUNK_SIZE = "clip.audio.chunk_size" CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size" MAX_POS_EMB = "clip.audio.max_pos_emb" @@ -882,6 +883,10 @@ class MODEL_TENSOR(IntEnum): A_ENC_CONV_NORM = auto() # SSM conv A_ENC_CONV_PW1 = auto() A_ENC_CONV_PW2 = auto() + A_ENC_CONV_NORM_MEAN = auto() # parakeet + A_ENC_CONV_NORM_VAR = auto() # parakeet + A_ENC_MEL_FILTERS = auto() # parakeet + A_ENC_WINDOW = auto() # parakeet A_CTC_OUT = auto() A_CTC_OUT_MID = auto() A_ENC_ATTN_REL_POS_EMB = auto() @@ -1396,6 +1401,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_ENC_CONV_NORM: "a.blk.{bid}.conv_norm", MODEL_TENSOR.A_ENC_CONV_PW1: "a.blk.{bid}.conv_pw1", MODEL_TENSOR.A_ENC_CONV_PW2: "a.blk.{bid}.conv_pw2", + MODEL_TENSOR.A_ENC_CONV_NORM_MEAN: "a.blk.{bid}.conv_norm_mean", + MODEL_TENSOR.A_ENC_CONV_NORM_VAR: "a.blk.{bid}.conv_norm_var", + MODEL_TENSOR.A_ENC_MEL_FILTERS: "a.mel_filters", + MODEL_TENSOR.A_ENC_WINDOW: "a.window", MODEL_TENSOR.A_CTC_OUT: "a.enc_ctc_out", MODEL_TENSOR.A_CTC_OUT_MID: "a.enc_ctc_out_mid", MODEL_TENSOR.A_ENC_ATTN_REL_POS_EMB: "a.blk.{bid}.attn_rel_pos_emb", @@ -1569,6 +1578,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_ENC_CONV_NORM, MODEL_TENSOR.A_ENC_CONV_PW1, MODEL_TENSOR.A_ENC_CONV_PW2, + MODEL_TENSOR.A_ENC_CONV_NORM_MEAN, + MODEL_TENSOR.A_ENC_CONV_NORM_VAR, + MODEL_TENSOR.A_ENC_MEL_FILTERS, + MODEL_TENSOR.A_ENC_WINDOW, MODEL_TENSOR.A_MM_INP_PROJ, MODEL_TENSOR.A_MM_SOFT_EMB_NORM, MODEL_TENSOR.A_MM_EMBEDDING, @@ -4385,6 +4398,7 @@ class VisionProjectorType: YOUTUVL = "youtuvl" NEMOTRON_V2_VL = "nemotron_v2_vl" HUNYUANVL = "hunyuanvl" + PARAKEET = "parakeet" MINICPMV4_6 = "minicpmv4_6" GRANITE_SPEECH = "granite_speech" # audio MIMOVL = "mimovl" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 875d0f73d964..00aeaf254eca 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1278,6 +1278,9 @@ def add_audio_num_mel_bins(self, value: int) -> None: def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + def add_audio_subsampling_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.SUBSAMPLING_FACTOR, value) + def add_audio_chunk_size(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.CHUNK_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 82f26e7b303d..7cde82d9d949 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1972,6 +1972,7 @@ class TensorNameMap: "conformer.pre_encode.conv.{bid}", # lfm2 "model.audio_tower.subsample_conv_projection.conv_{bid}.conv", # gemma3n "conformer.subsample_conv_projection.layer{bid}.conv", # gemma4 + "sound_encoder.encoder.subsampling.layers.{bid}", # parakeet ), MODEL_TENSOR.A_ENC_CONV1D_NORM: ( @@ -2003,6 +2004,7 @@ class TensorNameMap: "conformer.layers.{bid}.self_attn.linear_q", # lfm2 "conformer.layers.{bid}.attention.attn.q_proj", # gemma3n "conformer.layers.{bid}.self_attn.q_proj", # gemma4 + "sound_encoder.encoder.layers.{bid}.self_attn.q_proj", # parakeet "encoder.layers.{bid}.attn.to_q", # granite_speech ), @@ -2011,6 +2013,7 @@ class TensorNameMap: "conformer.layers.{bid}.self_attn.linear_k", # lfm2 "conformer.layers.{bid}.attention.attn.k_proj", # gemma3n "conformer.layers.{bid}.self_attn.k_proj", # gemma4 + "sound_encoder.encoder.layers.{bid}.self_attn.k_proj", # parakeet "encoder.layers.{bid}.attn.to_k", # granite_speech (split from to_kv) ), @@ -2019,6 +2022,7 @@ class TensorNameMap: "conformer.layers.{bid}.self_attn.linear_v", # lfm2 "conformer.layers.{bid}.attention.attn.v_proj", # gemma3n "conformer.layers.{bid}.self_attn.v_proj", # gemma4 + "sound_encoder.encoder.layers.{bid}.self_attn.v_proj", # parakeet "encoder.layers.{bid}.attn.to_v", # granite_speech (split from to_kv) ), @@ -2047,6 +2051,7 @@ class TensorNameMap: "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox "conformer.layers.{bid}.norm_self_att", # lfm2 "conformer.layers.{bid}.attention.pre_attn_norm", # gemma3n + "sound_encoder.encoder.layers.{bid}.norm_self_att", # parakeet "encoder.layers.{bid}.attn.pre_norm", # granite_speech ), @@ -2055,6 +2060,7 @@ class TensorNameMap: "conformer.layers.{bid}.self_attn.linear_out", # lfm2 "conformer.layers.{bid}.attention.post", # gemma3n "conformer.layers.{bid}.self_attn.post", # gemma4 + "sound_encoder.encoder.layers.{bid}.self_attn.o_proj", # parakeet "encoder.layers.{bid}.attn.to_out", # granite_speech ), @@ -2062,6 +2068,7 @@ class TensorNameMap: "audio_tower.layers.{bid}.final_layer_norm", # ultravox "conformer.layers.{bid}.norm_out", # lfm2 "conformer.layers.{bid}.attention.post_norm", # gemma3n + "sound_encoder.encoder.layers.{bid}.norm_out", # parakeet "encoder.layers.{bid}.post_norm", # granite_speech ), @@ -2069,6 +2076,7 @@ class TensorNameMap: "conformer.layers.{bid}.norm_feed_forward1", # lfm2 "conformer.layers.{bid}.ffw_layer_start.pre_layer_norm", # gemma3n "conformer.layers.{bid}.feed_forward1.pre_layer_norm", # gemma4 + "sound_encoder.encoder.layers.{bid}.norm_feed_forward1", # parakeet "encoder.layers.{bid}.ff1.pre_norm", # granite_speech ), @@ -2086,6 +2094,7 @@ class TensorNameMap: "conformer.layers.{bid}.feed_forward1.linear1", # lfm2 "conformer.layers.{bid}.ffw_layer_start.ffw_layer_1", # gemma3n "conformer.layers.{bid}.feed_forward1.ffw_layer_1", # gemma4 + "sound_encoder.encoder.layers.{bid}.feed_forward1.linear1", # parakeet "encoder.layers.{bid}.ff1.up_proj", # granite_speech ), @@ -2096,6 +2105,7 @@ class TensorNameMap: "conformer.layers.{bid}.feed_forward1.linear2", # lfm2 "conformer.layers.{bid}.ffw_layer_start.ffw_layer_2", # gemma3n "conformer.layers.{bid}.feed_forward1.ffw_layer_2", # gemma4 + "sound_encoder.encoder.layers.{bid}.feed_forward1.linear2", # parakeet "encoder.layers.{bid}.ff1.down_proj", # granite_speech ), @@ -2103,6 +2113,7 @@ class TensorNameMap: "conformer.layers.{bid}.feed_forward2.linear1", # lfm2 "conformer.layers.{bid}.ffw_layer_end.ffw_layer_1", # gemma3n "conformer.layers.{bid}.feed_forward2.ffw_layer_1", # gemma4 + "sound_encoder.encoder.layers.{bid}.feed_forward2.linear1", # parakeet "encoder.layers.{bid}.ff2.up_proj", # granite_speech ), @@ -2110,6 +2121,7 @@ class TensorNameMap: "conformer.layers.{bid}.feed_forward2.linear2", # lfm2 "conformer.layers.{bid}.ffw_layer_end.ffw_layer_2", # gemma3n "conformer.layers.{bid}.feed_forward2.ffw_layer_2", # gemma4 + "sound_encoder.encoder.layers.{bid}.feed_forward2.linear2", # parakeet "encoder.layers.{bid}.ff2.down_proj", # granite_speech ), @@ -2117,6 +2129,7 @@ class TensorNameMap: "conformer.layers.{bid}.norm_feed_forward2", # lfm2 "conformer.layers.{bid}.ffw_layer_end.pre_layer_norm", # gemma3n "conformer.layers.{bid}.feed_forward2.pre_layer_norm", # gemma4 + "sound_encoder.encoder.layers.{bid}.norm_feed_forward2", # parakeet "encoder.layers.{bid}.ff2.pre_norm", # granite_speech ), @@ -2132,20 +2145,24 @@ class TensorNameMap: MODEL_TENSOR.A_ENC_LINEAR_POS: ( "conformer.layers.{bid}.self_attn.linear_pos", # lfm2 "conformer.layers.{bid}.attention.attn.relative_position_embedding.pos_proj", # gemma3n + "sound_encoder.encoder.layers.{bid}.self_attn.relative_k_proj", # parakeet ), MODEL_TENSOR.A_ENC_POS_BIAS_U: ( "conformer.layers.{bid}.self_attn.pos_bias_u", # lfm2 + "sound_encoder.encoder.layers.{bid}.self_attn.bias_u", # parakeet ), MODEL_TENSOR.A_ENC_POS_BIAS_V: ( "conformer.layers.{bid}.self_attn.pos_bias_v", # lfm2 + "sound_encoder.encoder.layers.{bid}.self_attn.bias_v", # parakeet ), MODEL_TENSOR.A_ENC_OUT: ( "conformer.pre_encode.out", # lfm2 "model.audio_tower.subsample_conv_projection.input_proj_linear", # gemma3n (note: it should be A_ENC_INP_PROJ, this is a mistake; it should be corrected in C++ code when it's supported) "conformer.output_proj", # gemma4 + "sound_encoder.encoder.subsampling.linear", # parakeet ), # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors @@ -2155,6 +2172,7 @@ class TensorNameMap: "audio.multi_modal_projector.linear_{bid}", # ultravox, meralion "audio_adapter.model.{bid}", # lfm2 "audio_tower.proj{bid}", # qwen3omni + "sound_projection.linear{bid}", # parakeet (linear1, linear2) ), MODEL_TENSOR.A_MMPROJ_FC: ( @@ -2165,6 +2183,7 @@ class TensorNameMap: MODEL_TENSOR.A_MM_NORM_PRE: ( "audio.multi_modal_projector.ln_pre", # ultravox + "sound_projection.norm", # parakeet ), MODEL_TENSOR.A_MM_NORM_MID: ( @@ -2174,30 +2193,43 @@ class TensorNameMap: MODEL_TENSOR.A_ENC_CONV_DW: ( "conformer.layers.{bid}.conv.depthwise_conv", # lfm2 "conformer.layers.{bid}.lconv1d.depthwise_conv1d", # gemma3n + "sound_encoder.encoder.layers.{bid}.conv.depthwise_conv", # parakeet "encoder.layers.{bid}.conv.depth_conv.conv", # granite_speech ), MODEL_TENSOR.A_ENC_CONV_NORM: ( "conformer.layers.{bid}.conv.batch_norm", # lfm2 "conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma3n + "sound_encoder.encoder.layers.{bid}.conv.norm", # parakeet + ), + + MODEL_TENSOR.A_ENC_CONV_NORM_MEAN: ( + "sound_encoder.encoder.layers.{bid}.conv.norm.running_mean", # parakeet + ), + + MODEL_TENSOR.A_ENC_CONV_NORM_VAR: ( + "sound_encoder.encoder.layers.{bid}.conv.norm.running_var", # parakeet "encoder.layers.{bid}.conv.batch_norm", # granite_speech ), MODEL_TENSOR.A_ENC_CONV_PW1: ( "conformer.layers.{bid}.conv.pointwise_conv1", # lfm2 "conformer.layers.{bid}.lconv1d.linear_start", # gemma3n + "sound_encoder.encoder.layers.{bid}.conv.pointwise_conv1", # parakeet "encoder.layers.{bid}.conv.up_conv", # granite_speech ), MODEL_TENSOR.A_ENC_CONV_PW2: ( "conformer.layers.{bid}.conv.pointwise_conv2", # lfm2 "conformer.layers.{bid}.lconv1d.linear_end", # gemma3n + "sound_encoder.encoder.layers.{bid}.conv.pointwise_conv2", # parakeet "encoder.layers.{bid}.conv.down_conv", # granite_speech ), MODEL_TENSOR.A_ENC_NORM_CONV: ( "conformer.layers.{bid}.norm_conv", # lfm2 "conformer.layers.{bid}.lconv1d.conv_norm", # gemma3n + "sound_encoder.encoder.layers.{bid}.norm_conv", # parakeet "encoder.layers.{bid}.conv.norm", # granite_speech ), @@ -2209,6 +2241,14 @@ class TensorNameMap: "conformer.layers.{bid}.attention.attn.per_dim_scale", # gemma4 ), + MODEL_TENSOR.A_ENC_MEL_FILTERS: ( + "sound_encoder.encoder.feature_extractor.featurizer.fb", # parakeet + ), + + MODEL_TENSOR.A_ENC_WINDOW: ( + "sound_encoder.encoder.feature_extractor.featurizer.window", # parakeet + ), + MODEL_TENSOR.A_MM_EMBEDDING: ( "model.embed_audio.embedding", # gemma3n ), diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 93f005652b7d..57e841725dfc 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -47,6 +47,7 @@ add_library(mtmd models/mobilenetv5.cpp models/youtuvl.cpp models/yasa2.cpp + models/parakeet.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index c055cfb75419..277ccaa6b8f3 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -72,6 +72,8 @@ #define KEY_A_PROJ_DOWNSAMPLE_RATE "clip.audio.projector.downsample_rate" #define KEY_A_PROJ_HEAD_COUNT "clip.audio.projector.head_count" +#define KEY_AUDIO_SUBSAMPLING_FACTOR "clip.audio.subsampling_factor" + // // tensor name constants @@ -296,6 +298,12 @@ #define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s" #define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s" +// parakeet +#define TN_MEL_FILTERS "a.mel_filters" +#define TN_WINDOW "a.window" +#define TN_CONV_NORM_MEAN "%s.blk.%d.conv_norm_mean" +#define TN_CONV_NORM_VAR "%s.blk.%d.conv_norm_var" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -350,6 +358,7 @@ enum projector_type { PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANVL, + PROJECTOR_TYPE_PARAKEET, PROJECTOR_TYPE_EXAONE4_5, PROJECTOR_TYPE_MINICPMV4_6, PROJECTOR_TYPE_GRANITE_SPEECH, @@ -404,6 +413,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"}, { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, + { PROJECTOR_TYPE_PARAKEET, "parakeet"}, { PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"}, { PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"}, { PROJECTOR_TYPE_MIMOVL, "mimovl"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 238f805a9aae..fb07dced3e54 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -104,6 +104,8 @@ struct clip_hparams { // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox + int32_t subsampling_factor = 0; // parakeet + // int32_t audio_chunk_size = 0; int32_t audio_conv_kernel_size = 0; int32_t audio_max_pos_emb = 0; @@ -118,6 +120,10 @@ struct clip_hparams { int32_t audio_window_len = -1; int32_t audio_hop_len = -1; + // parakeet + std::vector mel_filters; + std::vector window; + // legacy bool has_llava_projector = false; int minicpmv_version = 0; @@ -229,14 +235,18 @@ struct clip_layer { ggml_tensor * norm_conv_b = nullptr; ggml_tensor * linear_pos_w = nullptr; - ggml_tensor * conv_norm_w = nullptr; - ggml_tensor * conv_norm_b = nullptr; - ggml_tensor * conv_dw_w = nullptr; - ggml_tensor * conv_dw_b = nullptr; - ggml_tensor * conv_pw1_w = nullptr; - ggml_tensor * conv_pw1_b = nullptr; - ggml_tensor * conv_pw2_w = nullptr; - ggml_tensor * conv_pw2_b = nullptr; + ggml_tensor * conv_norm_w = nullptr; + ggml_tensor * conv_norm_b = nullptr; + ggml_tensor * conv_norm_mean = nullptr; // parakeet + ggml_tensor * conv_norm_var = nullptr; // parakeet + ggml_tensor * conv_dw_w = nullptr; + ggml_tensor * conv_dw_b = nullptr; + ggml_tensor * conv_pw1_w = nullptr; + ggml_tensor * conv_pw1_b = nullptr; + ggml_tensor * conv_pw2_w = nullptr; + ggml_tensor * conv_pw2_b = nullptr; + + struct ggml_tensor * attn_pos_w; // gemma4 audio conformer per-layer ggml_tensor * attn_pre_norm_w = nullptr; @@ -547,6 +557,9 @@ struct clip_model { ggml_tensor * net_2; ggml_tensor * net_3; + // Parakeet + ggml_tensor * mm_norm_w = nullptr; + int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder std::vector sam_layers; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 80136ed86672..79581a331d62 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -997,6 +997,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_PARAKEET: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1305,6 +1309,15 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; + case PROJECTOR_TYPE_PARAKEET: + { + get_u32(KEY_AUDIO_SUBSAMPLING_FACTOR, hparams.subsampling_factor, false); + hparams.audio_chunk_len = 0; + hparams.audio_sample_rate = 16000; + hparams.audio_n_fft = 512; + hparams.audio_window_len = 400; + hparams.audio_hop_len = 160; + } break; case PROJECTOR_TYPE_IDEFICS3: { // use default llava-uhd preprocessing params @@ -2592,6 +2605,84 @@ struct clip_model_loader { layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias")); } } break; + case PROJECTOR_TYPE_PARAKEET: + { + auto get_vector = [&](const std::string & name) { + std::vector result; + auto it = tensor_offset.find(name); + if (it == tensor_offset.end()) { + return result; + } + + int idx = gguf_find_tensor(ctx_gguf.get(), name.c_str()); + GGML_ASSERT(idx >= 0); + size_t n_bytes = gguf_get_tensor_size(ctx_gguf.get(), idx); + size_t n_elems = n_bytes / sizeof(float); + result.resize(n_elems); + fin.seekg(it->second, std::ios::beg); + fin.read(reinterpret_cast(result.data()), n_bytes); + return result; + }; + + hparams.mel_filters = get_vector(TN_MEL_FILTERS); + hparams.window = get_vector(TN_WINDOW); + + // Subsampling layers (conv1d) + for (int i : {0, 2, 3, 5, 6}) { + model.pre_encode_conv_X_w[i] = get_tensor(string_format(TN_CONV1D, i, "weight")); + model.pre_encode_conv_X_b[i] = get_tensor(string_format(TN_CONV1D, i, "bias")); + } + model.pre_encode_out_w = get_tensor(string_format(TN_PRE_ENCODE_OUT, "weight")); + model.pre_encode_out_b = get_tensor(string_format(TN_PRE_ENCODE_OUT, "bias")); + + // Projection layers + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"), false); + model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"), false); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"), false); + + // Encoder layers + for (int il = 0; il < hparams.n_layer; ++il) { + auto & layer = model.layers[il]; + + // Attention (from shared above) + + // Relative position encoding + layer.linear_pos_w = get_tensor(string_format(TN_LINEAR_POS, prefix, il, "weight")); + layer.pos_bias_u = get_tensor(string_format(TN_POS_BIAS_U, prefix, il)); + layer.pos_bias_v = get_tensor(string_format(TN_POS_BIAS_V, prefix, il)); + + // Convolution module + layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight")); + layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false); + layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight")); + layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false); + layer.conv_norm_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight")); + layer.conv_norm_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false); + layer.conv_norm_mean = get_tensor(string_format(TN_CONV_NORM_MEAN, prefix, il), false); + layer.conv_norm_var = get_tensor(string_format(TN_CONV_NORM_VAR, prefix, il), false); + layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight")); + layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false); + + // Feed-forward networks + layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight")); + layer.ff_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"), false); + + layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight")); + layer.ff_norm_1_b = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "bias"), false); + layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight")); + layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false); + layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight")); + layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false); + + // Layer norms + layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight")); + layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false); + } + + model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + } break; case PROJECTOR_TYPE_GRANITE_SPEECH: { model.inp_proj_w = get_tensor(string_format(TN_INP_PROJ, "weight")); @@ -3387,6 +3478,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } n_patches = n; } break; + case PROJECTOR_TYPE_PARAKEET: + { + n_patches = (img->nx + (params.subsampling_factor - 1)) / params.subsampling_factor; + } break; case PROJECTOR_TYPE_GEMMA4UA: { n_patches = img->nx; // no downsampling: one token per raw waveform frame @@ -3517,7 +3612,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const auto & mel_inp = imgs.entries[0]; const int n_step = mel_inp->nx; const int n_mel = mel_inp->ny; + std::vector inp_raw(n_step * n_mel); + std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float)); set_input_f32("inp_raw", inp_raw); } @@ -4196,6 +4293,51 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_f32("pos_emb", pos_emb); } break; + case PROJECTOR_TYPE_PARAKEET: + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + const int n_tokens_real = (1101 + hparams.subsampling_factor-1) / hparams.subsampling_factor; + const float mask_value = -1e30f; + + std::vector mask_data(n_q * n_k); + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + bool is_padding = (k >= n_tokens_real); + mask_data[q * n_k + k] = (is_padding) ? mask_value : 0.0f; + } + } + set_input_f32(attn_mask->name, mask_data); + + // Generate rotation frequencies for relative positional encoding. + { + const int n_state = hparams.n_embd; + const int d_half = n_state / 2; + const float log_10000 = logf(10000.0f); + std::vector freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + set_input_f32("pos_freqs", freqs); + } + + // Generate relative positional distance values which scaled by + // the frequency to produce the angles for sin/cos. + { + // window_size is only known after graph construction since it depends on + // n_time from the conv output, so we read it back from the graph tensor. + struct ggml_tensor * rel_pos = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos->ne[1]; + const int n_time = (window_size + 1) / 2; + std::vector pos(window_size); + for (int t = 0; t < window_size; ++t) { + // The range of the values is high to low which the original model has. + pos[t] = float(n_time - 1 - t); + } + set_input_f32(rel_pos->name, pos); + } + } break; case PROJECTOR_TYPE_GRANITE_SPEECH: { const int context_size = ctx->model.hparams.audio_chunk_size; @@ -4387,6 +4529,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.qf_proj_linear_w->ne[1]; case PROJECTOR_TYPE_GLM4V: return ctx->model.mm_ffn_down_w->ne[1]; + case PROJECTOR_TYPE_PARAKEET: + return ctx->model.mm_1_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index b882f800dd77..590f539c7e6f 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -207,6 +207,14 @@ struct clip_graph_kimik25 : clip_graph { ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode); }; +struct clip_graph_parakeet : clip_graph { + clip_graph_parakeet(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * parakeet_build_graph_conv(); + ggml_tensor * parakeet_build_graph_encoder(ggml_tensor * cur); +}; + struct clip_graph_exaone4_5 : clip_graph { clip_graph_exaone4_5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/parakeet.cpp b/tools/mtmd/models/parakeet.cpp new file mode 100644 index 000000000000..a77371eb96b2 --- /dev/null +++ b/tools/mtmd/models/parakeet.cpp @@ -0,0 +1,337 @@ +#include "models.h" + +ggml_cgraph * clip_graph_parakeet::build() { + // Build convolution graph + ggml_tensor * cur = parakeet_build_graph_conv(); + ggml_build_forward_expand(gf, cur); + + // Build encoder graph + cur = parakeet_build_graph_encoder(cur); + + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + cb(cur, "sound_projection.norm", -1); + + cur = build_ffn(cur, model.mm_0_w, model.mm_0_b, nullptr, nullptr, model.mm_1_w, model.mm_1_b, FFN_RELU_SQR, -1); + cb(cur, "projected", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + +ggml_tensor * clip_graph_parakeet::parakeet_build_graph_conv() { + ggml_tensor * inp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.ny, img.nx, 1); + ggml_set_name(inp, "inp_raw"); + ggml_set_input(inp); + + // [freq, time, channels, batch] + ggml_tensor * cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[0], inp, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[0]); + cb(cur, "pre_conv_0", -1); + ggml_set_output(cur); + + cur = ggml_relu(ctx0, cur); + cb(cur, "pre_conv_0_relu", -1); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[2], cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[2]); + cb(cur, "pre_conv_2", -1); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[3], cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[3]); + cb(cur, "pre_conv_3", -1); + + cur = ggml_relu(ctx0, cur); + cb(cur, "pre_conv_3_relu", -1); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[5], cur, 2, 2, 1, 1, 1, 1); + cb(cur, "pre_conv_5_direct", -1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[5]); + cb(cur, "pre_conv_5", -1); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[6], cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[6]); + cb(cur, "pre_conv_6", -1); + + cur = ggml_relu(ctx0, cur); + cb(cur, "pre_conv_6_relu", -1); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; + const int n_chan = cur->ne[1]; + const int n_frames = cur->ne[2]; + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur); + cur = ggml_add(ctx0, cur, model.pre_encode_out_b); + + ggml_set_name(cur, "pre_enc_out"); + ggml_set_output(cur); + + return cur; +} + +ggml_tensor * clip_graph_parakeet::parakeet_build_graph_encoder(ggml_tensor * cur) { + const auto & hparams = model.hparams; + const int n_layer = hparams.n_layer; + const int n_state = hparams.n_embd; + const float fc_factor = 0.5f; + + // [time_frames, time_frames, 1, 1]] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cur->ne[1], cur->ne[1]); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + const int n_time = cur->ne[1]; + const int window_size = 2 * n_time - 1; + const int d_half = n_state / 2; + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin, cos, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + // FFN1 + { + struct ggml_tensor * residual = cur; + //ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, 1e-5); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ff_norm_w), layer.ff_norm_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, layer.ff_up_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, layer.ff_down_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding from model.position_embedding. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_1_w), layer.ln_1_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_head; + const int d_head = n_state / n_head; + const int n_time = cur->ne[1]; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.v_w, cur); + + // [d_head, n_heads, time_frames, 1] + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + // [n_state, window_size] + struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb); + ggml_format_name(pos, "enc_%d_attn_pos", il); + + // Add the content bias to Q. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + // [feat, time_frames, head, 1] + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + // [feat, time_frames, head, 1] + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + // [feat, feat, head, 1] + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + // Add the position bias to Q. + // [feat, head, time_frames, batch] + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + // [feat, window_size, 1, 1] and we are doing multi-head attention so + // we need to split this into heads. + // [feat, head, window_size, 1] + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, pos_emb->ne[1]); + + // [feat, window_size, head, 1] + pos = ggml_permute(ctx0, pos, 0, 2, 1, 3); + pos = ggml_cont(ctx0, pos); + ggml_format_name(pos, "enc_%d_attn_pos_perm", il); + // [feat, time, head, 1] + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + // [window_size, time_frames, head, 1] + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative positional shift + { + + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head = rel_pos_scores->ne[2]; + + // [feat_padded, window_size, head, 1] + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + attn_scores = ggml_cont(ctx0, attn_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.o_w, cur); + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: + cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + cur = ggml_roll(ctx0, cur, 4, 0, 0, 0); + cur = ggml_pad(ctx0, cur, 4, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, layer.conv_norm_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_norm_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_norm_w), layer.conv_norm_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ff_norm_1_w), layer.ff_norm_1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, layer.ff_up_1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, layer.ff_down_1_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_2_w), layer.ln_2_b); + } + + cb(cur, "encoder_out", -1); + + ggml_build_forward_expand(gf, cur); + + return cur; +} diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index 13f211fd9021..cff329a0b7cc 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -943,6 +943,224 @@ bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * s } // +// mtmd_audio_preprocessor_parakeet implementation +// + +void mtmd_audio_preprocessor_parakeet::worker_thread( + int ith, + const float * window_func, + int window_size, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + int n_fft_bins, + const mtmd_audio_cache & cache, + mtmd_audio_mel & mel) { + std::vector fft_in(frame_size * 2, 0.0); + std::vector fft_out(frame_size * 2 * 2 * 2); + + int n_fb = n_fft_bins; + int i = ith; + + GGML_ASSERT(n_fb == 1 + (frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + const int window_pad_left = (frame_size - window_size) / 2; + + // Zero-pad left. + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center. + const int n_to_process = std::min({window_size, n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right. + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + frame_size, 0.0f); + + // FFT. + fft(cache, fft_in.data(), frame_size, fft_out.data()); + + // Calculate modulus^2 of complex numbers. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram. + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * cache.filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * cache.filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * cache.filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * cache.filters.data[j * n_fb + k + 3]; + } + for (; k < n_fb; k++) { + sum += fft_out[k] * cache.filters.data[j * n_fb + k]; + } + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero. + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +void mtmd_audio_preprocessor_parakeet::initialize() { + cache.fill_sin_cos_table(hparams.audio_n_fft); + + const size_t n_fft = hparams.audio_n_fft / 2 + 1; + GGML_ASSERT(hparams.mel_filters.size() == (size_t)hparams.n_mel_bins * n_fft); + cache.filters.n_mel = hparams.n_mel_bins; + cache.filters.n_fft = n_fft; + cache.filters.data = hparams.mel_filters; + + GGML_ASSERT(hparams.window.size() == (size_t)hparams.audio_window_len); + GGML_ASSERT(hparams.window.size() <= (size_t) hparams.audio_n_fft); + cache.hann_window = hparams.window; +} + +bool mtmd_audio_preprocessor_parakeet::preprocess(const float * samples, + size_t n_samples_in, + std::vector & output) { + if (n_samples_in == 0) { + return false; + } + + filter_params params; + params.n_mel = hparams.n_mel_bins; + params.n_fft_bins = 1 + (hparams.audio_n_fft / 2); + params.hann_window_size = hparams.audio_window_len; + params.hop_length = hparams.audio_hop_len; + params.sample_rate = hparams.audio_sample_rate; + + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); + + const float * window_func = cache.hann_window.data(); + const int window_size = params.hann_window_size; + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + std::vector samples_preprocessed(samples, samples + n_samples_in); + { + const float preemph = 0.97f; + for (int i = n_samples_in - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet uses centered constant padding + const size_t pad = (size_t)(frame_size / 2); + std::vector samples_padded(n_samples_in + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mtmd_audio_mel out_full; + out_full.n_mel = params.n_mel; + out_full.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + out_full.n_len_org = out_full.n_len; + out_full.data.resize(out_full.n_mel * out_full.n_len); + + const int n_threads = 4; + + if (n_threads == 1) { + worker_thread(0, + window_func, + window_size, + samples_padded, + samples_padded.size(), + frame_size, + frame_step, + 1, + params.n_fft_bins, + cache, + out_full); + } else { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + worker_thread, iw + 1, + window_func, + window_size, + std::cref(samples_padded), + samples_padded.size(), + frame_size, + frame_step, + n_threads, + params.n_fft_bins, + std::cref(cache), + std::ref(out_full) + ); + } + + worker_thread(0, + window_func, + window_size, + samples_padded, + samples_padded.size(), + frame_size, + frame_step, + n_threads, + params.n_fft_bins, + cache, + out_full); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + // Per-feature normalization (only on valid frames) + { + const double eps = 1e-5; + int valid_frames = n_samples_in / frame_step; + + for (int j = 0; j < out_full.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)out_full.data[i * out_full.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)out_full.data[i * out_full.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < out_full.n_len; i++) { + out_full.data[i * out_full.n_mel + j] = (float)((out_full.data[i * out_full.n_mel + j] - mean) / denominator); + } + } + } + + output.push_back(std::move(out_full)); + return true; +} + + // mtmd_audio_preprocessor_gemma4ua // diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index 9656e3940f53..b00715a9b5b2 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -111,6 +111,21 @@ struct mtmd_audio_preprocessor_qwen3a : mtmd_audio_preprocessor { mtmd_audio_cache cache; }; +struct mtmd_audio_preprocessor_parakeet : mtmd_audio_preprocessor { + mtmd_audio_preprocessor_parakeet(clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) { } + void initialize() override; + bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; + + static void worker_thread(int ith, const float * window_func, int window_size, + const std::vector & samples, int n_samples, + int frame_size, int frame_step, int n_threads, + int n_fft_bins, + const mtmd_audio_cache & cache, mtmd_audio_mel & mel); +}; + // // streaming ISTFT - converts spectrogram frames back to audio one frame at a time // diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 0b5caa6cb5c1..4d2406742c1a 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -577,6 +577,10 @@ struct mtmd_context { aud_end = ""; audio_preproc = std::make_unique(ctx_a); } break; + case PROJECTOR_TYPE_PARAKEET: + { + audio_preproc = std::make_unique(ctx_a); + } break; case PROJECTOR_TYPE_GEMMA4UA: { aud_beg = "<|audio>";