Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b52d755
Set asyncio_default_test_loop_scope to avoid warning about the value …
dagardner-nv May 20, 2025
7ac47f3
Define the default event loop scope of asynchronous tests avoids warn…
dagardner-nv May 20, 2025
f7da997
Silence known deprecation warnings being triggered by weave
dagardner-nv May 20, 2025
fb6d99c
Rename config classes such that pytest doesn't believe they are test …
dagardner-nv May 20, 2025
57030b1
Rename config classes such that pytest doesn't believe they are test …
dagardner-nv May 20, 2025
abd0c44
Replace usage of update_forward_refs() with model_rebuild() per https…
dagardner-nv May 20, 2025
70169d4
Silence expected builder warnings with explicitly testing that they a…
dagardner-nv May 20, 2025
3feefc4
Silence expected warning at test completion
dagardner-nv May 20, 2025
4ea2590
Replace deprecated dict method with model_dump
dagardner-nv May 20, 2025
99627f7
Ignore expected warning about the mock method not being called
dagardner-nv May 20, 2025
c280f82
Silence expected warning
dagardner-nv May 20, 2025
8268d6a
Replace deprecated MultiCommand with Group
dagardner-nv May 20, 2025
a602c1a
Avoid deprecation warning about calling apply on dataframe group http…
dagardner-nv May 20, 2025
02c63a1
Always await the search method, fix type-o in collection name
dagardner-nv May 20, 2025
d8b5a67
Remove edgecolor argument which was being overriden by the color argu…
dagardner-nv May 20, 2025
e5c945c
Replace deprecated import of DeterministicFakeEmbedding, add missing …
dagardner-nv May 20, 2025
dc367a9
Attempt to await all coroutines, filter warnings for the rest
dagardner-nv May 20, 2025
41283c3
Merge branch 'develop' of github.com:NVIDIA/AIQToolkit into david-rem…
dagardner-nv May 20, 2025
27e4bec
Revert async hacks
dagardner-nv May 21, 2025
7568c95
Revert async hacks
dagardner-nv May 21, 2025
f60f859
Create a fixture that yields a loop that runs in another thread
dagardner-nv May 21, 2025
0f7b857
Make the RunLoopThread class private to the fixture, update tests to …
dagardner-nv May 21, 2025
72d7fe2
Avoid mocking the asyncio.run_coroutine_threadsafe method
dagardner-nv May 21, 2025
b4b2396
Check session_manager.workflow.has_single_output early before calling…
dagardner-nv May 21, 2025
012bac0
Ensure that any running coroutines are canceled in the exception handler
dagardner-nv May 21, 2025
88c6605
Don't patch the AIQSessionManager constructor, instead just pass in s…
dagardner-nv May 21, 2025
5712df3
Ensure the NVIDIA_API_KEY environment variable is set (to a fake valu…
dagardner-nv May 21, 2025
0adc43f
Silence syntax warnings being emitted from the qdrant-client (used by…
dagardner-nv May 21, 2025
acaeaf1
Fix return type hint, and avoid unneeded import
dagardner-nv May 21, 2025
1d25deb
Replace the warnings.catch_warnings statement with a global pytest ig…
dagardner-nv May 21, 2025
e13e349
Fix the warning filter
dagardner-nv May 21, 2025
2346a46
Remove unused import
dagardner-nv May 21, 2025
25967e8
Merge branch 'develop' of github.com:NVIDIA/AIQToolkit into david-rem…
dagardner-nv May 21, 2025
c08f320
Remove unintented change
dagardner-nv May 21, 2025
ca738c8
Warning is different in Python 3.11
dagardner-nv May 21, 2025
c4010df
Fix toml
dagardner-nv May 21, 2025
b438189
Fix ignore for 3.11
dagardner-nv May 22, 2025
363e4bd
Remove attempt to filter the 3.11 deprecation warning
dagardner-nv May 22, 2025
a38c941
Explcitly cast entry point objects to a list to avoid a warning in Py…
dagardner-nv May 22, 2025
f1023e3
base_output is not a coroutine
dagardner-nv May 27, 2025
5d06d6b
Merge branch 'develop' of github.com:NVIDIA/AIQToolkit into david-rem…
dagardner-nv May 27, 2025
24b3a37
Revert unintended change
dagardner-nv May 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vale.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MinAlertLevel = error
Vocab = aiq

# Configs for markdown and reStructuredText files
[*{.md,.rst}]
[*{.md,.mdx,.rst}]

BasedOnStyles = Vale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def create_trace_flow_diagram(df: pd.DataFrame, temp_dir: str) -> TraceFlowInfo:

# Draw span box
color = colors.get(span_kind, "lightgray")
rect = plt.Rectangle((x_start, 0.5 * i - 0.2), x_end - x_start, 0.4, color=color, alpha=0.8, edgecolor="black")
rect = plt.Rectangle((x_start, 0.5 * i - 0.2), x_end - x_start, 0.4, color=color, alpha=0.8)

ax.add_patch(rect)

Expand Down
123 changes: 63 additions & 60 deletions packages/aiqtoolkit_agno/tests/test_tool_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import threading
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
Expand All @@ -28,6 +30,40 @@
from aiq.plugins.agno.tool_wrapper import process_result


@pytest.fixture(name="run_loop_thread")
def fixture_run_loop_thread():
"""
Fixture to create an asyncio event loop running in another thread.
Useful for creating a loop that can be used with the asyncio.run_coroutine_threadsafe function.
"""

class RunLoopThread(threading.Thread):

def __init__(self, loop: asyncio.AbstractEventLoop, release_event: threading.Event):
super().__init__()
self._loop = loop
self._release_event = release_event

def run(self):
asyncio.set_event_loop(self._loop)
self._release_event.set()
self._loop.run_forever()

loop = asyncio.new_event_loop()
release_event = threading.Event()
thread = RunLoopThread(loop=loop, release_event=release_event)
thread.start()

# Wait for the thread to set the event
release_event.wait()

yield loop

# Stop the loop and join the thread
loop.call_soon_threadsafe(loop.stop)
thread.join()


class TestToolWrapper:
"""Tests for the agno_tool_wrapper function."""

Expand Down Expand Up @@ -205,135 +241,104 @@ async def mock_acall_invoke(*args, **kwargs):

@patch("aiq.plugins.agno.tool_wrapper._tool_call_counters", {})
@patch("aiq.plugins.agno.tool_wrapper._tool_initialization_done", {})
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_execute_agno_tool_initialization(self, mock_run_coroutine_threadsafe, mock_event_loop):
def test_execute_agno_tool_initialization(self, run_loop_thread: asyncio.AbstractEventLoop):
"""Test that execute_agno_tool correctly handles tool initialization."""
# Set up the mock future
mock_future = MagicMock()
mock_future.result.return_value = "initialization_result"
mock_run_coroutine_threadsafe.return_value = mock_future

# Create a mock coroutine function
mock_coroutine_fn = AsyncMock()
mock_coroutine_fn.return_value = "initialization_result"

# Call the function under test for a tool with an empty kwargs dict (initialization)
result = execute_agno_tool("test_tool", mock_coroutine_fn, ["query"], mock_event_loop)
result = execute_agno_tool("test_tool", mock_coroutine_fn, ["query"], run_loop_thread)

# Verify that the counters and initialization flags were set correctly
from aiq.plugins.agno.tool_wrapper import _tool_call_counters
from aiq.plugins.agno.tool_wrapper import _tool_initialization_done
assert "test_tool" in _tool_call_counters
assert "test_tool" in _tool_initialization_done

# Verify that run_coroutine_threadsafe was called with the coroutine function
mock_run_coroutine_threadsafe.assert_called()
# Verify that the coroutine function was called
mock_coroutine_fn.assert_called_once_with()

# Verify the result
assert result == "initialization_result"

@patch("aiq.plugins.agno.tool_wrapper._tool_call_counters", {"search_api_tool": 0})
@patch("aiq.plugins.agno.tool_wrapper._tool_initialization_done", {"search_api_tool": True})
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_execute_agno_tool_search_api_empty_query(self, mock_run_coroutine_threadsafe, mock_event_loop):
def test_execute_agno_tool_search_api_empty_query(self, run_loop_thread):
"""Test that execute_agno_tool correctly handles search API tools with empty queries."""
# Create a mock coroutine function
mock_coroutine_fn = AsyncMock()

# Call the function under test for a search tool with an empty query
result = execute_agno_tool("search_api_tool", mock_coroutine_fn, ["query"], mock_event_loop, query="")
result = execute_agno_tool("search_api_tool", mock_coroutine_fn, ["query"], run_loop_thread, query="")

# Verify that an error message is returned for empty query after initialization
assert "ERROR" in result
assert "requires a valid query" in result

# Verify that run_coroutine_threadsafe was not called since we blocked the empty query
mock_run_coroutine_threadsafe.assert_not_called()
# Verify that coroutine was not called since we called execute_agno_tool with an empty query
mock_coroutine_fn.assert_not_called()

@patch("aiq.plugins.agno.tool_wrapper._tool_call_counters", {"test_tool": 0})
@patch("aiq.plugins.agno.tool_wrapper._tool_initialization_done", {"test_tool": False})
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_execute_agno_tool_filtered_kwargs(self, mock_run_coroutine_threadsafe, mock_event_loop):
def test_execute_agno_tool_filtered_kwargs(self, run_loop_thread: asyncio.AbstractEventLoop):
"""Test that execute_agno_tool correctly filters reserved keywords."""
# Set up the mock future
mock_future = MagicMock()
mock_future.result.return_value = "filtered_result"
mock_run_coroutine_threadsafe.return_value = mock_future

# Create a mock process_result future
process_future = MagicMock()
process_future.result.return_value = "processed_result"
mock_run_coroutine_threadsafe.side_effect = [mock_future, process_future]

# Create a mock coroutine function
mock_coroutine_fn = AsyncMock()
mock_coroutine_fn.return_value = "processed_result"

# Call the function under test with kwargs containing reserved keywords
result = execute_agno_tool("test_tool",
mock_coroutine_fn, ["query"],
mock_event_loop,
run_loop_thread,
query="test query",
model_config="should be filtered",
_type="should be filtered")

# Verify that run_coroutine_threadsafe was called with filtered kwargs
args, kwargs = mock_coroutine_fn.call_args
assert "query" in kwargs
assert "model_config" not in kwargs
assert "_type" not in kwargs
# Verify that mock_coroutine_fn was called with filtered kwargs
mock_coroutine_fn.assert_called_once_with(query="test query")

# Verify the result
assert result == "processed_result"

@patch("aiq.plugins.agno.tool_wrapper._tool_call_counters", {"test_tool": 0})
@patch("aiq.plugins.agno.tool_wrapper._tool_initialization_done", {"test_tool": False})
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_execute_agno_tool_wrapped_kwargs(self, mock_run_coroutine_threadsafe, mock_event_loop):
def test_execute_agno_tool_wrapped_kwargs(self, run_loop_thread: asyncio.AbstractEventLoop):
"""Test that execute_agno_tool correctly unwraps nested kwargs."""
# Set up the mock future
mock_future = MagicMock()
mock_future.result.return_value = "unwrapped_result"

# Create a mock process_result future
process_future = MagicMock()
process_future.result.return_value = "processed_result"
mock_run_coroutine_threadsafe.side_effect = [mock_future, process_future]

# Create a mock coroutine function
mock_coroutine_fn = AsyncMock()
mock_coroutine_fn.return_value = "processed_result"

# Call the function under test with wrapped kwargs
result = execute_agno_tool("test_tool",
mock_coroutine_fn, ["query"],
mock_event_loop,
run_loop_thread,
kwargs={
"query": "test query", "other_param": "value"
})

# Verify that run_coroutine_threadsafe was called with unwrapped kwargs
args, kwargs = mock_coroutine_fn.call_args
assert "query" in kwargs
assert kwargs["query"] == "test query"
assert "other_param" in kwargs
assert kwargs["other_param"] == "value"
# Verify that mock_coroutine_fn was called with unwrapped kwargs
mock_coroutine_fn.assert_called_once_with(query="test query", other_param="value")

# Verify the result
assert result == "processed_result"

@patch("aiq.plugins.agno.tool_wrapper._tool_call_counters", {"test_tool": 0})
@patch("aiq.plugins.agno.tool_wrapper._MAX_EMPTY_CALLS", 2)
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_execute_agno_tool_infinite_loop_detection(self, mock_run_coroutine_threadsafe, mock_event_loop):
def test_execute_agno_tool_infinite_loop_detection(self, run_loop_thread: asyncio.AbstractEventLoop):
"""Test that execute_agno_tool detects and prevents infinite loops."""
# Create a mock coroutine function
mock_coroutine_fn = AsyncMock()

# First call with only metadata should increment counter but proceed
execute_agno_tool("test_tool", mock_coroutine_fn, ["query"], mock_event_loop, model_config="metadata only")
execute_agno_tool("test_tool", mock_coroutine_fn, ["query"], run_loop_thread, model_config="metadata only")

# Second call with only metadata should detect potential infinite loop
result2 = execute_agno_tool("test_tool",
mock_coroutine_fn, ["query"],
mock_event_loop,
run_loop_thread,
model_config="metadata only")

# Verify that the second call returned an error about infinite loops
Expand Down Expand Up @@ -412,8 +417,11 @@ def __init__(self, choices):
assert result == "OpenAI response content"

@patch("aiq.plugins.agno.tool_wrapper.tool")
@patch("aiq.plugins.agno.tool_wrapper.asyncio.run_coroutine_threadsafe")
def test_different_calling_styles(self, mock_run_coroutine_threadsafe, mock_tool, mock_function, mock_builder):
def test_different_calling_styles(self,
mock_tool,
mock_function,
mock_builder,
run_loop_thread: asyncio.AbstractEventLoop):
"""Test that execute_agno_tool handles different function calling styles."""
# Mock the tool decorator to return a function that returns its input
mock_tool.return_value = lambda x: x
Expand All @@ -428,11 +436,6 @@ def test_different_calling_styles(self, mock_run_coroutine_threadsafe, mock_tool
process_future = MagicMock()
process_future.result.return_value = "processed_result"

mock_run_coroutine_threadsafe.side_effect = [future1, future2, process_future]

# Create a mock coroutine function
AsyncMock()

# Call the function under test
wrapper_func = agno_tool_wrapper("test_tool", mock_function, mock_builder)

Expand Down
3 changes: 1 addition & 2 deletions packages/aiqtoolkit_mem0ai/tests/test_mem0_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
from unittest.mock import AsyncMock

import pytest
from mem0 import AsyncMemoryClient

from aiq.memory.models import MemoryItem
from aiq.plugins.mem0ai.mem0_editor import Mem0Editor


@pytest.fixture(name="mock_mem0_client")
def mock_mem0_client_fixture() -> AsyncMemoryClient:
def mock_mem0_client_fixture() -> AsyncMock:
"""Fixture to provide a mocked AsyncMemoryClient."""
return AsyncMock()

Expand Down
1 change: 1 addition & 0 deletions packages/aiqtoolkit_test/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
# version when adding a new package. If unsure, default to using `~=` instead of `==`. Does not apply to aiq packages.
# Keep sorted!!!
"aiqtoolkit~=1.2",
"langchain-community~=0.3",
"pytest~=8.3",
]
requires-python = ">=3.11,<3.13"
Expand Down
2 changes: 1 addition & 1 deletion packages/aiqtoolkit_test/src/aiq/test/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ async def embedder_test_provider(config: EmbedderTestConfig, builder: Builder):
@register_embedder_client(config_type=EmbedderTestConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
async def embedder_langchain_test_client(config: EmbedderTestConfig, builder: Builder):

from langchain.embeddings import DeterministicFakeEmbedding
from langchain_community.embeddings import DeterministicFakeEmbedding

yield DeterministicFakeEmbedding(size=config.embedding_size)
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,17 @@ markers = [
"slow: Slow tests",
]
filterwarnings = [
# Add warnings to ignore as a part of pytest here
# Ignore warnings from qdrant-client (used by mem0) with Python 3.12+ of note is that this only happens the first
# time the module is imported and parsed, after that the pyc files in the __pycache__ directory are used which don't
# trigger the warnings. In Python 3.12 this triggers a SyntaxWarning, in Python 3.11 it triggers a DeprecationWarning
# which unfortunately pytest is unable to filter.
# Remove once https://github.com/qdrant/qdrant-client/issues/983 is resolved.
"ignore:^invalid escape sequence:SyntaxWarning"
]
testpaths = ["tests", "examples/*/tests", "packages/*/tests"]
asyncio_mode = "auto"
pytest_plugins = ["aiqtoolkit-test"]
asyncio_default_fixture_loop_scope = "session"


## Pylint configuration begins here

Expand Down
2 changes: 1 addition & 1 deletion src/aiq/cli/commands/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
logger = logging.getLogger(__name__)


class StartCommandGroup(click.MultiCommand):
class StartCommandGroup(click.Group):

# pylint: disable=too-many-positional-arguments
def __init__(
Expand Down
22 changes: 16 additions & 6 deletions src/aiq/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,18 @@ async def run_one(item: EvalInputItem):
return "", []

async with session_manager.run(item.input_obj) as runner:
if not session_manager.workflow.has_single_output:
# raise an error if the workflow has multiple outputs
raise NotImplementedError("Multiple outputs are not supported")

base_output = None
intermediate_future = None

try:

# Start usage stats and intermediate steps collection in parallel
intermediate_future = pull_intermediate()

if session_manager.workflow.has_single_output:
base_output = await runner.result()
else:
# raise an error if the workflow has multiple outputs
raise NotImplementedError("Multiple outputs are not supported")
base_output = await runner.result()
intermediate_steps = await intermediate_future
except NotImplementedError as e:
# raise original error
Expand All @@ -101,6 +104,13 @@ async def run_one(item: EvalInputItem):
logger.exception("Failed to run the workflow: %s", e, exc_info=True)
# stop processing if a workflow error occurs
self.workflow_interrupted = True

# Cancel any coroutines that are still running, avoiding a warning about unawaited coroutines
# (typically one of these two is what raised the exception and the other is still running)
for coro in (base_output, intermediate_future):
if coro is not None:
asyncio.ensure_future(coro).cancel()

stop_event.set()
return

Expand Down
19 changes: 15 additions & 4 deletions src/aiq/observability/async_otel_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import re
import warnings
from contextlib import asynccontextmanager
from contextlib import contextmanager
from typing import Any
Expand All @@ -30,10 +31,20 @@
from aiq.utils.optional_imports import try_import_opentelemetry

try:
from weave.trace.context import weave_client_context
from weave.trace.context.call_context import get_current_call
from weave.trace.context.call_context import set_call_stack
from weave.trace.weave_client import Call
with warnings.catch_warnings():
# Ignore deprecation warnings being triggered by weave. https://github.com/wandb/weave/issues/3666
# and https://github.com/wandb/weave/issues/4533
warnings.filterwarnings("ignore", category=DeprecationWarning, message=r"^`sentry_sdk\.Hub` is deprecated")
warnings.filterwarnings("ignore",
category=DeprecationWarning,
message=r"^Using extra keyword arguments on `Field` is deprecated")
warnings.filterwarnings("ignore",
category=DeprecationWarning,
message=r"^`include` is deprecated and does nothing")
from weave.trace.context import weave_client_context
from weave.trace.context.call_context import get_current_call
from weave.trace.context.call_context import set_call_stack
from weave.trace.weave_client import Call
WEAVE_AVAILABLE = True
except ImportError:
WEAVE_AVAILABLE = False
Expand Down
Loading