diff --git a/packages/sdk/bun.lock b/packages/sdk/bun.lock index ef3dd9e1dc..596513074b 100644 --- a/packages/sdk/bun.lock +++ b/packages/sdk/bun.lock @@ -18,7 +18,7 @@ "@qvac/transcription-parakeet": "^0.3.1", "@qvac/transcription-whispercpp": "^0.6.1", "@qvac/translation-nmtcpp": "^1.0.1", - "@qvac/tts-onnx": "^0.8.2", + "@qvac/tts-onnx": "^0.8.3", "fast-safe-stringify": "2.1.1", "which-runtime": "^1.3.2", "zod": "^4.0.17", @@ -546,7 +546,7 @@ "@qvac/translation-nmtcpp": ["@qvac/translation-nmtcpp@1.0.1", "", { "dependencies": { "@qvac/dl-hyperdrive": "^0.1.0", "@qvac/error": "^0.1.0", "@qvac/infer-base": "^0.2.0", "bare-path": "^3.0.0" } }, "sha512-KMF/8A2b7SxSMTsHxSZWH/YTFDyW1LSUEKs/x7f4P/zZTrqJYzZqOkEvxe9IS27kiz7dRAh7i3Q1hyvCnLdzIg=="], - "@qvac/tts-onnx": ["@qvac/tts-onnx@0.8.2", "", { "dependencies": { "@qvac/error": "^0.1.0", "@qvac/infer-base": "^0.4.0", "@qvac/onnx": "^0.14.0", "bare-fs": "^4.5.1", "bare-path": "^3.0.0" } }, "sha512-/x6T4aGrHnBpndBZZN9W9FHQmjb+X7uXp1AQv+mi4dJai4XNxhahh7tK8OCS5Z3HdB/Iv/6YZ2PL0wnVf+OgVQ=="], + "@qvac/tts-onnx": ["@qvac/tts-onnx@0.8.3", "", { "dependencies": { "@qvac/error": "^0.1.0", "@qvac/infer-base": "^0.4.0", "@qvac/onnx": "^0.14.0", "bare-fs": "^4.5.6", "bare-os": "^3.8.0", "bare-path": "^3.0.0" } }, "sha512-l7EUfE7zNr2cDn2elieF39RKqJsckxs2xN2inemb/YEJuiPGdY4Q2wOZknNfHgFXF4aC//nmKgRy5Q8nfhk8rA=="], "@react-native/assets-registry": ["@react-native/assets-registry@0.84.1", "", {}, "sha512-lAJ6PDZv95FdT9s9uhc9ivhikW1Zwh4j9XdXM7J2l4oUA3t37qfoBmTSDLuPyE3Bi+Xtwa11hJm0BUTT2sc/gg=="], diff --git a/packages/sdk/examples/tts/chatterbox-enhanced.ts b/packages/sdk/examples/tts/chatterbox-enhanced.ts new file mode 100644 index 0000000000..8db1717328 --- /dev/null +++ b/packages/sdk/examples/tts/chatterbox-enhanced.ts @@ -0,0 +1,132 @@ +import { + loadModel, + textToSpeech, + unloadModel, + type ModelProgressUpdate, + TTS_TOKENIZER_EN_CHATTERBOX, + TTS_SPEECH_ENCODER_EN_CHATTERBOX_FP32, + TTS_EMBED_TOKENS_EN_CHATTERBOX_FP32, + TTS_CONDITIONAL_DECODER_EN_CHATTERBOX_FP32, + TTS_LANGUAGE_MODEL_EN_CHATTERBOX_FP32, + TTS_ENHANCER_BACKBONE_LAVASR_FP32, + TTS_ENHANCER_SPEC_HEAD_LAVASR_FP32, + TTS_DENOISER_LAVASR_FP32, +} from "@qvac/sdk"; +import { + createWav, + playAudio, + int16ArrayToBuffer, + createWavHeader, +} from "./utils"; + +// A/B comparison: Chatterbox TTS with and without LavaSR neural speech enhancement. +// Produces two WAV files so you can hear the difference. +// Usage: node chatterbox-enhanced.js +// LavaSR models are loaded from the QVAC Registry automatically. +const [referenceAudioSrc] = process.argv.slice(2); + +if (!referenceAudioSrc) { + console.error("Usage: node chatterbox-enhanced.js "); + process.exit(1); +} + +const CHATTERBOX_SAMPLE_RATE = 24000; +const ENHANCED_SAMPLE_RATE = 48000; +const SYNTHESIS_TEXT = + "Hello! This sentence is synthesized twice, once at standard quality and once with LavaSR neural enhancement, so you can hear the difference."; + +const chatterboxConfig = { + ttsEngine: "chatterbox" as const, + language: "en" as const, + ttsTokenizerSrc: TTS_TOKENIZER_EN_CHATTERBOX.src, + ttsSpeechEncoderSrc: TTS_SPEECH_ENCODER_EN_CHATTERBOX_FP32.src, + ttsEmbedTokensSrc: TTS_EMBED_TOKENS_EN_CHATTERBOX_FP32.src, + ttsConditionalDecoderSrc: TTS_CONDITIONAL_DECODER_EN_CHATTERBOX_FP32.src, + ttsLanguageModelSrc: TTS_LANGUAGE_MODEL_EN_CHATTERBOX_FP32.src, + referenceAudioSrc, +}; + +function onProgress(progress: ModelProgressUpdate) { + console.log(progress); +} + +function saveAndPlay(samples: number[], sampleRate: number, filename: string) { + createWav(samples, sampleRate, filename); + console.log(`Saved ${filename}`); + const audioData = int16ArrayToBuffer(samples); + const wavBuffer = Buffer.concat([ + createWavHeader(audioData.length, sampleRate), + audioData, + ]); + playAudio(wavBuffer); +} + +try { + // --- Pass 1: Raw Chatterbox (no enhancer) --- + console.log("\n--- Pass 1: Raw Chatterbox (24 kHz) ---\n"); + + const rawModelId = await loadModel({ + modelSrc: TTS_TOKENIZER_EN_CHATTERBOX.src, + modelType: "tts", + modelConfig: chatterboxConfig, + onProgress, + }); + + const rawResult = textToSpeech({ + modelId: rawModelId, + text: SYNTHESIS_TEXT, + inputType: "text", + stream: false, + }); + + const rawBuffer = await rawResult.buffer; + console.log(`Raw TTS complete. ${rawBuffer.length} samples @ ${CHATTERBOX_SAMPLE_RATE} Hz`); + saveAndPlay(rawBuffer, CHATTERBOX_SAMPLE_RATE, "tts-raw-output.wav"); + + await unloadModel({ modelId: rawModelId }); + console.log("Raw model unloaded.\n"); + + // --- Pass 2: Chatterbox + LavaSR enhancement --- + console.log("--- Pass 2: Chatterbox + LavaSR enhancement (48 kHz) ---\n"); + + const enhancedModelId = await loadModel({ + modelSrc: TTS_TOKENIZER_EN_CHATTERBOX.src, + modelType: "tts", + modelConfig: { + ...chatterboxConfig, + enhancer: { + type: "lavasr", + enhance: true, + denoise: true, + backboneSrc: TTS_ENHANCER_BACKBONE_LAVASR_FP32.src, + specHeadSrc: TTS_ENHANCER_SPEC_HEAD_LAVASR_FP32.src, + denoiserSrc: TTS_DENOISER_LAVASR_FP32.src, + }, + }, + onProgress, + }); + + const enhancedResult = textToSpeech({ + modelId: enhancedModelId, + text: SYNTHESIS_TEXT, + inputType: "text", + stream: false, + }); + + const enhancedBuffer = await enhancedResult.buffer; + console.log(`Enhanced TTS complete. ${enhancedBuffer.length} samples @ ${ENHANCED_SAMPLE_RATE} Hz`); + saveAndPlay(enhancedBuffer, ENHANCED_SAMPLE_RATE, "tts-enhanced-output.wav"); + + await unloadModel({ modelId: enhancedModelId }); + console.log("Enhanced model unloaded."); + + console.log("\n--- Done ---"); + console.log("Compare the two files:"); + console.log(" tts-raw-output.wav -- 24 kHz standard Chatterbox"); + console.log(" tts-enhanced-output.wav -- 48 kHz with LavaSR enhancement"); + + process.exit(0); +} catch (error) { + console.error("❌ Error:", error); + process.exit(1); +} diff --git a/packages/sdk/package.json b/packages/sdk/package.json index e33345ec97..e406687e6f 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -186,7 +186,7 @@ "@qvac/transcription-parakeet": "^0.3.1", "@qvac/transcription-whispercpp": "^0.6.1", "@qvac/translation-nmtcpp": "^1.0.1", - "@qvac/tts-onnx": "^0.8.2", + "@qvac/tts-onnx": "^0.8.3", "fast-safe-stringify": "2.1.1", "which-runtime": "^1.3.2", "zod": "^4.0.17" diff --git a/packages/sdk/schemas/text-to-speech.ts b/packages/sdk/schemas/text-to-speech.ts index cf7ed6b7b1..b553092059 100644 --- a/packages/sdk/schemas/text-to-speech.ts +++ b/packages/sdk/schemas/text-to-speech.ts @@ -11,9 +11,33 @@ export const TTS_LANGUAGES = [ const ttsLanguageSchema = z.enum(TTS_LANGUAGES); +const lavaSREnhancerRuntimeSchema = z.object({ + type: z.literal("lavasr"), + enhance: z.boolean().optional(), + denoise: z.boolean().optional(), +}); + +const ttsEnhancerRuntimeConfigSchema = z.discriminatedUnion("type", [ + lavaSREnhancerRuntimeSchema, +]); + +export const lavaSREnhancerConfigSchema = lavaSREnhancerRuntimeSchema.extend({ + backboneSrc: modelSrcInputSchema, + specHeadSrc: modelSrcInputSchema, + denoiserSrc: modelSrcInputSchema.optional(), +}); + +export const ttsEnhancerConfigSchema = z + .discriminatedUnion("type", [lavaSREnhancerConfigSchema]) + .refine( + (data) => data.type !== "lavasr" || !data.denoise || data.denoiserSrc !== undefined, + { message: "denoiserSrc is required when denoise is true", path: ["denoiserSrc"] }, + ); + export const ttsChatterboxRuntimeConfigSchema = z.object({ ttsEngine: z.literal("chatterbox"), language: ttsLanguageSchema, + enhancer: ttsEnhancerRuntimeConfigSchema.optional(), }); export const ttsSupertonicRuntimeConfigSchema = z.object({ @@ -22,6 +46,7 @@ export const ttsSupertonicRuntimeConfigSchema = z.object({ ttsSpeed: z.number().optional(), ttsNumInferenceSteps: z.number().optional(), ttsSupertonicMultilingual: z.boolean().optional(), + enhancer: ttsEnhancerRuntimeConfigSchema.optional(), }); export const ttsRuntimeConfigSchema = z.union([ @@ -36,6 +61,7 @@ export const ttsChatterboxConfigSchema = ttsChatterboxRuntimeConfigSchema.extend ttsConditionalDecoderSrc: modelSrcInputSchema, ttsLanguageModelSrc: modelSrcInputSchema, referenceAudioSrc: modelSrcInputSchema, + enhancer: ttsEnhancerConfigSchema.optional(), }); export const ttsSupertonicConfigSchema = ttsSupertonicRuntimeConfigSchema.extend({ @@ -46,6 +72,7 @@ export const ttsSupertonicConfigSchema = ttsSupertonicRuntimeConfigSchema.extend ttsUnicodeIndexerSrc: modelSrcInputSchema, ttsTtsConfigSrc: modelSrcInputSchema, ttsVoiceStyleSrc: modelSrcInputSchema, + enhancer: ttsEnhancerConfigSchema.optional(), }); export const ttsConfigSchema = z.union([ @@ -87,6 +114,9 @@ export type TtsSupertonicRuntimeConfig = z.infer< typeof ttsSupertonicRuntimeConfigSchema >; export type TtsRuntimeConfig = z.infer; +export type TtsEnhancerRuntimeConfig = z.infer; +export type TtsEnhancerConfig = z.infer; +export type LavaSREnhancerConfig = z.infer; export type TtsClientParams = z.infer; export type TtsRequest = z.infer; export type TtsResponse = z.infer; diff --git a/packages/sdk/server/bare/plugins/onnx-tts/plugin.ts b/packages/sdk/server/bare/plugins/onnx-tts/plugin.ts index d7173fcdd8..9994a7ac74 100644 --- a/packages/sdk/server/bare/plugins/onnx-tts/plugin.ts +++ b/packages/sdk/server/bare/plugins/onnx-tts/plugin.ts @@ -17,6 +17,8 @@ import { type TtsChatterboxRuntimeConfig, type TtsSupertonicRuntimeConfig, type TtsRuntimeConfig, + type TtsEnhancerConfig, + type TtsEnhancerRuntimeConfig, } from "@/schemas"; import { createStreamLogger, registerAddonLogger } from "@/logging"; import { @@ -27,6 +29,69 @@ import { textToSpeech } from "@/server/bare/plugins/onnx-tts/ops/text-to-speech" import { attachModelExecutionMs } from "@/profiling/model-execution"; import { loadReferenceAudioAt24k } from "@/server/bare/plugins/onnx-tts/wav-helper"; +async function resolveEnhancerArtifacts( + enhancer: TtsEnhancerConfig | undefined, + resolve: ResolveContext["resolveModelPath"], +) { + if (!enhancer) return {}; + + switch (enhancer.type) { + case "lavasr": { + const [enhancerBackbonePath, enhancerSpecHeadPath, denoiserPath] = await Promise.all([ + resolve(enhancer.backboneSrc), + resolve(enhancer.specHeadSrc), + enhancer.denoiserSrc ? resolve(enhancer.denoiserSrc) : undefined, + ]); + return { + enhancerBackbonePath, + enhancerSpecHeadPath, + ...(denoiserPath && { denoiserPath }), + }; + } + default: + throw new Error(`Unknown enhancer type: ${(enhancer as { type: string }).type}`); + } +} + +function buildRuntimeEnhancer(enhancer: TtsEnhancerConfig | undefined) { + if (!enhancer) return undefined; + switch (enhancer.type) { + case "lavasr": + return { + type: "lavasr" as const, + enhance: enhancer.enhance ?? false, + denoise: enhancer.denoise ?? false, + }; + default: + throw new Error(`Unknown enhancer type: ${(enhancer as { type: string }).type}`); + } +} + +function buildEnhancerArg( + enhancer: TtsEnhancerRuntimeConfig | undefined, + artifacts: Record, +) { + if (!enhancer) return undefined; + + switch (enhancer.type) { + case "lavasr": { + const backbonePath = artifacts["enhancerBackbonePath"]; + const specHeadPath = artifacts["enhancerSpecHeadPath"]; + if (!backbonePath || !specHeadPath) return undefined; + return { + type: "lavasr" as const, + ...(enhancer.enhance !== undefined && { enhance: enhancer.enhance }), + ...(enhancer.denoise !== undefined && { denoise: enhancer.denoise }), + backbonePath, + specHeadPath, + ...(artifacts["denoiserPath"] && { denoiserPath: artifacts["denoiserPath"] }), + }; + } + default: + throw new Error(`Unknown enhancer type: ${(enhancer as { type: string }).type}`); + } +} + async function resolveChatterboxConfig( config: TtsChatterboxConfig, ctx: ResolveContext, @@ -39,6 +104,7 @@ async function resolveChatterboxConfig( ttsLanguageModelSrc, referenceAudioSrc, language, + enhancer, } = config; if ( @@ -56,25 +122,34 @@ async function resolveChatterboxConfig( const resolve = ctx.resolveModelPath; const [ - tokenizerPath, - speechEncoderPath, - embedTokensPath, - conditionalDecoderPath, - languageModelPath, - referenceAudioPath, + [ + tokenizerPath, + speechEncoderPath, + embedTokensPath, + conditionalDecoderPath, + languageModelPath, + referenceAudioPath, + ], + enhancerArtifacts, ] = await Promise.all([ - resolve(ttsTokenizerSrc), - resolve(ttsSpeechEncoderSrc), - resolve(ttsEmbedTokensSrc), - resolve(ttsConditionalDecoderSrc), - resolve(ttsLanguageModelSrc), - resolve(referenceAudioSrc), + Promise.all([ + resolve(ttsTokenizerSrc), + resolve(ttsSpeechEncoderSrc), + resolve(ttsEmbedTokensSrc), + resolve(ttsConditionalDecoderSrc), + resolve(ttsLanguageModelSrc), + resolve(referenceAudioSrc), + ]), + resolveEnhancerArtifacts(enhancer, resolve), ]); + const runtimeEnhancer = buildRuntimeEnhancer(enhancer); + return { config: { ttsEngine: "chatterbox", language, + ...(runtimeEnhancer && { enhancer: runtimeEnhancer }), } as TtsChatterboxRuntimeConfig, artifacts: { tokenizerPath, @@ -83,6 +158,7 @@ async function resolveChatterboxConfig( conditionalDecoderPath, languageModelPath, referenceAudioPath, + ...enhancerArtifacts, }, }; } @@ -103,6 +179,7 @@ async function resolveSupertonicConfig( ttsNumInferenceSteps, ttsSupertonicMultilingual, language, + enhancer, } = config; if ( @@ -119,23 +196,31 @@ async function resolveSupertonicConfig( const resolve = ctx.resolveModelPath; const [ - textEncoderPath, - durationPredictorPath, - vectorEstimatorPath, - vocoderPath, - unicodeIndexerPath, - ttsConfigPath, - voiceStylePath, + [ + textEncoderPath, + durationPredictorPath, + vectorEstimatorPath, + vocoderPath, + unicodeIndexerPath, + ttsConfigPath, + voiceStylePath, + ], + enhancerArtifacts, ] = await Promise.all([ - resolve(ttsTextEncoderSrc), - resolve(ttsDurationPredictorSrc), - resolve(ttsVectorEstimatorSrc), - resolve(ttsVocoderSrc), - resolve(ttsUnicodeIndexerSrc), - resolve(ttsTtsConfigSrc), - resolve(ttsVoiceStyleSrc), + Promise.all([ + resolve(ttsTextEncoderSrc), + resolve(ttsDurationPredictorSrc), + resolve(ttsVectorEstimatorSrc), + resolve(ttsVocoderSrc), + resolve(ttsUnicodeIndexerSrc), + resolve(ttsTtsConfigSrc), + resolve(ttsVoiceStyleSrc), + ]), + resolveEnhancerArtifacts(enhancer, resolve), ]); + const runtimeEnhancer = buildRuntimeEnhancer(enhancer); + return { config: { ttsEngine: "supertonic", @@ -143,6 +228,7 @@ async function resolveSupertonicConfig( ttsSpeed, ttsNumInferenceSteps, ttsSupertonicMultilingual, + ...(runtimeEnhancer && { enhancer: runtimeEnhancer }), } as TtsSupertonicRuntimeConfig, artifacts: { textEncoderPath, @@ -152,6 +238,7 @@ async function resolveSupertonicConfig( unicodeIndexerPath, ttsConfigPath, voiceStylePath, + ...enhancerArtifacts, }, }; } @@ -184,6 +271,7 @@ function createChatterboxModel( const logger = createStreamLogger(modelId, ModelType.onnxTts); registerAddonLogger(modelId, ModelType.onnxTts, logger); const referenceAudio = loadReferenceAudioAt24k(referenceAudioPath); + const enhancerArg = buildEnhancerArg(config.enhancer, artifacts); const model = new ONNXTTS({ files: { tokenizerPath, @@ -198,6 +286,7 @@ function createChatterboxModel( logger, opts: { stats: true }, exclusiveRun: true, + ...(enhancerArg && { enhancer: enhancerArg }), } as never); return { model, loader: undefined }; } @@ -230,6 +319,7 @@ function createSupertonicModel( const logger = createStreamLogger(modelId, ModelType.onnxTts); registerAddonLogger(modelId, ModelType.onnxTts, logger); const voiceName = path.basename(voiceStylePath).replace(/\.json$/i, "") || "F1"; + const enhancerArg = buildEnhancerArg(config.enhancer, artifacts); const model = new ONNXTTS({ files: { textEncoderPath, @@ -249,6 +339,7 @@ function createSupertonicModel( logger, opts: { stats: true }, exclusiveRun: true, + ...(enhancerArg && { enhancer: enhancerArg }), } as never); return { model, loader: undefined }; } diff --git a/packages/sdk/test/unit/tts-schemas.test.ts b/packages/sdk/test/unit/tts-schemas.test.ts new file mode 100644 index 0000000000..785930b3d0 --- /dev/null +++ b/packages/sdk/test/unit/tts-schemas.test.ts @@ -0,0 +1,307 @@ +// @ts-expect-error brittle has no type declarations +import test from "brittle"; +import { + lavaSREnhancerConfigSchema, + ttsEnhancerConfigSchema, + ttsChatterboxRuntimeConfigSchema, + ttsSupertonicRuntimeConfigSchema, + ttsChatterboxConfigSchema, + ttsSupertonicConfigSchema, +} from "@/schemas/text-to-speech"; + +// --- lavaSREnhancerConfigSchema (load-time) --- + +test("lavaSREnhancerConfigSchema: accepts valid config with required model sources", (t) => { + const result = lavaSREnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, true); +}); + +test("lavaSREnhancerConfigSchema: accepts config with optional denoiserSrc", (t) => { + const result = lavaSREnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + denoise: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + denoiserSrc: "denoiser.onnx", + }); + t.is(result.success, true); +}); + +test("lavaSREnhancerConfigSchema: rejects missing backboneSrc", (t) => { + const result = lavaSREnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, false); +}); + +test("lavaSREnhancerConfigSchema: rejects missing specHeadSrc", (t) => { + const result = lavaSREnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + }); + t.is(result.success, false); +}); + +test("lavaSREnhancerConfigSchema: rejects wrong type discriminator", (t) => { + const result = lavaSREnhancerConfigSchema.safeParse({ + type: "unknown", + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, false); +}); + +// --- ttsEnhancerConfigSchema (discriminated union) --- + +test("ttsEnhancerConfigSchema: accepts valid lavasr config", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, true); +}); + +test("ttsEnhancerConfigSchema: rejects unknown enhancer type", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + type: "unknown-enhancer", + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, false); +}); + +// --- Runtime config schemas (enhancer without model sources) --- + +test("ttsChatterboxRuntimeConfigSchema: accepts config with runtime enhancer", (t) => { + const result = ttsChatterboxRuntimeConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + enhancer: { type: "lavasr", enhance: true, denoise: false }, + }); + t.is(result.success, true); +}); + +test("ttsChatterboxRuntimeConfigSchema: accepts config without enhancer", (t) => { + const result = ttsChatterboxRuntimeConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + }); + t.is(result.success, true); +}); + +test("ttsSupertonicRuntimeConfigSchema: accepts config with runtime enhancer", (t) => { + const result = ttsSupertonicRuntimeConfigSchema.safeParse({ + ttsEngine: "supertonic", + language: "en", + enhancer: { type: "lavasr", enhance: true }, + }); + t.is(result.success, true); +}); + +// --- Load-time config: enhancer overrides runtime schema --- + +test("ttsChatterboxConfigSchema: requires model sources in enhancer at load time", (t) => { + const withSources = ttsChatterboxConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + ttsTokenizerSrc: "tok.bin", + ttsSpeechEncoderSrc: "enc.onnx", + ttsEmbedTokensSrc: "emb.onnx", + ttsConditionalDecoderSrc: "dec.onnx", + ttsLanguageModelSrc: "lm.onnx", + referenceAudioSrc: "ref.wav", + enhancer: { + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }, + }); + t.is(withSources.success, true); + + const withoutSources = ttsChatterboxConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + ttsTokenizerSrc: "tok.bin", + ttsSpeechEncoderSrc: "enc.onnx", + ttsEmbedTokensSrc: "emb.onnx", + ttsConditionalDecoderSrc: "dec.onnx", + ttsLanguageModelSrc: "lm.onnx", + referenceAudioSrc: "ref.wav", + enhancer: { + type: "lavasr", + enhance: true, + }, + }); + t.is(withoutSources.success, false); +}); + +test("ttsSupertonicConfigSchema: requires model sources in enhancer at load time", (t) => { + const withSources = ttsSupertonicConfigSchema.safeParse({ + ttsEngine: "supertonic", + language: "en", + ttsTextEncoderSrc: "enc.onnx", + ttsDurationPredictorSrc: "dp.onnx", + ttsVectorEstimatorSrc: "ve.onnx", + ttsVocoderSrc: "voc.onnx", + ttsUnicodeIndexerSrc: "ui.json", + ttsTtsConfigSrc: "tts.json", + ttsVoiceStyleSrc: "voice.json", + enhancer: { + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }, + }); + t.is(withSources.success, true); +}); + +test("ttsChatterboxConfigSchema: accepts config without enhancer", (t) => { + const result = ttsChatterboxConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + ttsTokenizerSrc: "tok.bin", + ttsSpeechEncoderSrc: "enc.onnx", + ttsEmbedTokensSrc: "emb.onnx", + ttsConditionalDecoderSrc: "dec.onnx", + ttsLanguageModelSrc: "lm.onnx", + referenceAudioSrc: "ref.wav", + }); + t.is(result.success, true); +}); + +test("ttsSupertonicConfigSchema: accepts config without enhancer", (t) => { + const result = ttsSupertonicConfigSchema.safeParse({ + ttsEngine: "supertonic", + language: "en", + ttsTextEncoderSrc: "enc.onnx", + ttsDurationPredictorSrc: "dp.onnx", + ttsVectorEstimatorSrc: "ve.onnx", + ttsVocoderSrc: "voc.onnx", + ttsUnicodeIndexerSrc: "ui.json", + ttsTtsConfigSrc: "tts.json", + ttsVoiceStyleSrc: "voice.json", + }); + t.is(result.success, true); +}); + +// --- Runtime schemas reject load-time model source fields --- + +test("ttsChatterboxRuntimeConfigSchema: strips load-time model source fields from enhancer", (t) => { + const result = ttsChatterboxRuntimeConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + enhancer: { + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }, + }); + t.is(result.success, true); + const keys = Object.keys(result.data?.enhancer ?? {}); + t.ok(!keys.includes("backboneSrc"), "backboneSrc should be stripped from runtime config"); + t.ok(!keys.includes("specHeadSrc"), "specHeadSrc should be stripped from runtime config"); +}); + +test("ttsSupertonicRuntimeConfigSchema: strips load-time model source fields from enhancer", (t) => { + const result = ttsSupertonicRuntimeConfigSchema.safeParse({ + ttsEngine: "supertonic", + language: "en", + enhancer: { + type: "lavasr", + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }, + }); + t.is(result.success, true); + const keys = Object.keys(result.data?.enhancer ?? {}); + t.ok(!keys.includes("backboneSrc"), "backboneSrc should be stripped from runtime config"); + t.ok(!keys.includes("specHeadSrc"), "specHeadSrc should be stripped from runtime config"); +}); + +// --- ttsEnhancerConfigSchema: denoise/denoiserSrc refinement --- + +test("ttsEnhancerConfigSchema: rejects denoise true without denoiserSrc", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + denoise: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, false); +}); + +test("ttsEnhancerConfigSchema: accepts denoise true with denoiserSrc", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + denoise: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + denoiserSrc: "denoiser.onnx", + }); + t.is(result.success, true); +}); + +test("ttsEnhancerConfigSchema: accepts denoise false without denoiserSrc", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + type: "lavasr", + enhance: true, + denoise: false, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, true); +}); + +test("ttsEnhancerConfigSchema: rejects empty object", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({}); + t.is(result.success, false); +}); + +test("ttsEnhancerConfigSchema: rejects object without type discriminator", (t) => { + const result = ttsEnhancerConfigSchema.safeParse({ + enhance: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }); + t.is(result.success, false); +}); + +// --- Load-time config: denoise refinement propagates through parent schemas --- + +test("ttsChatterboxConfigSchema: rejects enhancer with denoise true but no denoiserSrc", (t) => { + const result = ttsChatterboxConfigSchema.safeParse({ + ttsEngine: "chatterbox", + language: "en", + ttsTokenizerSrc: "tok.bin", + ttsSpeechEncoderSrc: "enc.onnx", + ttsEmbedTokensSrc: "emb.onnx", + ttsConditionalDecoderSrc: "dec.onnx", + ttsLanguageModelSrc: "lm.onnx", + referenceAudioSrc: "ref.wav", + enhancer: { + type: "lavasr", + enhance: true, + denoise: true, + backboneSrc: "backbone.onnx", + specHeadSrc: "spechead.onnx", + }, + }); + t.is(result.success, false); +});