Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/mlmodel_strands/_mock_model_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Test setup derived from: https://github.com/strands-agents/sdk-python/blob/main/tests/fixtures/mocked_model_provider.py
# strands Apache 2.0 license: https://github.com/strands-agents/sdk-python/blob/main/LICENSE

import json
from typing import TypedDict

from strands.models import Model


class RedactionMessage(TypedDict):
redactedUserContent: str
redactedAssistantContent: str


class MockedModelProvider(Model):
"""A mock implementation of the Model interface for testing purposes.

This class simulates a model provider by returning pre-defined agent responses
in sequence. It implements the Model interface methods and provides functionality
to stream mock responses as events.
"""

def __init__(self, agent_responses):
self.agent_responses = agent_responses
self.index = 0

def format_chunk(self, event):
return event

def format_request(self, messages, tool_specs=None, system_prompt=None):
return None

def get_config(self):
pass

def update_config(self, **model_config):
pass

async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs):
pass

async def stream(self, messages, tool_specs=None, system_prompt=None):
events = self.map_agent_message_to_events(self.agent_responses[self.index])
for event in events:
yield event

self.index += 1

def map_agent_message_to_events(self, agent_message):
stop_reason = "end_turn"
yield {"messageStart": {"role": "assistant"}}
if agent_message.get("redactedAssistantContent"):
yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}}
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}}
yield {"contentBlockStop": {}}
stop_reason = "guardrail_intervened"
else:
for content in agent_message["content"]:
if "reasoningContent" in content:
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"reasoningContent": content["reasoningContent"]}}}
yield {"contentBlockStop": {}}
if "text" in content:
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
yield {"contentBlockStop": {}}
if "toolUse" in content:
stop_reason = "tool_use"
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"name": content["toolUse"]["name"],
"toolUseId": content["toolUse"]["toolUseId"],
}
}
}
}
yield {
"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}
}
yield {"contentBlockStop": {}}

yield {"messageStop": {"stopReason": stop_reason}}
144 changes: 144 additions & 0 deletions tests/mlmodel_strands/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from _mock_model_provider import MockedModelProvider
from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture
from testing_support.ml_testing_utils import set_trace_info

_default_settings = {
"package_reporting.enabled": False, # Turn off package reporting for testing as it causes slowdowns.
"transaction_tracer.explain_threshold": 0.0,
"transaction_tracer.transaction_threshold": 0.0,
"transaction_tracer.stack_trace_threshold": 0.0,
"debug.log_data_collector_payloads": True,
"debug.record_transaction_failure": True,
"ai_monitoring.enabled": True,
}

collector_agent_registration = collector_agent_registration_fixture(
app_name="Python Agent Test (mlmodel_strands)", default_settings=_default_settings
)


@pytest.fixture
def single_tool_model():
model = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
],
},
{"role": "assistant", "content": [{"text": "Success!"}]},
]
)
return model


@pytest.fixture
def single_tool_model_error():
model = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
# Set arguments to an invalid type to trigger error in tool
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": 12}}},
],
},
{"role": "assistant", "content": [{"text": "Success!"}]},
]
)
return model


@pytest.fixture
def multi_tool_model():
model = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling compute_sum tool"},
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling compute_sum tool"},
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123, "b": 2}}},
],
},
{"role": "assistant", "content": [{"text": "Success!"}]},
]
)
return model


@pytest.fixture
def multi_tool_model_error():
model = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling compute_sum tool"},
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling add_exclamation tool"},
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Calling compute_sum tool"},
# Set insufficient arguments to trigger error in tool
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123}}},
],
},
{"role": "assistant", "content": [{"text": "Success!"}]},
]
)
return model
36 changes: 36 additions & 0 deletions tests/mlmodel_strands/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from strands import Agent, tool

from newrelic.api.background_task import background_task


# Example tool for testing purposes
@tool
def add_exclamation(message: str) -> str:
return f"{message}!"


# TODO: Remove this file once all real tests are in place


