Skip to content

Commit a0b5c9d

Browse files
committed
WIP local version
1 parent f838a1e commit a0b5c9d

File tree

1 file changed

+367
-0
lines changed

1 file changed

+367
-0
lines changed
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import asyncio
2+
import logging
3+
import os
4+
from typing import Optional, List
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
import aiortc
8+
import av
9+
import torch
10+
from PIL import Image
11+
from transformers import AutoModelForCausalLM
12+
13+
from vision_agents.core import llm
14+
from vision_agents.core.agents.agents import AgentOptions, default_agent_options
15+
from vision_agents.core.stt.events import STTTranscriptEvent
16+
from vision_agents.core.llm.events import (
17+
LLMResponseChunkEvent,
18+
LLMResponseCompletedEvent,
19+
)
20+
from vision_agents.core.llm.llm import LLMResponseEvent
21+
from vision_agents.core.processors import Processor
22+
from vision_agents.core.utils.video_forwarder import VideoForwarder
23+
from vision_agents.core.utils.queue import LatestNQueue
24+
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class LocalVLM(llm.VideoLLM):
30+
"""
31+
Local VLM using Moondream model for captioning or visual queries.
32+
33+
Note: The moondream3-preview model is gated and requires authentication:
34+
- Request access at https://huggingface.co/moondream/moondream3-preview
35+
- Once approved, authenticate using one of:
36+
- Set HF_TOKEN environment variable: export HF_TOKEN=your_token_here
37+
- Run: huggingface-cli login
38+
39+
Args:
40+
mode: "vqa" for visual question answering or "caption" for image captioning (default: "vqa")
41+
conf_threshold: Confidence threshold (unused for VLM, kept for API compatibility)
42+
max_workers: Number of worker threads for async operations
43+
device: Device to run inference on ('cuda', 'mps', or 'cpu').
44+
Auto-detects CUDA, then MPS (Apple Silicon), then defaults to CPU.
45+
Note: MPS is automatically converted to CPU due to model compatibility.
46+
model_name: Hugging Face model identifier (default: "moondream/moondream3-preview")
47+
options: AgentOptions for model directory configuration.
48+
If not provided, uses default_agent_options()
49+
"""
50+
51+
def __init__(
52+
self,
53+
mode: str = "vqa",
54+
conf_threshold: float = 0.3,
55+
max_workers: int = 10,
56+
device: Optional[str] = None,
57+
model_name: str = "moondream/moondream3-preview",
58+
options: Optional[AgentOptions] = None,
59+
):
60+
super().__init__()
61+
62+
self.conf_threshold = conf_threshold
63+
self.max_workers = max_workers
64+
self.mode = mode
65+
self.model_name = model_name
66+
self._shutdown = False
67+
68+
if options is None:
69+
self.options = default_agent_options()
70+
else:
71+
self.options = options
72+
73+
if device is None:
74+
if torch.cuda.is_available():
75+
self.device = "cuda"
76+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
77+
self.device = "cpu"
78+
logger.info("⚠️ MPS detected but using CPU (moondream model has CUDA dependencies incompatible with MPS)")
79+
else:
80+
self.device = "cpu"
81+
else:
82+
if device == "mps":
83+
self.device = "cpu"
84+
logger.warning("⚠️ MPS device requested but using CPU instead (moondream model has CUDA dependencies incompatible with MPS)")
85+
else:
86+
self.device = device
87+
88+
self._frame_buffer: LatestNQueue[av.VideoFrame] = LatestNQueue(maxlen=10)
89+
self._latest_frame: Optional[av.VideoFrame] = None
90+
self._video_forwarder: Optional[VideoForwarder] = None
91+
self._stt_subscription_setup = False
92+
self._is_processing = False
93+
94+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
95+
self.model = None
96+
97+
logger.info("🌙 Moondream Local VLM initialized")
98+
logger.info(f"🔧 Device: {self.device}")
99+
logger.info(f"📝 Mode: {self.mode}")
100+
101+
async def warmup(self) -> None:
102+
"""Initialize and load the model."""
103+
if self.model is None:
104+
await self._prepare_moondream()
105+
106+
async def _prepare_moondream(self):
107+
"""Load the Moondream model from Hugging Face."""
108+
logger.info(f"Loading Moondream model: {self.model_name}")
109+
logger.info(f"Device: {self.device}")
110+
111+
self.model = await asyncio.to_thread( # type: ignore[func-returns-value]
112+
lambda: self._load_model_sync()
113+
)
114+
logger.info("✅ Moondream model loaded")
115+
116+
def _load_model_sync(self):
117+
"""Synchronous model loading function run in thread pool."""
118+
try:
119+
hf_token = os.getenv("HF_TOKEN")
120+
if not hf_token:
121+
logger.warning(
122+
"⚠️ HF_TOKEN environment variable not set. "
123+
"This model requires authentication. "
124+
"Set HF_TOKEN or run 'huggingface-cli login'"
125+
)
126+
127+
load_kwargs = {
128+
"trust_remote_code": True,
129+
"dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
130+
"cache_dir": self.options.model_dir,
131+
}
132+
133+
if hf_token:
134+
load_kwargs["token"] = hf_token
135+
else:
136+
load_kwargs["token"] = True
137+
138+
if self.device == "cuda":
139+
load_kwargs["device_map"] = {"": "cuda"}
140+
else:
141+
load_kwargs["device_map"] = "cpu"
142+
143+
model = AutoModelForCausalLM.from_pretrained(
144+
self.model_name,
145+
**load_kwargs,
146+
)
147+
148+
model.eval()
149+
150+
if self.device == "cuda":
151+
logger.info("✅ Model loaded on CUDA device")
152+
else:
153+
logger.info("✅ Model loaded on CPU device")
154+
155+
try:
156+
model.compile()
157+
except Exception as compile_error:
158+
logger.warning(f"⚠️ Model compilation failed, continuing without compilation: {compile_error}")
159+
160+
return model
161+
except Exception as e:
162+
error_msg = str(e)
163+
if "gated repo" in error_msg.lower() or "403" in error_msg or "authorized" in error_msg.lower():
164+
logger.exception(
165+
"❌ Failed to load Moondream model: Model requires authentication.\n"
166+
"This model is gated and requires access approval:\n"
167+
f"1. Visit https://huggingface.co/{self.model_name} to request access\n"
168+
"2. Once approved, authenticate using one of:\n"
169+
" - Set HF_TOKEN environment variable: export HF_TOKEN=your_token_here\n"
170+
" - Run: huggingface-cli login\n"
171+
f"Original error: {e}"
172+
)
173+
else:
174+
logger.exception(f"❌ Failed to load Moondream model: {e}")
175+
raise
176+
177+
async def watch_video_track(
178+
self,
179+
track: aiortc.mediastreams.MediaStreamTrack,
180+
shared_forwarder: Optional[VideoForwarder] = None
181+
) -> None:
182+
"""Setup video forwarding and STT subscription."""
183+
if self._video_forwarder is not None and shared_forwarder is None:
184+
logger.warning("Video forwarder already running, stopping previous one")
185+
await self._stop_watching_video_track()
186+
187+
if self.model is None:
188+
await self._prepare_moondream()
189+
190+
if shared_forwarder is not None:
191+
self._video_forwarder = shared_forwarder
192+
logger.info("🎥 Moondream Local VLM subscribing to shared VideoForwarder")
193+
await self._video_forwarder.start_event_consumer(
194+
self._on_frame_received,
195+
fps=1.0,
196+
consumer_name="moondream_local_vlm"
197+
)
198+
else:
199+
self._video_forwarder = VideoForwarder(
200+
track, # type: ignore[arg-type]
201+
max_buffer=10,
202+
fps=1.0,
203+
name="moondream_local_vlm_forwarder",
204+
)
205+
await self._video_forwarder.start()
206+
await self._video_forwarder.start_event_consumer(
207+
self._on_frame_received
208+
)
209+
210+
if not self._stt_subscription_setup and self.agent:
211+
self._setup_stt_subscription()
212+
self._stt_subscription_setup = True
213+
214+
async def _on_frame_received(self, frame: av.VideoFrame):
215+
"""Callback to receive frames and add to buffer."""
216+
try:
217+
self._frame_buffer.put_latest_nowait(frame)
218+
self._latest_frame = frame
219+
except Exception as e:
220+
logger.error(f"Error adding frame to buffer: {e}")
221+
222+
def _setup_stt_subscription(self):
223+
if not self.agent:
224+
logger.warning("Cannot setup STT subscription: agent not set")
225+
return
226+
227+
@self.agent.events.subscribe
228+
async def on_stt_transcript(event: STTTranscriptEvent):
229+
await self._on_stt_transcript(event)
230+
231+
def _consume_stream(self, generator):
232+
"""Consume the generator stream from model query/caption methods."""
233+
chunks = []
234+
for chunk in generator:
235+
logger.debug(f"Moondream stream chunk: {type(chunk)} - {chunk}")
236+
if isinstance(chunk, str):
237+
chunks.append(chunk)
238+
else:
239+
logger.warning(f"Unexpected chunk type: {type(chunk)}, value: {chunk}")
240+
if chunk:
241+
chunks.append(str(chunk))
242+
result = "".join(chunks)
243+
logger.debug(f"Moondream stream result: {result}")
244+
return result
245+
246+
async def _process_frame(self, text: Optional[str] = None) -> Optional[LLMResponseEvent]:
247+
if self._latest_frame is None:
248+
logger.warning("No frames available, skipping Moondream processing")
249+
return None
250+
251+
if self._is_processing:
252+
logger.debug("Moondream processing already in progress, skipping")
253+
return None
254+
255+
if self.model is None:
256+
logger.warning("Model not loaded, skipping Moondream processing")
257+
return None
258+
259+
latest_frame = self._latest_frame
260+
261+
try:
262+
frame_array = latest_frame.to_ndarray(format="rgb24")
263+
image = Image.fromarray(frame_array)
264+
265+
if self.mode == "vqa":
266+
if not text:
267+
logger.warning("VQA mode requires text/question")
268+
return None
269+
270+
self._is_processing = True
271+
result = await asyncio.to_thread(self.model.query, image, text, stream=True)
272+
273+
if isinstance(result, dict) and "answer" in result:
274+
stream = result["answer"]
275+
else:
276+
stream = result
277+
278+
answer = await asyncio.to_thread(self._consume_stream, stream)
279+
280+
if not answer:
281+
logger.warning("Moondream query returned empty answer")
282+
self._is_processing = False
283+
return None
284+
285+
self.events.send(LLMResponseChunkEvent(delta=answer))
286+
self.events.send(LLMResponseCompletedEvent(text=answer))
287+
logger.info(f"Moondream VQA response: {answer}")
288+
self._is_processing = False
289+
return LLMResponseEvent(original=answer, text=answer)
290+
291+
elif self.mode == "caption":
292+
self._is_processing = True
293+
result = await asyncio.to_thread(self.model.caption, image, length="normal", stream=True)
294+
295+
if isinstance(result, dict) and "caption" in result:
296+
stream = result["caption"]
297+
else:
298+
stream = result
299+
300+
caption = await asyncio.to_thread(self._consume_stream, stream)
301+
302+
if not caption:
303+
logger.warning("Moondream caption returned empty result")
304+
self._is_processing = False
305+
return None
306+
307+
self.events.send(LLMResponseChunkEvent(delta=caption))
308+
self.events.send(LLMResponseCompletedEvent(text=caption))
309+
logger.info(f"Moondream caption: {caption}")
310+
self._is_processing = False
311+
return LLMResponseEvent(original=caption, text=caption)
312+
else:
313+
logger.error(f"Unknown mode: {self.mode}")
314+
self._is_processing = False
315+
return None
316+
317+
except Exception as e:
318+
logger.exception(f"Error processing frame: {e}")
319+
self._is_processing = False
320+
return LLMResponseEvent(original=None, text="", exception=e)
321+
322+
async def _on_stt_transcript(self, event: STTTranscriptEvent):
323+
"""Handle STT transcript event - process with Moondream."""
324+
if not event.text:
325+
return
326+
327+
await self._process_frame(text=event.text)
328+
329+
async def simple_response(
330+
self,
331+
text: str,
332+
processors: Optional[List[Processor]] = None,
333+
participant: Optional[Participant] = None,
334+
) -> LLMResponseEvent:
335+
"""
336+
simple_response is a standardized way to create a response.
337+
338+
Args:
339+
text: The text/question to respond to
340+
processors: list of processors (which contain state) about the video/voice AI
341+
participant: optionally the participant object
342+
343+
Examples:
344+
await llm.simple_response("What do you see in this image?")
345+
"""
346+
result = await self._process_frame(text=text if self.mode == "vqa" else None)
347+
if result is None:
348+
return LLMResponseEvent(original=None, text="",
349+
exception=ValueError("No frame available or processing failed"))
350+
return result
351+
352+
async def _stop_watching_video_track(self) -> None:
353+
"""Stop video forwarding."""
354+
if self._video_forwarder is not None:
355+
await self._video_forwarder.stop()
356+
self._video_forwarder = None
357+
logger.info("Stopped video forwarding")
358+
359+
def close(self):
360+
"""Clean up resources."""
361+
self._shutdown = True
362+
if hasattr(self, "executor"):
363+
self.executor.shutdown(wait=False)
364+
if self.model is not None:
365+
del self.model
366+
self.model = None
367+
logger.info("🛑 Moondream Local VLM closed")

0 commit comments

Comments
 (0)