Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 19 additions & 41 deletions agents-core/vision_agents/core/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import io
import importlib.metadata
import logging
import re
import os
import importlib.metadata
import re
from dataclasses import dataclass
from typing import Dict, Optional
from PIL import Image

import httpx

logger = logging.getLogger(__name__)


# Type alias for markdown file contents: maps filename to file content
MarkdownFileContents = Dict[str, str]
Expand Down Expand Up @@ -52,7 +53,9 @@ def _read_markdown_file_sync(file_path: str) -> str:
return ""


async def parse_instructions_async(text: str, base_dir: Optional[str] = None) -> Instructions:
async def parse_instructions_async(
text: str, base_dir: Optional[str] = None
) -> Instructions:
"""
Async version: Parse instructions from a string, extracting @ mentioned markdown files and their contents.

Expand Down Expand Up @@ -127,32 +130,6 @@ def parse_instructions(text: str, base_dir: Optional[str] = None) -> Instruction
)


def frame_to_png_bytes(frame) -> bytes:
"""
Convert a video frame to PNG bytes.

Args:
frame: Video frame object that can be converted to an image

Returns:
PNG bytes of the frame, or empty bytes if conversion fails
"""
logger = logging.getLogger(__name__)
try:
if hasattr(frame, "to_image"):
img = frame.to_image()
else:
arr = frame.to_ndarray(format="rgb24")
img = Image.fromarray(arr)

buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
except Exception as e:
logger.error(f"Error converting frame to PNG: {e}")
return b""


def get_vision_agents_version() -> Optional[str]:
"""
Get the installed vision-agents package version.
Expand All @@ -166,43 +143,44 @@ def get_vision_agents_version() -> Optional[str]:
async def ensure_model(path: str, url: str) -> str:
"""
Download a model file asynchronously if it doesn't exist.

Args:
path: Local path where the model should be saved
url: URL to download the model from

Returns:
The path to the model file
"""

logger = logging.getLogger(__name__)
if not os.path.exists(path):
model_name = os.path.basename(path)
logger.info(f"Downloading {model_name}...")

try:
async with httpx.AsyncClient(timeout=300.0, follow_redirects=True) as client:
async with httpx.AsyncClient(
timeout=300.0, follow_redirects=True
) as client:
async with client.stream("GET", url) as response:
response.raise_for_status()

# Write file in chunks to avoid loading entire file in memory
chunks = []
async for chunk in response.aiter_bytes(chunk_size=8192):
chunks.append(chunk)

# Write all chunks to file in thread to avoid blocking event loop
def write_file():
with open(path, "wb") as f:
for chunk in chunks:
f.write(chunk)

await asyncio.to_thread(write_file)

logger.info(f"{model_name} downloaded.")
except httpx.HTTPError as e:
# Clean up partial download on error
if os.path.exists(path):
os.remove(path)
raise RuntimeError(f"Failed to download {model_name}: {e}")

return path
22 changes: 22 additions & 0 deletions agents-core/vision_agents/core/utils/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import av
from PIL.Image import Resampling
from PIL import Image