@background_task()
def test_simple_run_agent(set_trace_info, single_tool_model):
set_trace_info()
my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation])

response = my_agent("Run the tools.")
assert response.message["content"][0]["text"] == "Success!"
assert response.metrics.tool_metrics["add_exclamation"].success_count == 1
12 changes: 8 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ envlist =
python-logger_structlog-{py38,py39,py310,py311,py312,py313,py314,pypy311}-structloglatest,
python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogen061,
python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogenlatest,
python-mlmodel_strands-{py310,py311,py312,py313}-strandslatest,
python-mlmodel_gemini-{py39,py310,py311,py312,py313,py314},
python-mlmodel_langchain-{py39,py310,py311,py312,py313},
;; Package not ready for Python 3.14 (type annotations not updated)
Expand Down Expand Up @@ -440,6 +441,8 @@ deps =
mlmodel_langchain: faiss-cpu
mlmodel_langchain: mock
mlmodel_langchain: asyncio
mlmodel_strands: strands-agents[openai]
mlmodel_strands: strands-agents-tools
logger_loguru-logurulatest: loguru
logger_structlog-structloglatest: structlog
messagebroker_pika-pikalatest: pika
Expand Down Expand Up @@ -510,6 +513,7 @@ changedir =
application_celery: tests/application_celery
component_djangorestframework: tests/component_djangorestframework
component_flask_rest: tests/component_flask_rest
component_graphenedjango: tests/component_graphenedjango
component_graphqlserver: tests/component_graphqlserver
component_tastypie: tests/component_tastypie
coroutines_asyncio: tests/coroutines_asyncio
Expand All @@ -521,26 +525,26 @@ changedir =
datastore_cassandradriver: tests/datastore_cassandradriver
datastore_elasticsearch: tests/datastore_elasticsearch
datastore_firestore: tests/datastore_firestore
datastore_oracledb: tests/datastore_oracledb
datastore_memcache: tests/datastore_memcache
datastore_motor: tests/datastore_motor
datastore_mysql: tests/datastore_mysql
datastore_mysqldb: tests/datastore_mysqldb
datastore_oracledb: tests/datastore_oracledb
datastore_postgresql: tests/datastore_postgresql
datastore_psycopg: tests/datastore_psycopg
datastore_psycopg2: tests/datastore_psycopg2
datastore_psycopg2cffi: tests/datastore_psycopg2cffi
datastore_pylibmc: tests/datastore_pylibmc
datastore_pymemcache: tests/datastore_pymemcache
datastore_motor: tests/datastore_motor
datastore_pymongo: tests/datastore_pymongo
datastore_pymssql: tests/datastore_pymssql
datastore_pymysql: tests/datastore_pymysql
datastore_pyodbc: tests/datastore_pyodbc
datastore_pysolr: tests/datastore_pysolr
datastore_redis: tests/datastore_redis
datastore_rediscluster: tests/datastore_rediscluster
datastore_valkey: tests/datastore_valkey
datastore_sqlite: tests/datastore_sqlite
datastore_valkey: tests/datastore_valkey
external_aiobotocore: tests/external_aiobotocore
external_botocore: tests/external_botocore
external_feedparser: tests/external_feedparser
Expand All @@ -561,7 +565,6 @@ changedir =
framework_fastapi: tests/framework_fastapi
framework_flask: tests/framework_flask
framework_graphene: tests/framework_graphene
component_graphenedjango: tests/component_graphenedjango
framework_graphql: tests/framework_graphql
framework_grpc: tests/framework_grpc
framework_pyramid: tests/framework_pyramid
Expand All @@ -581,6 +584,7 @@ changedir =
mlmodel_langchain: tests/mlmodel_langchain
mlmodel_openai: tests/mlmodel_openai
mlmodel_sklearn: tests/mlmodel_sklearn
mlmodel_strands: tests/mlmodel_strands
template_genshi: tests/template_genshi
template_jinja2: tests/template_jinja2
template_mako: tests/template_mako
Expand Down
Loading