diff --git a/tests/mlmodel_strands/_mock_model_provider.py b/tests/mlmodel_strands/_mock_model_provider.py new file mode 100644 index 0000000000..e4c9e79930 --- /dev/null +++ b/tests/mlmodel_strands/_mock_model_provider.py @@ -0,0 +1,99 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test setup derived from: https://github.com/strands-agents/sdk-python/blob/main/tests/fixtures/mocked_model_provider.py +# strands Apache 2.0 license: https://github.com/strands-agents/sdk-python/blob/main/LICENSE + +import json +from typing import TypedDict + +from strands.models import Model + + +class RedactionMessage(TypedDict): + redactedUserContent: str + redactedAssistantContent: str + + +class MockedModelProvider(Model): + """A mock implementation of the Model interface for testing purposes. + + This class simulates a model provider by returning pre-defined agent responses + in sequence. It implements the Model interface methods and provides functionality + to stream mock responses as events. + """ + + def __init__(self, agent_responses): + self.agent_responses = agent_responses + self.index = 0 + + def format_chunk(self, event): + return event + + def format_request(self, messages, tool_specs=None, system_prompt=None): + return None + + def get_config(self): + pass + + def update_config(self, **model_config): + pass + + async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs): + pass + + async def stream(self, messages, tool_specs=None, system_prompt=None): + events = self.map_agent_message_to_events(self.agent_responses[self.index]) + for event in events: + yield event + + self.index += 1 + + def map_agent_message_to_events(self, agent_message): + stop_reason = "end_turn" + yield {"messageStart": {"role": "assistant"}} + if agent_message.get("redactedAssistantContent"): + yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}} + yield {"contentBlockStop": {}} + stop_reason = "guardrail_intervened" + else: + for content in agent_message["content"]: + if "reasoningContent" in content: + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"reasoningContent": content["reasoningContent"]}}} + yield {"contentBlockStop": {}} + if "text" in content: + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} + yield {"contentBlockStop": {}} + if "toolUse" in content: + stop_reason = "tool_use" + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "name": content["toolUse"]["name"], + "toolUseId": content["toolUse"]["toolUseId"], + } + } + } + } + yield { + "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}} + } + yield {"contentBlockStop": {}} + + yield {"messageStop": {"stopReason": stop_reason}} diff --git a/tests/mlmodel_strands/conftest.py b/tests/mlmodel_strands/conftest.py new file mode 100644 index 0000000000..b810161f6a --- /dev/null +++ b/tests/mlmodel_strands/conftest.py @@ -0,0 +1,144 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from _mock_model_provider import MockedModelProvider +from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture +from testing_support.ml_testing_utils import set_trace_info + +_default_settings = { + "package_reporting.enabled": False, # Turn off package reporting for testing as it causes slowdowns. + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "ai_monitoring.enabled": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (mlmodel_strands)", default_settings=_default_settings +) + + +@pytest.fixture +def single_tool_model(): + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}}, + ], + }, + {"role": "assistant", "content": [{"text": "Success!"}]}, + ] + ) + return model + + +@pytest.fixture +def single_tool_model_error(): + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + # Set arguments to an invalid type to trigger error in tool + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": 12}}}, + ], + }, + {"role": "assistant", "content": [{"text": "Success!"}]}, + ] + ) + return model + + +@pytest.fixture +def multi_tool_model(): + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling compute_sum tool"}, + {"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling compute_sum tool"}, + {"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123, "b": 2}}}, + ], + }, + {"role": "assistant", "content": [{"text": "Success!"}]}, + ] + ) + return model + + +@pytest.fixture +def multi_tool_model_error(): + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling compute_sum tool"}, + {"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling add_exclamation tool"}, + {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Calling compute_sum tool"}, + # Set insufficient arguments to trigger error in tool + {"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123}}}, + ], + }, + {"role": "assistant", "content": [{"text": "Success!"}]}, + ] + ) + return model diff --git a/tests/mlmodel_strands/test_simple.py b/tests/mlmodel_strands/test_simple.py new file mode 100644 index 0000000000..ae24003fab --- /dev/null +++ b/tests/mlmodel_strands/test_simple.py @@ -0,0 +1,36 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from strands import Agent, tool + +from newrelic.api.background_task import background_task + + +# Example tool for testing purposes +@tool +def add_exclamation(message: str) -> str: + return f"{message}!" + + +# TODO: Remove this file once all real tests are in place + + +@background_task() +def test_simple_run_agent(set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + response = my_agent("Run the tools.") + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 diff --git a/tox.ini b/tox.ini index 39148b657f..ace7839db3 100644 --- a/tox.ini +++ b/tox.ini @@ -182,6 +182,7 @@ envlist = python-logger_structlog-{py38,py39,py310,py311,py312,py313,py314,pypy311}-structloglatest, python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogen061, python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogenlatest, + python-mlmodel_strands-{py310,py311,py312,py313}-strandslatest, python-mlmodel_gemini-{py39,py310,py311,py312,py313,py314}, python-mlmodel_langchain-{py39,py310,py311,py312,py313}, ;; Package not ready for Python 3.14 (type annotations not updated) @@ -440,6 +441,8 @@ deps = mlmodel_langchain: faiss-cpu mlmodel_langchain: mock mlmodel_langchain: asyncio + mlmodel_strands: strands-agents[openai] + mlmodel_strands: strands-agents-tools logger_loguru-logurulatest: loguru logger_structlog-structloglatest: structlog messagebroker_pika-pikalatest: pika @@ -510,6 +513,7 @@ changedir = application_celery: tests/application_celery component_djangorestframework: tests/component_djangorestframework component_flask_rest: tests/component_flask_rest + component_graphenedjango: tests/component_graphenedjango component_graphqlserver: tests/component_graphqlserver component_tastypie: tests/component_tastypie coroutines_asyncio: tests/coroutines_asyncio @@ -521,17 +525,17 @@ changedir = datastore_cassandradriver: tests/datastore_cassandradriver datastore_elasticsearch: tests/datastore_elasticsearch datastore_firestore: tests/datastore_firestore - datastore_oracledb: tests/datastore_oracledb datastore_memcache: tests/datastore_memcache + datastore_motor: tests/datastore_motor datastore_mysql: tests/datastore_mysql datastore_mysqldb: tests/datastore_mysqldb + datastore_oracledb: tests/datastore_oracledb datastore_postgresql: tests/datastore_postgresql datastore_psycopg: tests/datastore_psycopg datastore_psycopg2: tests/datastore_psycopg2 datastore_psycopg2cffi: tests/datastore_psycopg2cffi datastore_pylibmc: tests/datastore_pylibmc datastore_pymemcache: tests/datastore_pymemcache - datastore_motor: tests/datastore_motor datastore_pymongo: tests/datastore_pymongo datastore_pymssql: tests/datastore_pymssql datastore_pymysql: tests/datastore_pymysql @@ -539,8 +543,8 @@ changedir = datastore_pysolr: tests/datastore_pysolr datastore_redis: tests/datastore_redis datastore_rediscluster: tests/datastore_rediscluster - datastore_valkey: tests/datastore_valkey datastore_sqlite: tests/datastore_sqlite + datastore_valkey: tests/datastore_valkey external_aiobotocore: tests/external_aiobotocore external_botocore: tests/external_botocore external_feedparser: tests/external_feedparser @@ -561,7 +565,6 @@ changedir = framework_fastapi: tests/framework_fastapi framework_flask: tests/framework_flask framework_graphene: tests/framework_graphene - component_graphenedjango: tests/component_graphenedjango framework_graphql: tests/framework_graphql framework_grpc: tests/framework_grpc framework_pyramid: tests/framework_pyramid @@ -581,6 +584,7 @@ changedir = mlmodel_langchain: tests/mlmodel_langchain mlmodel_openai: tests/mlmodel_openai mlmodel_sklearn: tests/mlmodel_sklearn + mlmodel_strands: tests/mlmodel_strands template_genshi: tests/template_genshi template_jinja2: tests/template_jinja2 template_mako: tests/template_mako