diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 19c8a3f7..4bd9c426 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -47,7 +47,7 @@ from contextlib import asynccontextmanager from typing import AsyncIterator, Optional -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, UploadFile from fastapi.responses import StreamingResponse # Import from new modular API @@ -173,7 +173,10 @@ def load_model( logger.info(f"Loading model with SimpleEngine: {model_name}") _engine = SimpleEngine(model_name=model_name) # Start SimpleEngine synchronously (no background loop) - asyncio.get_event_loop().run_until_complete(_engine.start()) + # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(_engine.start()) model_type = "MLLM" if _engine.is_mllm else "LLM" logger.info(f"{model_type} model loaded (simple mode): {model_name}") @@ -278,6 +281,152 @@ async def execute_mcp_tool(request: MCPExecuteRequest) -> MCPExecuteResponse: ) +# ============================================================================= +# Audio Endpoints +# ============================================================================= + +# Global audio engines (lazy loaded) +_stt_engine = None +_tts_engine = None + + +@app.post("/v1/audio/transcriptions") +async def create_transcription( + file: UploadFile, + model: str = "whisper-large-v3", + language: Optional[str] = None, + response_format: str = "json", +): + """ + Transcribe audio to text (OpenAI Whisper API compatible). + + Supported models: + - whisper-large-v3 (multilingual, best quality) + - whisper-large-v3-turbo (faster) + - whisper-medium, whisper-small (lighter) + - parakeet-tdt-0.6b-v2 (English, fastest) + """ + global _stt_engine + import tempfile + import os + + try: + from .audio.stt import STTEngine + + # Map model aliases to full names + model_map = { + "whisper-large-v3": "mlx-community/whisper-large-v3-mlx", + "whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "whisper-medium": "mlx-community/whisper-medium-mlx", + "whisper-small": "mlx-community/whisper-small-mlx", + "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", + "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", + } + model_name = model_map.get(model, model) + + # Load engine if needed + if _stt_engine is None or _stt_engine.model_name != model_name: + _stt_engine = STTEngine(model_name) + _stt_engine.load() + + # Save uploaded file temporarily + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + try: + result = _stt_engine.transcribe(tmp_path, language=language) + finally: + os.unlink(tmp_path) + + if response_format == "text": + return result.text + + return { + "text": result.text, + "language": result.language, + "duration": result.duration, + } + + except ImportError: + raise HTTPException( + status_code=503, + detail="mlx-audio not installed. Install with: pip install mlx-audio" + ) + except Exception as e: + logger.error(f"Transcription failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/audio/speech") +async def create_speech( + model: str = "kokoro", + input: str = "", + voice: str = "af_heart", + speed: float = 1.0, + response_format: str = "wav", +): + """ + Generate speech from text (OpenAI TTS API compatible). + + Supported models: + - kokoro (fast, lightweight) + - chatterbox (multilingual, expressive) + - vibevoice (realtime) + - voxcpm (Chinese/English) + """ + global _tts_engine + from fastapi.responses import Response + + try: + from .audio.tts import TTSEngine + + # Map model aliases to full names + model_map = { + "kokoro": "mlx-community/Kokoro-82M-bf16", + "kokoro-4bit": "mlx-community/Kokoro-82M-4bit", + "chatterbox": "mlx-community/chatterbox-turbo-fp16", + "chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit", + "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", + "voxcpm": "mlx-community/VoxCPM1.5", + } + model_name = model_map.get(model, model) + + # Load engine if needed + if _tts_engine is None or _tts_engine.model_name != model_name: + _tts_engine = TTSEngine(model_name) + _tts_engine.load() + + audio = _tts_engine.generate(input, voice=voice, speed=speed) + audio_bytes = _tts_engine.to_bytes(audio, format=response_format) + + content_type = "audio/wav" if response_format == "wav" else f"audio/{response_format}" + return Response(content=audio_bytes, media_type=content_type) + + except ImportError: + raise HTTPException( + status_code=503, + detail="mlx-audio not installed. Install with: pip install mlx-audio" + ) + except Exception as e: + logger.error(f"TTS generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/v1/audio/voices") +async def list_voices(model: str = "kokoro"): + """List available voices for a TTS model.""" + from .audio.tts import KOKORO_VOICES, CHATTERBOX_VOICES + + if "kokoro" in model.lower(): + return {"voices": KOKORO_VOICES} + elif "chatterbox" in model.lower(): + return {"voices": CHATTERBOX_VOICES} + else: + return {"voices": ["default"]} + + # ============================================================================= # Completion Endpoints # =============================================================================