Skip to content

Commit 93cf015

Browse files
committed
Fixed mypy
1 parent a9361a8 commit 93cf015

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

plugins/baseten/vision_agents/plugins/baseten/baseten_vlm.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,23 @@
33
import logging
44
import os
55
from collections import deque
6-
from typing import Iterator, Optional
6+
from typing import Iterator, Optional, cast
77

8-
import aiortc
98
import av
10-
from PIL.Image import Resampling
11-
from openai import AsyncOpenAI
12-
9+
from aiortc.mediastreams import MediaStreamTrack, VideoStreamTrack
1310
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
14-
15-
from vision_agents.core.llm.llm import LLMResponseEvent, VideoLLM
11+
from openai import AsyncOpenAI, AsyncStream
12+
from openai.types.chat import ChatCompletionChunk
13+
from PIL.Image import Resampling
1614
from vision_agents.core.llm.events import (
1715
LLMResponseChunkEvent,
1816
LLMResponseCompletedEvent,
1917
)
20-
from vision_agents.core.utils.video_forwarder import VideoForwarder
21-
from . import events
22-
18+
from vision_agents.core.llm.llm import LLMResponseEvent, VideoLLM
2319
from vision_agents.core.processors import Processor
20+
from vision_agents.core.utils.video_forwarder import VideoForwarder
2421

22+
from . import events
2523

2624
logger = logging.getLogger(__name__)
2725

@@ -118,8 +116,7 @@ async def simple_response(
118116
)
119117
return LLMResponseEvent(original=None, text="")
120118

121-
messages = []
122-
119+
messages: list[dict] = []
123120
# Add Agent's instructions as system prompt.
124121
if self.instructions:
125122
messages.append(
@@ -162,8 +159,10 @@ async def simple_response(
162159

163160
# TODO: Maybe move it to a method, too much code
164161
try:
165-
response = await self._client.chat.completions.create(
166-
messages=messages, model=self.model, stream=True
162+
response = await self._client.chat.completions.create( # type: ignore[arg-type]
163+
messages=messages, # type: ignore[arg-type]
164+
model=self.model,
165+
stream=True,
167166
)
168167
except Exception as e:
169168
# Send an error event if the request failed
@@ -180,12 +179,12 @@ async def simple_response(
180179
return LLMResponseEvent(original=None, text="")
181180

182181
i = 0
183-
llm_response_event: Optional[LLMResponseEvent] = LLMResponseEvent(
184-
original=None, text=""
182+
llm_response_event: LLMResponseEvent[Optional[ChatCompletionChunk]] = (
183+
LLMResponseEvent(original=None, text="")
185184
)
186185
text_chunks: list[str] = []
187186
total_text = ""
188-
async for chunk in response:
187+
async for chunk in cast(AsyncStream[ChatCompletionChunk], response):
189188
if not chunk.choices:
190189
continue
191190

@@ -226,7 +225,7 @@ async def simple_response(
226225

227226
async def watch_video_track(
228227
self,
229-
track: aiortc.mediastreams.VideoStreamTrack, # TODO: Check if this works, maybe I need to update typings everywhere
228+
track: MediaStreamTrack,
230229
shared_forwarder: Optional[VideoForwarder] = None,
231230
) -> None:
232231
"""
@@ -249,7 +248,7 @@ async def watch_video_track(
249248
logger.info("🎥 BasetenVLM subscribing to VideoForwarder")
250249
if not shared_forwarder:
251250
self._video_forwarder = shared_forwarder or VideoForwarder(
252-
track,
251+
cast(VideoStreamTrack, track),
253252
max_buffer=10,
254253
fps=1.0, # Low FPS for VLM
255254
name="baseten_vlm_forwarder",

0 commit comments

Comments
 (0)