def ensure_even_dimensions(frame: av.VideoFrame) -> av.VideoFrame:
Expand Down Expand Up @@ -61,3 +62,24 @@ def frame_to_jpeg_bytes(
buf = io.BytesIO()
resized.save(buf, "JPEG", quality=quality, optimize=True)
return buf.getvalue()


def frame_to_png_bytes(frame: av.VideoFrame) -> bytes:
"""
Convert a video frame to PNG bytes.

Args:
frame: Video frame object that can be converted to an image

Returns:
PNG bytes of the frame, or empty bytes if conversion fails
"""
if hasattr(frame, "to_image"):
img = frame.to_image()
else:
arr = frame.to_ndarray(format="rgb24")
img = Image.fromarray(arr)

buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
43 changes: 13 additions & 30 deletions plugins/gemini/tests/test_gemini_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
from vision_agents.core.tts.manual_test import play_pcm_with_ffplay
from vision_agents.plugins.gemini import Realtime
from vision_agents.core.llm.events import RealtimeAudioOutputEvent
from vision_agents.core.utils.utils import frame_to_png_bytes
from getstream.video.rtc import PcmData, AudioFormat

# Load environment variables
load_dotenv()


class TestGeminiRealtime:
"""Integration tests for Realtime2 connect flow"""
@pytest.fixture
async def realtime():
"""Create and manage Realtime connection lifecycle"""
realtime = Realtime(
model="gemini-2.0-flash-exp",
)
try:
yield realtime
finally:
await realtime.close()

@pytest.fixture
async def realtime(self):
"""Create and manage Realtime connection lifecycle"""
realtime = Realtime(
model="gemini-2.0-flash-exp",
)
try:
yield realtime
finally:
await realtime.close()

class TestGeminiRealtime:
"""Integration tests for Gemini Realtime connect flow"""

@pytest.mark.integration
async def test_simple_response_flow(self, realtime):
Expand Down Expand Up @@ -93,20 +93,3 @@ async def on_audio(event: RealtimeAudioOutputEvent):
# Stop video sender
await realtime._stop_watching_video_track()
assert len(events) > 0

async def test_frame_to_png_bytes_with_bunny_video(self, bunny_video_track):
"""Test that frame_to_png_bytes works with real bunny video frames"""
# Get a frame from the bunny video track
frame = await bunny_video_track.recv()
png_bytes = frame_to_png_bytes(frame)

# Verify we got PNG data
assert isinstance(png_bytes, bytes)
assert len(png_bytes) > 0

# Verify it's actually PNG data (PNG files start with specific bytes)
assert png_bytes.startswith(b"\x89PNG\r\n\x1a\n")

print(
f"Successfully converted bunny video frame to PNG: {len(png_bytes)} bytes"
)
20 changes: 7 additions & 13 deletions plugins/gemini/tests/test_realtime_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ async def realtime_instance(self):
finally:
await realtime.close()

@pytest.mark.asyncio
async def test_convert_tools_to_provider_format(self):
"""Test tool conversion to Gemini Live format."""
# Create a minimal instance just for testing the conversion method
Expand Down Expand Up @@ -92,7 +91,6 @@ async def test_convert_tools_to_provider_format(self):
assert "expression" in tool2["parameters"]["properties"]

@pytest.mark.integration
@pytest.mark.asyncio
async def test_live_function_calling_basic(self, realtime_instance):
"""Test basic live function calling with weather function."""
realtime = realtime_instance
Expand All @@ -116,7 +114,7 @@ def get_weather(location: str) -> Dict[str, str]:
# Set up event listeners for audio output
@realtime.events.subscribe
async def handle_audio_output(event: RealtimeAudioOutputEvent):
if event.audio_data:
if event.data:
# Audio was received - this indicates Gemini responded
text_responses.append("audio_response_received")

Expand All @@ -143,7 +141,6 @@ async def handle_response(event: RealtimeResponseEvent):
# Remove the text response assertion

@pytest.mark.integration
@pytest.mark.asyncio
async def test_live_function_calling_error_handling(self, realtime_instance):
"""Test live function calling with error handling."""
realtime = realtime_instance
Expand All @@ -164,7 +161,7 @@ def unreliable_function(input_data: str) -> Dict[str, Any]:
# Set up event listeners for audio output
@realtime.events.subscribe
async def handle_audio_output(event: RealtimeAudioOutputEvent):
if event.audio_data:
if event.data:
# Audio was received - this indicates Gemini responded
text_responses.append("audio_response_received")

Expand All @@ -191,7 +188,6 @@ async def handle_response(event: RealtimeResponseEvent):
assert len(text_responses) > 0, "No response received from Gemini"

@pytest.mark.integration
@pytest.mark.asyncio
async def test_live_function_calling_multiple_functions(self, realtime_instance):
"""Test live function calling with multiple functions in one request."""
realtime = realtime_instance
Expand All @@ -216,7 +212,7 @@ def get_status() -> Dict[str, str]:
# Set up event listeners for audio output
@realtime.events.subscribe
async def handle_audio_output(event: RealtimeAudioOutputEvent):
if event.audio_data:
if event.data:
# Audio was received - this indicates Gemini responded
text_responses.append("audio_response_received")

Expand Down Expand Up @@ -245,8 +241,7 @@ async def handle_response(event: RealtimeResponseEvent):
# Verify we got a response
assert len(text_responses) > 0, "No response received from Gemini"

@pytest.mark.asyncio
async def test_create_config_with_tools(self):
async def test_get_config_with_tools(self):
"""Test that tools are added to the config."""
# Create a minimal instance for testing config creation
realtime = gemini.Realtime(model="test-model", api_key="test-key")
Expand All @@ -256,7 +251,7 @@ async def test_create_config_with_tools(self):
def test_func(param: str) -> str:
return f"test: {param}"

config = realtime._get_config_with_resumption()
config = realtime.get_config()

# Verify tools were added
assert "tools" in config
Expand All @@ -265,13 +260,12 @@ def test_func(param: str) -> str:
assert len(config["tools"][0]["function_declarations"]) == 1
assert config["tools"][0]["function_declarations"][0]["name"] == "test_func"

@pytest.mark.asyncio
async def test_create_config_without_tools(self):
async def test_get_config_without_tools(self):
"""Test config creation when no tools are available."""
# Create a minimal instance without registering any functions
realtime = gemini.Realtime(model="test-model", api_key="test-key")

config = realtime._create_config()
config = realtime.get_config()

# Verify tools were not added
assert "tools" not in config
2 changes: 1 addition & 1 deletion plugins/gemini/vision_agents/plugins/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .gemini_realtime import Realtime
from .gemini_llm import GeminiLLM as LLM
from .gemini_realtime import GeminiRealtime as Realtime

__all__ = ["Realtime", "LLM"]
Loading