Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 151 additions & 2 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import StreamingResponse

# Import from new modular API
Expand Down Expand Up @@ -173,7 +173,10 @@ def load_model(
logger.info(f"Loading model with SimpleEngine: {model_name}")
_engine = SimpleEngine(model_name=model_name)
# Start SimpleEngine synchronously (no background loop)
asyncio.get_event_loop().run_until_complete(_engine.start())
# Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(_engine.start())
model_type = "MLLM" if _engine.is_mllm else "LLM"
logger.info(f"{model_type} model loaded (simple mode): {model_name}")

Expand Down Expand Up @@ -278,6 +281,152 @@ async def execute_mcp_tool(request: MCPExecuteRequest) -> MCPExecuteResponse:
)


# =============================================================================
# Audio Endpoints
# =============================================================================

# Global audio engines (lazy loaded)
_stt_engine = None
_tts_engine = None


@app.post("/v1/audio/transcriptions")
async def create_transcription(
file: UploadFile,
model: str = "whisper-large-v3",
language: Optional[str] = None,
response_format: str = "json",
):
"""
Transcribe audio to text (OpenAI Whisper API compatible).

Supported models:
- whisper-large-v3 (multilingual, best quality)
- whisper-large-v3-turbo (faster)
- whisper-medium, whisper-small (lighter)
- parakeet-tdt-0.6b-v2 (English, fastest)
"""
global _stt_engine
import tempfile
import os

try:
from .audio.stt import STTEngine

# Map model aliases to full names
model_map = {
"whisper-large-v3": "mlx-community/whisper-large-v3-mlx",
"whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"whisper-medium": "mlx-community/whisper-medium-mlx",
"whisper-small": "mlx-community/whisper-small-mlx",
"parakeet": "mlx-community/parakeet-tdt-0.6b-v2",
"parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3",
}
model_name = model_map.get(model, model)

# Load engine if needed
if _stt_engine is None or _stt_engine.model_name != model_name:
_stt_engine = STTEngine(model_name)
_stt_engine.load()

# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name

try:
result = _stt_engine.transcribe(tmp_path, language=language)
finally:
os.unlink(tmp_path)

if response_format == "text":
return result.text

return {
"text": result.text,
"language": result.language,
"duration": result.duration,
}

except ImportError:
raise HTTPException(
status_code=503,
detail="mlx-audio not installed. Install with: pip install mlx-audio"
)
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/audio/speech")
async def create_speech(
model: str = "kokoro",
input: str = "",
voice: str = "af_heart",
speed: float = 1.0,
response_format: str = "wav",
):
"""
Generate speech from text (OpenAI TTS API compatible).

Supported models:
- kokoro (fast, lightweight)
- chatterbox (multilingual, expressive)
- vibevoice (realtime)
- voxcpm (Chinese/English)
"""
global _tts_engine
from fastapi.responses import Response

try:
from .audio.tts import TTSEngine

# Map model aliases to full names
model_map = {
"kokoro": "mlx-community/Kokoro-82M-bf16",
"kokoro-4bit": "mlx-community/Kokoro-82M-4bit",
"chatterbox": "mlx-community/chatterbox-turbo-fp16",
"chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit",
"vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit",
"voxcpm": "mlx-community/VoxCPM1.5",
}
model_name = model_map.get(model, model)

# Load engine if needed
if _tts_engine is None or _tts_engine.model_name != model_name:
_tts_engine = TTSEngine(model_name)
_tts_engine.load()

audio = _tts_engine.generate(input, voice=voice, speed=speed)
audio_bytes = _tts_engine.to_bytes(audio, format=response_format)

content_type = "audio/wav" if response_format == "wav" else f"audio/{response_format}"
return Response(content=audio_bytes, media_type=content_type)

except ImportError:
raise HTTPException(
status_code=503,
detail="mlx-audio not installed. Install with: pip install mlx-audio"
)
except Exception as e:
logger.error(f"TTS generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))


@app.get("/v1/audio/voices")
async def list_voices(model: str = "kokoro"):
"""List available voices for a TTS model."""
from .audio.tts import KOKORO_VOICES, CHATTERBOX_VOICES

if "kokoro" in model.lower():
return {"voices": KOKORO_VOICES}
elif "chatterbox" in model.lower():
return {"voices": CHATTERBOX_VOICES}
else:
return {"voices": ["default"]}


# =============================================================================
# Completion Endpoints
# =============================================================================
Expand Down