diff --git a/common/arg.cpp b/common/arg.cpp index 163c9b71b0e..b45946671df 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -57,6 +57,7 @@ static std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, + LLAMA_EXAMPLE_LIQUID_AUDIO, }; static std::string read_file(const std::string & fname) { @@ -1339,7 +1340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.system_prompt = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD})); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_LIQUID_AUDIO})); add_opt(common_arg( {"--perf"}, {"--no-perf"}, @@ -2159,7 +2160,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.image.emplace_back(item); } } - ).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI})); + ).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_LIQUID_AUDIO})); add_opt(common_arg( {"--image-min-tokens"}, "N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)", @@ -2639,7 +2640,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.out_file = value; } - ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE})); + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_LIQUID_AUDIO})); add_opt(common_arg( {"-ofreq", "--output-frequency"}, "N", string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), @@ -2771,14 +2772,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.hostname = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LIQUID_AUDIO}).set_env("LLAMA_ARG_HOST")); add_opt(common_arg( {"--port"}, "PORT", string_format("port to listen (default: %d)", params.port), [](common_params & params, int value) { params.port = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LIQUID_AUDIO}).set_env("LLAMA_ARG_PORT")); add_opt(common_arg( {"--path"}, "PATH", string_format("path to serve static files from (default: %s)", params.public_path.c_str()), @@ -3425,7 +3426,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.vocoder.model.path = value; } - ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LIQUID_AUDIO})); add_opt(common_arg( {"--tts-use-guide-tokens"}, "Use guide tokens to improve TTS word recall", @@ -3439,7 +3440,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.vocoder.speaker_file = value; } - ).set_examples({LLAMA_EXAMPLE_TTS})); + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LIQUID_AUDIO})); add_opt(common_arg( {"--diffusion-steps"}, "N", diff --git a/common/common.h b/common/common.h index 96c990c05d8..fe716b9782b 100644 --- a/common/common.h +++ b/common/common.h @@ -104,6 +104,7 @@ enum llama_example { LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_FIT_PARAMS, + LLAMA_EXAMPLE_LIQUID_AUDIO, LLAMA_EXAMPLE_COUNT, }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ab015dd2c3a..1cf1f5cb500 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10296,7 +10296,7 @@ def _add_feed_forward_length(self): def set_gguf_parameters(self): # set num_key_value_heads only for attention layers self.hparams["num_key_value_heads"] = [ - self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + self.hparams["num_key_value_heads"] if layer_type != "conv" else 0 for layer_type in self.hparams["layer_types"] ] @@ -10345,6 +10345,25 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield f"{self.dense_tensor_name}.weight", tensor.clone() +@ModelBase.register("Lfm25AudioTokenizer") +class LFM25AudioTokenizer(LFM2Model): + model_arch = gguf.MODEL_ARCH.LFM2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_embedding_length_out(self.hparams.get("output_size")) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "istft.window" or name.startswith("emb.emb"): + return [] + + if name.startswith("lin"): + name = name.replace("lin", "dense_2_out") + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Lfm2MoeForCausalLM") class LFM2MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.LFM2MOE diff --git a/deliverable/stress_test.py b/deliverable/stress_test.py new file mode 100755 index 00000000000..c9b772b31dd --- /dev/null +++ b/deliverable/stress_test.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "soundfile", +# "openai", +# ] +# /// +""" +Stress test for LFM2.5-Audio server. +Ramps up requests per second (RPS) for TTS, ASR, and FUNC modes. +""" + +import argparse +import base64 +import concurrent.futures +import json +import statistics +import sys +import time +from dataclasses import dataclass, field + +import numpy as np +import soundfile as sf +from openai import OpenAI + + +@dataclass +class RequestResult: + mode: str + success: bool + latency: float # total request time in seconds + ttft: float | None = None # time to first token + error: str | None = None + text_tokens: int = 0 + audio_samples: int = 0 + + +@dataclass +class RPSStageResult: + target_rps: float + mode: str + results: list[RequestResult] = field(default_factory=list) + + @property + def actual_rps(self): + if not self.results: + return 0 + total_time = max(r.latency for r in self.results) if self.results else 1 + return len(self.results) / total_time if total_time > 0 else 0 + + @property + def success_count(self): + return sum(1 for r in self.results if r.success) + + @property + def fail_count(self): + return sum(1 for r in self.results if not r.success) + + @property + def success_rate(self): + return self.success_count / len(self.results) * 100 if self.results else 0 + + @property + def latencies(self): + return [r.latency for r in self.results if r.success] + + @property + def ttfts(self): + return [r.ttft for r in self.results if r.success and r.ttft is not None] + + +# Default prompts for each mode +TTS_PROMPTS = [ + "Hello, how are you doing today?", + "The quick brown fox jumps over the lazy dog.", + "Welcome to the audio stress test.", + "This is a sample sentence for text to speech synthesis.", + "Testing the server under increasing load.", +] + +FUNC_PROMPTS = [ + 'What is the weather in San Francisco?', + 'Book a meeting for tomorrow at 3pm.', + 'Search for flights from New York to London.', + 'Set a reminder to buy groceries at 5pm.', + 'Calculate the distance from Paris to Berlin.', +] + +TTS_SYSTEM = "Perform TTS. Use the US male voice." +ASR_SYSTEM = "Perform ASR." +FUNC_SYSTEM = "Respond in function calls." + + +def single_request(base_url: str, mode: str, system_prompt: str, + user_content, max_tokens: int) -> RequestResult: + """Execute a single request and return the result.""" + client = OpenAI(base_url=base_url, api_key="dummy") + + modalities = ["audio"] if "TTS" in system_prompt else ["text"] + + t_start = time.time() + ttft = None + text_tokens = 0 + audio_samples = 0 + + try: + stream = client.chat.completions.create( + model="", + modalities=modalities, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ], + stream=True, + max_tokens=max_tokens, + ) + + completed = False + for chunk in stream: + if chunk.choices[0].finish_reason == "stop": + completed = True + break + + delta = chunk.choices[0].delta + + if text := delta.content: + if ttft is None: + ttft = time.time() - t_start + text_tokens += 1 + + if hasattr(delta, "audio") and delta.audio and "data" in delta.audio: + if ttft is None: + ttft = time.time() - t_start + pcm_bytes = base64.b64decode(delta.audio["data"]) + audio_samples += len(pcm_bytes) // 2 # int16 = 2 bytes + + latency = time.time() - t_start + + if not completed: + return RequestResult( + mode=mode, success=False, latency=latency, + error="Server disconnected before completion", + ) + + return RequestResult( + mode=mode, success=True, latency=latency, ttft=ttft, + text_tokens=text_tokens, audio_samples=audio_samples, + ) + + except Exception as e: + latency = time.time() - t_start + return RequestResult( + mode=mode, success=False, latency=latency, + error=str(e)[:200], + ) + + +def prepare_asr_content(wav_file: str): + """Load a WAV file and return OpenAI-compatible audio content.""" + with open(wav_file, "rb") as f: + wav_data = f.read() + encoded = base64.b64encode(wav_data).decode("utf-8") + return [{"type": "input_audio", "input_audio": {"data": encoded, "format": "wav"}}] + + +def run_rps_stage(base_url: str, mode: str, system_prompt: str, + contents: list, target_rps: float, duration: float, + max_tokens: int) -> RPSStageResult: + """Run a single RPS stage: fire requests at the target rate for `duration` seconds.""" + stage = RPSStageResult(target_rps=target_rps, mode=mode) + interval = 1.0 / target_rps if target_rps > 0 else 1.0 + total_requests = max(1, int(target_rps * duration)) + + futures = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=total_requests + 4) as pool: + t0 = time.time() + for i in range(total_requests): + # Schedule request at the right time + scheduled = t0 + i * interval + now = time.time() + if scheduled > now: + time.sleep(scheduled - now) + + content = contents[i % len(contents)] + futures.append( + pool.submit(single_request, base_url, mode, system_prompt, content, max_tokens) + ) + + # Wait for all to complete + for fut in concurrent.futures.as_completed(futures): + stage.results.append(fut.result()) + + return stage + + +def print_stage_report(stage: RPSStageResult): + """Print a summary for one RPS stage.""" + lats = stage.latencies + ttfts = stage.ttfts + + print(f" Target RPS: {stage.target_rps:>6.1f} | " + f"Requests: {len(stage.results):>4} | " + f"OK: {stage.success_count:>4} | " + f"Fail: {stage.fail_count:>3} | " + f"Success: {stage.success_rate:>5.1f}%") + + if lats: + print(f" Latency — min: {min(lats):.3f}s " + f"avg: {statistics.mean(lats):.3f}s " + f"p50: {statistics.median(lats):.3f}s " + f"p95: {sorted(lats)[int(len(lats) * 0.95)]:.3f}s " + f"max: {max(lats):.3f}s") + if ttfts: + print(f" TTFT — min: {min(ttfts):.3f}s " + f"avg: {statistics.mean(ttfts):.3f}s " + f"p50: {statistics.median(ttfts):.3f}s " + f"max: {max(ttfts):.3f}s") + + if stage.fail_count > 0: + errors = [r.error for r in stage.results if not r.success and r.error] + unique = set(errors) + for e in list(unique)[:3]: + print(f" Error: {e}") + + +def run_stress_test(base_url: str, modes: list[str], wav_file: str | None, + rps_stages: list[float], duration: float, max_tokens: int): + """Run the full stress test across modes and RPS levels.""" + # Prepare content for each mode + mode_configs: dict[str, tuple[str, list]] = {} + + if "tts" in modes: + mode_configs["tts"] = (TTS_SYSTEM, TTS_PROMPTS) + if "asr" in modes: + if not wav_file: + print("ERROR: ASR mode requires --wav argument", file=sys.stderr) + sys.exit(1) + asr_content = prepare_asr_content(wav_file) + mode_configs["asr"] = (ASR_SYSTEM, [asr_content]) + if "func" in modes: + mode_configs["func"] = (FUNC_SYSTEM, FUNC_PROMPTS) + + all_stages: list[RPSStageResult] = [] + + for mode, (sys_prompt, contents) in mode_configs.items(): + print(f"\n{'=' * 60}") + print(f" STRESS TEST: {mode.upper()}") + print(f"{'=' * 60}") + + for rps in rps_stages: + print(f"\n--- {mode.upper()} @ {rps} RPS (duration: {duration}s) ---") + stage = run_rps_stage(base_url, mode, sys_prompt, contents, rps, duration, max_tokens) + print_stage_report(stage) + all_stages.append(stage) + + # If success rate drops below 50%, stop escalating for this mode + if stage.success_rate < 50: + print(f" >> Success rate below 50%, stopping RPS ramp for {mode.upper()}") + break + + # Final summary + print(f"\n{'=' * 60}") + print(" SUMMARY") + print(f"{'=' * 60}") + print(f"{'Mode':<6} {'RPS':>6} {'Total':>6} {'OK':>6} {'Fail':>6} {'Rate':>7} {'Avg Lat':>8} {'Avg TTFT':>9}") + print("-" * 60) + for s in all_stages: + avg_lat = statistics.mean(s.latencies) if s.latencies else float("nan") + avg_ttft = statistics.mean(s.ttfts) if s.ttfts else float("nan") + print(f"{s.mode:<6} {s.target_rps:>6.1f} {len(s.results):>6} " + f"{s.success_count:>6} {s.fail_count:>6} {s.success_rate:>6.1f}% " + f"{avg_lat:>7.3f}s {avg_ttft:>8.3f}s") + + +def main(): + parser = argparse.ArgumentParser( + description="Stress test for LFM2.5-Audio server with increasing RPS" + ) + parser.add_argument( + "--modes", type=str, default="tts,asr,func", + help="Comma-separated list of modes to test: tts,asr,func (default: all)", + ) + parser.add_argument( + "--wav", type=str, + help="Path to input WAV file (required for ASR mode)", + ) + parser.add_argument( + "--rps", type=str, default="1,2,4,8,16,32,64", + help="Comma-separated RPS stages to ramp through (default: 1,2,4,8,16,32,64)", + ) + parser.add_argument( + "--duration", type=float, default=10.0, + help="Duration in seconds for each RPS stage (default: 10)", + ) + parser.add_argument( + "--max-tokens", type=int, default=512, + help="Max tokens per request (default: 512)", + ) + parser.add_argument( + "--base-url", type=str, default="http://127.0.0.1:8080/v1", + help="Server base URL (default: http://127.0.0.1:8080/v1)", + ) + + args = parser.parse_args() + + modes = [m.strip().lower() for m in args.modes.split(",")] + rps_stages = [float(r.strip()) for r in args.rps.split(",")] + + if "asr" in modes and not args.wav: + parser.error("--wav is required when testing ASR mode") + + print("Stress Test Configuration:") + print(f" Server: {args.base_url}") + print(f" Modes: {modes}") + print(f" RPS stages: {rps_stages}") + print(f" Duration: {args.duration}s per stage") + print(f" Max tokens: {args.max_tokens}") + + run_stress_test( + base_url=args.base_url, + modes=modes, + wav_file=args.wav, + rps_stages=rps_stages, + duration=args.duration, + max_tokens=args.max_tokens, + ) + + +if __name__ == "__main__": + main() diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 57485c534ee..a85e08d29b6 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2170,8 +2170,9 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || !(dense_2 || dense_3)) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; @@ -2180,6 +2181,9 @@ void llm_graph_context::build_dense_out( if (dense_2) { cur = ggml_mul_mat(ctx0, dense_2, cur); } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } if (dense_3) { cur = ggml_mul_mat(ctx0, dense_3, cur); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 93d32522d15..68fde483023 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -934,6 +934,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b58b35a4268..411b66d4833 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2287,6 +2287,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } } break; case LLM_ARCH_LFM2MOE: { @@ -6597,7 +6604,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -8045,7 +8053,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -8101,7 +8113,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); llm->res->set_outputs(); diff --git a/src/llama-model.h b/src/llama-model.h index d1de16e3f28..575cc1eba20 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -465,8 +465,9 @@ struct llama_model { //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map gguf_kv; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 048d65a75c2..99438160532 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -860,6 +860,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + // do not quantize conv weights + quantize &= name.find("conv.dw.weight") == std::string::npos; + quantize &= name.find("conv.pw1.weight") == std::string::npos; + quantize &= name.find("conv.pw2.weight") == std::string::npos; + quantize &= name.find("conv1d") == std::string::npos; + quantize &= name.find("conv_dw.weight") == std::string::npos; + // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index 7f805d78795..b1ca853b0e6 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -1,18 +1,26 @@ #include "models.h" - +// +#include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" - -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : +template +llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + using inp_hybrid_type = std::conditional_t; + inp_hybrid_type * inp_hybrid = nullptr; + if constexpr (iswa) { + inp_hybrid = build_inp_mem_hybrid_iswa(); + } else { + inp_hybrid = build_inp_mem_hybrid(); + } + ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); ggml_build_forward_expand(gf, cur); ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -54,29 +62,27 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const { - return build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - static_cast(hparams.expert_gating_func), il); +template ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const { + return build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + static_cast(hparams.expert_gating_func), il); } -ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const { +template ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - return build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); } -ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - int il) const { +template +ggml_tensor * llm_build_lfm2::build_attn_block( + ggml_tensor * cur, + ggml_tensor * inp_pos, + std::conditional_t * inp_attn, + int il) const { GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); const auto n_embd_head = hparams.n_embd_head_v; const auto n_head_kv = hparams.n_head_kv(il); @@ -104,17 +110,22 @@ ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur, k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q, k, v, nullptr, nullptr, nullptr, + 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); return cur; } -ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast(mctx)->get_recr(); +template +ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { + const llama_memory_recurrent_context * mctx_cur; + if constexpr (iswa) { + mctx_cur = static_cast(mctx)->get_recr(); + } else { + mctx_cur = static_cast(mctx)->get_recr(); + } const uint32_t kv_head = mctx_cur->get_head(); const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seqs = ubatch.n_seqs; @@ -173,3 +184,7 @@ ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph return y; } + +// Explicit template instantiations +template struct llm_build_lfm2; +template struct llm_build_lfm2; diff --git a/src/models/models.h b/src/models/models.h index 3a44f7f140f..8efc3bfb372 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -288,13 +288,14 @@ struct llm_build_jamba : public llm_graph_context_mamba { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; +template struct llm_build_lfm2 : public llm_graph_context { const llama_model & model; llm_build_lfm2(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const; ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const; + ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, std::conditional_t * inp_attn, int il) const; ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il); }; diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 518f8b9ae74..746df5a91c3 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -37,4 +37,5 @@ else() add_subdirectory(export-lora) endif() add_subdirectory(fit-params) + add_subdirectory(liquid-audio) endif() diff --git a/tools/liquid-audio/CMakeLists.txt b/tools/liquid-audio/CMakeLists.txt new file mode 100644 index 00000000000..cfe27235227 --- /dev/null +++ b/tools/liquid-audio/CMakeLists.txt @@ -0,0 +1,22 @@ +# lib +set(TARGET_LIB liquid-audio) +add_library(${TARGET_LIB} runner.cpp) +target_include_directories(${TARGET_LIB} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(${TARGET_LIB} PUBLIC llama common mtmd ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET_LIB} PRIVATE cxx_std_17) + +# cli +set(TARGET_CLI llama-liquid-audio-cli) +add_executable(${TARGET_CLI} cli.cpp) +target_link_libraries(${TARGET_CLI} PRIVATE ${TARGET_LIB}) +if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET_CLI} RUNTIME) +endif() + +# server +set(TARGET_SERVER llama-liquid-audio-server) +add_executable(${TARGET_SERVER} server.cpp) +target_link_libraries(${TARGET_SERVER} PRIVATE ${TARGET_LIB} cpp-httplib) +if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET_SERVER} RUNTIME) +endif() diff --git a/tools/liquid-audio/README.md b/tools/liquid-audio/README.md new file mode 100644 index 00000000000..9915dedd568 --- /dev/null +++ b/tools/liquid-audio/README.md @@ -0,0 +1,116 @@ +--- +license: other +license_name: lfm1.0 +license_link: LICENSE +language: +- en +tags: +- liquid +- lfm2.5 +- edge +- llama.cpp +- audio +- speech +- gguf +base_model: +- LiquidAI/LFM2.5-Audio-1.5B +widget: + - text: "Demo" + output: + url: demo.mp4 +--- + +
+ Liquid AI +
+ Try LFM • + Documentation • + LEAP +
+
+ +# LFM2.5-Audio-1.5B + +Find more details in the original model card: https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B + +## Runners + +`runners` folder contains runners for various architectures including + +- llama-liquid-audio-cli +- llama-liquid-audio-server + +## Convert GGUFs + +```bash +export CKPT=/path/to/LFM2.5-Audio-1.5B +export MODEL=LFM2.5-Audio-1.5B +# backbone +python convert_hf_to_gguf.py $CKPT --outfile $CKPT/${MODEL}-F16.gguf --outtype f16 +./llama-quantize $CKPT/${MODEL}-F16.gguf $CKPT/${MODEL}-Q8_0.gguf Q8_0 +./llama-quantize $CKPT/${MODEL}-F16.gguf $CKPT/${MODEL}-Q4_0.gguf Q4_0 +# mmproj +python convert_hf_to_gguf.py $CKPT --mmproj --outfile $CKPT/mmproj-${MODEL}-F16.gguf --outtype f16 +./llama-quantize $CKPT/mmproj-${MODEL}-F16.gguf $CKPT/mmproj-${MODEL}-Q8_0.gguf Q8_0 +./llama-quantize $CKPT/mmproj-${MODEL}-F16.gguf $CKPT/mmproj-${MODEL}-Q4_0.gguf Q4_0 +# vocoder +python tools/liquid-audio/convert_vocoder_to_gguf.py $CKPT --outfile $CKPT/vocoder-${MODEL}-F16.gguf --outtype f16 +python tools/liquid-audio/convert_vocoder_to_gguf.py $CKPT --outfile $CKPT/vocoder-${MODEL}-Q8_0.gguf --outtype q8_0 +python tools/liquid-audio/convert_vocoder_to_gguf.py $CKPT --outfile $CKPT/vocoder-${MODEL}-Q4_0.gguf --outtype q4_0 +# tokenizer +python convert_hf_to_gguf.py $CKPT/audio_detokenizer --outfile $CKPT/tokenizer-${MODEL}-F16.gguf --outtype f16 +./llama-quantize $CKPT/tokenizer-${MODEL}-F16.gguf $CKPT/tokenizer-${MODEL}-Q8_0.gguf Q8_0 +./llama-quantize $CKPT/tokenizer-${MODEL}-F16.gguf $CKPT/tokenizer-${MODEL}-Q4_0.gguf Q4_0 +``` + +# 🏃 How to run LFM2.5 + +## CLI + +Set env variables. +``` +export CKPT=/path/to/LFM2.5-Audio-1.5B-GGUF +export INPUT_WAV=/path/to/input.wav +export OUTPUT_WAV=/path/to/output.wav +``` + +### ASR (audio -> text) + +```bash +./llama-liquid-audio-cli -m $CKPT/LFM2.5-Audio-1.5B-Q4_0.gguf -mm $CKPT/mmproj-LFM2.5-Audio-1.5B-Q4_0.gguf -mv $CKPT/vocoder-LFM2.5-Audio-1.5B-Q4_0.gguf --tts-speaker-file $CKPT/tokenizer-LFM2.5-Audio-1.5B-Q4_0.gguf -sys "Perform ASR." --audio $INPUT_WAV +``` + +### TTS (text -> audio) + +```bash +./llama-liquid-audio-cli -m $CKPT/LFM2.5-Audio-1.5B-Q4_0.gguf -mm $CKPT/mmproj-LFM2.5-Audio-1.5B-Q4_0.gguf -mv $CKPT/vocoder-LFM2.5-Audio-1.5B-Q4_0.gguf --tts-speaker-file $CKPT/tokenizer-LFM2.5-Audio-1.5B-Q4_0.gguf -sys "Perform TTS." -p "Hi, how are you?" --output $OUTPUT_WAV +``` + +### Interleaved (audio/text -> audio + text) + +```bash +./llama-liquid-audio-cli -m $CKPT/LFM2.5-Audio-1.5B-Q4_0.gguf -mm $CKPT/mmproj-LFM2.5-Audio-1.5B-Q4_0.gguf -mv $CKPT/vocoder-LFM2.5-Audio-1.5B-Q4_0.gguf --tts-speaker-file $CKPT/tokenizer-LFM2.5-Audio-1.5B-Q4_0.gguf -sys "Respond with interleaved text and audio." --audio $INPUT_WAV --output $OUTPUT_WAV +``` + + +## Server + +Start server +``` +export CKPT=/path/to/LFM2.5-Audio-1.5B-GGUF +./llama-liquid-audio-server -m $CKPT/LFM2.5-Audio-1.5B-Q4_0.gguf -mm $CKPT/mmproj-LFM2.5-Audio-1.5B-Q4_0.gguf -mv $CKPT/vocoder-LFM2.5-Audio-1.5B-Q4_0.gguf --tts-speaker-file $CKPT/tokenizer-LFM2.5-Audio-1.5B-Q4_0.gguf +``` + +Use `liquid_audio_chat.py` script to communicate with the server. + +```bash +uv run liquid_audio_chat.py +``` + +# Demo + + diff --git a/tools/liquid-audio/cli.cpp b/tools/liquid-audio/cli.cpp new file mode 100644 index 00000000000..2bdbbc19514 --- /dev/null +++ b/tools/liquid-audio/cli.cpp @@ -0,0 +1,191 @@ +#include "mtmd-helper.h" +#include "mtmd.h" +#include "runner.h" + +// +#include "arg.h" +#include "common.h" +#include "ggml.h" +#include "log.h" + +#include + +namespace { +std::vector load_file(const char * fname) { + std::vector buf; + FILE * f = fopen(fname, "rb"); + if (!f) { + LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno)); + exit(1); + } + + fseek(f, 0, SEEK_END); + long file_size = ftell(f); + fseek(f, 0, SEEK_SET); + buf.resize(file_size); + + size_t n_read = fread(buf.data(), 1, file_size, f); + fclose(f); + if (n_read != (size_t) file_size) { + LOG_ERR("Failed to read entire file %s", fname); + exit(1); + } + + return buf; +} +} // namespace + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) +# include +# include +#elif defined(_WIN32) +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#endif + +static void show_additional_info(int /*argc*/, char ** argv) { + LOG("CLI for LFM2.5-Audio-1.5B\n\n" + "Usage: %s [options] -m --mmproj -mv --tts-speaker-file " + " " + "-sys [--audio " + "