From 5436f4d45ee8aed555105fc53edda9dd13ae66d6 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 18 Jun 2025 20:51:02 -0400 Subject: [PATCH 1/4] test(memory): add an end-to-end integration test --- pyproject.toml | 27 +- src/strands_tools/mem0_memory.py | 2 +- tests-integ/test_memory_tool.py | 461 +++++++++++++++++++++++++++++++ 3 files changed, 480 insertions(+), 10 deletions(-) create mode 100644 tests-integ/test_memory_tool.py diff --git a/pyproject.toml b/pyproject.toml index 3042e361..71a8a802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,9 +83,9 @@ mem0_memory = [ [tool.hatch.envs.hatch-static-analysis] features = ["mem0_memory"] dependencies = [ - "strands-agents>=0.1.0,<1.0.0", - "mypy>=0.981,<1.0.0", - "ruff>=0.4.4,<0.5.0", + "strands-agents>=0.1.0,<1.0.0", + "mypy>=0.981,<1.0.0", + "ruff>=0.4.4,<0.5.0", ] [tool.hatch.envs.hatch-static-analysis.scripts] @@ -140,7 +140,7 @@ list = [ "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'" ] format = [ - "hatch fmt --formatter", + "hatch fmt --formatter", ] test-format = [ "hatch fmt --formatter --check", @@ -154,6 +154,9 @@ test-lint = [ test = [ "hatch test --cover --cov-report html --cov-report xml {args}" ] +test-integ = [ + "hatch test tests-integ {args}" +] [tool.mypy] python_version = "3.10" @@ -173,14 +176,14 @@ ignore_missing_imports = false [tool.ruff] line-length = 120 -include = ["src/**/*.py", "tests/**/*.py"] +include = ["src/**/*.py", "tests/**/*.py","tests-integ/**/*.py"] [tool.ruff.lint] select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "B", # flake8-bugbear + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear ] [tool.coverage.run] @@ -196,6 +199,12 @@ directory = "build/coverage/html" [tool.coverage.xml] output = "build/coverage/coverage.xml" +[tool.pytest.ini_options] +testpaths = [ + "tests" +] +pythonpath = ["src"] + [tool.commitizen] name = "cz_conventional_commits" tag_format = "v$version" diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index d9849814..5840deaa 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -140,7 +140,7 @@ "description": "Optional metadata to store with the memory", }, }, - "required": ["action"] + "required": ["action"], } }, } diff --git a/tests-integ/test_memory_tool.py b/tests-integ/test_memory_tool.py new file mode 100644 index 00000000..a12b6e52 --- /dev/null +++ b/tests-integ/test_memory_tool.py @@ -0,0 +1,461 @@ +""" +Integration test for the Bedrock Knowledge Base (memory) tool. + +This test creates real AWS resources (IAM Role, OpenSearch Collection, Bedrock KB) +to validate the end-to-end functionality of the memory tool. It is designed to +be resilient and re-use existing resources to speed up local development. + +Prerequisites: +- AWS credentials must be configured in the environment. +- The AWS identity must have permissions to manage IAM, OpenSearch Serverless, + and Bedrock KB resources. + +Configuration: +- STRANDS_TEARDOWN_RESOURCES (env var): Set to "false" to prevent the automatic + deletion of AWS resources after the test run. This is useful for speeding + up local development by re-using the same Knowledge Base. +""" + +import json +import logging +import os +import time +import uuid +from unittest.mock import patch + +import boto3 +import pytest +from botocore.exceptions import ClientError +from opensearchpy import ( + AuthenticationException, + AuthorizationException, + AWSV4SignerAuth, + OpenSearch, + RequestsHttpConnection, +) +from strands import Agent +from strands_tools import memory + +logger = logging.getLogger(__name__) + +AWS_REGION = "us-east-1" +# Use a standard embedding model for the KB +EMBEDDING_MODEL_ARN = f"arn:aws:bedrock:{AWS_REGION}::foundation-model/amazon.titan-embed-text-v1" +EMBEDDING_DIMENSION = 1536 # Dimension for amazon.titan-embed-text-v1 + + +def _get_boto_clients(): + """Returns a dictionary of boto3 clients needed for the test.""" + return { + "iam": boto3.client("iam", region_name=AWS_REGION), + "bedrock-agent": boto3.client("bedrock-agent", region_name=AWS_REGION), + "opensearchserverless": boto3.client("opensearchserverless", region_name=AWS_REGION), + "sts": boto3.client("sts", region_name=AWS_REGION), + } + + +def _wait_for_resource( + poll_function, + resource_name, + success_status="ACTIVE", + failure_status="FAILED", + timeout_seconds=600, + delay_seconds=30, +): + """Generic waiter for AWS asynchronous operations with detailed logging.""" + logger.info(f"Waiting for '{resource_name}' to become '{success_status}'...") + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = poll_function() + status = response.get("status") or response.get("Status") + logger.info(f"Polling status for '{resource_name}': {status} (Elapsed: {int(time.time() - start_time)}s)") + if status == success_status: + logger.info(f"SUCCESS: Resource '{resource_name}' reached '{success_status}' state.") + return response + if status == failure_status: + logger.error( + f"FAILURE: Resource '{resource_name}' entered failure state: {failure_status}-Response: {response}" + ) + raise Exception(f"Resource '{resource_name}' entered failure state: {failure_status}") + except ClientError as e: + if "ResourceNotFoundException" not in str(e): + raise e + logger.info(f"Resource '{resource_name}' not found yet, continuing to wait...") + time.sleep(delay_seconds) + raise TimeoutError(f"Timed out waiting for resource '{resource_name}' to become '{success_status}'.") + + +@pytest.fixture(scope="module") +def managed_knowledge_base(): + """ + Pytest fixture to create and tear down a Bedrock Knowledge Base and its dependencies. + It will re-use existing resources if found to speed up local test runs. + Teardown can be skipped by setting the STRANDS_TEARDOWN_RESOURCES env var to "false". + """ + clients = _get_boto_clients() + + resource_names = { + "role_name": "StrandsMemoryIntegTestRole", + "kb_name": "strands-memory-integ-test-kb", + "ds_name": "strands-memory-integ-test-ds", + "policy_name": "StrandsMemoryIntegTestPolicy", + "collection_name": "strands-memory-integ-coll", + "enc_policy_name": "strands-memory-enc-policy", + "net_policy_name": "strands-memory-net-policy", + "access_policy_name": "strands-memory-access-policy", + "vector_index_name": "bedrock-kb-index", + } + created_resources = {} + + # Check if Knowledge Base already exists + try: + kbs = clients["bedrock-agent"].list_knowledge_bases() + existing_kb = next( + (kb for kb in kbs.get("knowledgeBaseSummaries", []) if kb["name"] == resource_names["kb_name"]), None + ) + if existing_kb and existing_kb["status"] == "ACTIVE": + logger.info(f"Found existing and ACTIVE Knowledge Base '{resource_names['kb_name']}'. Re-using for test.") + yield existing_kb["knowledgeBaseId"] + return # Skip creation if we are reusing + except ClientError as e: + logger.error(f"Error checking for existing Knowledge Bases: {e}") + raise + + try: + logger.info("No active Knowledge Base found. Creating or validating all resources from scratch...") + + # STEP 1: Create OpenSearch Security Policies + try: + clients["opensearchserverless"].create_security_policy( + name=resource_names["enc_policy_name"], + type="encryption", + policy=json.dumps( + { + "Rules": [ + { + "ResourceType": "collection", + "Resource": [f"collection/{resource_names['collection_name']}"], + } + ], + "AWSOwnedKey": True, + } + ), + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info(f"Encryption policy '{resource_names['enc_policy_name']}' already exists.") + + try: + clients["opensearchserverless"].create_security_policy( + name=resource_names["net_policy_name"], + type="network", + policy=json.dumps( + [ + { + "Rules": [ + { + "ResourceType": "collection", + "Resource": [f"collection/{resource_names['collection_name']}"], + } + ], + "AllowFromPublic": True, + } + ] + ), + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info(f"Network policy '{resource_names['net_policy_name']}' already exists.") + time.sleep(10) + + # STEP 2: Create OpenSearch Serverless Collection + try: + collection_res = clients["opensearchserverless"].create_collection( + name=resource_names["collection_name"], type="VECTORSEARCH" + ) + collection_id = collection_res["createCollectionDetail"]["id"] + collection_arn = collection_res["createCollectionDetail"]["arn"] + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info(f"Collection '{resource_names['collection_name']}' already exists. Fetching details.") + collection_details = clients["opensearchserverless"].list_collections( + collectionFilters={"name": resource_names["collection_name"]} + )["collectionSummaries"][0] + collection_id = collection_details["id"] + collection_arn = collection_details["arn"] + created_resources["collection_id"] = collection_id + + # STEP 3: Create IAM Role and Policies + try: + role_res = clients["iam"].get_role(RoleName=resource_names["role_name"]) + logger.info(f"IAM Role '{resource_names['role_name']}' already exists.") + created_resources["role_arn"] = role_res["Role"]["Arn"] + except ClientError as e: + if e.response["Error"]["Code"] != "NoSuchEntity": + raise + iam_policy_doc = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": "bedrock:InvokeModel", "Resource": EMBEDDING_MODEL_ARN}, + {"Effect": "Allow", "Action": "aoss:APIAccessAll", "Resource": collection_arn}, + ], + } + assume_role_policy = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Principal": {"Service": "bedrock.amazonaws.com"}, "Action": "sts:AssumeRole"} + ], + } + role_res = clients["iam"].create_role( + RoleName=resource_names["role_name"], AssumeRolePolicyDocument=json.dumps(assume_role_policy) + ) + created_resources["role_arn"] = role_res["Role"]["Arn"] + clients["iam"].put_role_policy( + RoleName=resource_names["role_name"], + PolicyName=resource_names["policy_name"], + PolicyDocument=json.dumps(iam_policy_doc), + ) + time.sleep(15) + + # STEP 4: Create OpenSearch Data Access Policy + try: + user_arn = clients["sts"].get_caller_identity()["Arn"] + access_policy_doc = [ + { + "Rules": [ + { + "ResourceType": "collection", + "Resource": [f"collection/{resource_names['collection_name']}"], + "Permission": ["aoss:*"], + }, + { + "ResourceType": "index", + "Resource": [f"index/{resource_names['collection_name']}/*"], + "Permission": ["aoss:*"], + }, + ], + "Principal": [created_resources["role_arn"], user_arn], + } + ] + clients["opensearchserverless"].create_access_policy( + name=resource_names["access_policy_name"], type="data", policy=json.dumps(access_policy_doc) + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info(f"Access policy '{resource_names['access_policy_name']}' already exists.") + + # STEP 5: Wait for OpenSearch Collection to be Active + collection_details = _wait_for_resource( + lambda: clients["opensearchserverless"].batch_get_collection(ids=[collection_id])["collectionDetails"][0], + resource_name=f"OpenSearch Collection ({collection_id})", + ) + + # STEP 6: Create the vector index + collection_endpoint = collection_details["collectionEndpoint"] + host = collection_endpoint.replace("https://", "") + auth = AWSV4SignerAuth(boto3.Session().get_credentials(), AWS_REGION, "aoss") + os_client = OpenSearch( + hosts=[{"host": host, "port": 443}], + http_auth=auth, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + ) + index_body = { + "settings": {"index": {"knn": True, "knn.algo_param.ef_search": 512}}, + "mappings": { + "properties": { + "bedrock-kb-vector": { + "type": "knn_vector", + "dimension": EMBEDDING_DIMENSION, + "method": {"name": "hnsw", "engine": "faiss"}, + }, + "AMAZON_BEDROCK_TEXT_CHUNK": {"type": "text"}, + "AMAZON_BEDROCK_METADATA": {"type": "text"}, + } + }, + } + if not os_client.indices.exists(resource_names["vector_index_name"]): + time.sleep(20) # Wait for access policy to propagate + for i in range(5): + try: + os_client.indices.create(resource_names["vector_index_name"], body=index_body) + logger.info("SUCCESS: Vector index created.") + break + except (AuthenticationException, AuthorizationException) as e: + if i < 4: + logger.warning(f"Auth error creating index (Attempt {i+1}). Waiting 30s...") + time.sleep(30) + else: + logger.error("Authorization error persisted after multiple retries.") + raise e + time.sleep(10) + + # STEP 7: Create Knowledge Base + kb_res = clients["bedrock-agent"].create_knowledge_base( + name=resource_names["kb_name"], + roleArn=created_resources["role_arn"], + knowledgeBaseConfiguration={ + "type": "VECTOR", + "vectorKnowledgeBaseConfiguration": {"embeddingModelArn": EMBEDDING_MODEL_ARN}, + }, + storageConfiguration={ + "type": "OPENSEARCH_SERVERLESS", + "opensearchServerlessConfiguration": { + "collectionArn": collection_arn, + "vectorIndexName": resource_names["vector_index_name"], + "fieldMapping": { + "vectorField": "bedrock-kb-vector", + "textField": "AMAZON_BEDROCK_TEXT_CHUNK", + "metadataField": "AMAZON_BEDROCK_METADATA", + }, + }, + }, + ) + created_resources["kb_id"] = kb_res["knowledgeBase"]["knowledgeBaseId"] + _wait_for_resource( + lambda: clients["bedrock-agent"].get_knowledge_base(knowledgeBaseId=created_resources["kb_id"])[ + "knowledgeBase" + ], + resource_name=f"Knowledge Base ({created_resources['kb_id']})", + ) + + # STEP 8: Create Data Source + ds_res = clients["bedrock-agent"].create_data_source( + knowledgeBaseId=created_resources["kb_id"], + name=resource_names["ds_name"], + dataSourceConfiguration={"type": "CUSTOM"}, + ) + created_resources["ds_id"] = ds_res["dataSource"]["dataSourceId"] + + # Do not need to wait for CUSTOM data source to become ACTIVE + time.sleep(10) + + logger.info("All new resources are ready.") + yield created_resources["kb_id"] + + finally: + if os.environ.get("STRANDS_TEARDOWN_RESOURCES", "true").lower() == "true": + logger.info("Starting teardown of AWS resources...") + # Use try/except for each deletion to make teardown more resilient + if "ds_id" in created_resources and "kb_id" in created_resources: + try: + clients["bedrock-agent"].delete_data_source( + knowledgeBaseId=created_resources["kb_id"], dataSourceId=created_resources["ds_id"] + ) + except ClientError: + logger.warning("Could not delete data source.") + if "kb_id" in created_resources: + try: + clients["bedrock-agent"].delete_knowledge_base(knowledgeBaseId=created_resources["kb_id"]) + time.sleep(30) + except ClientError: + logger.warning("Could not delete knowledge base.") + if "collection_id" in created_resources: + try: + clients["opensearchserverless"].delete_collection(id=created_resources["collection_id"]) + except ClientError: + logger.warning("Could not delete OpenSearch collection.") + if "access_policy_name" in resource_names: + try: + clients["opensearchserverless"].delete_access_policy( + name=resource_names["access_policy_name"], type="data" + ) + except ClientError: + logger.warning("Could not delete access policy.") + if "net_policy_name" in resource_names: + try: + clients["opensearchserverless"].delete_security_policy( + name=resource_names["net_policy_name"], type="network" + ) + except ClientError: + logger.warning("Could not delete network policy.") + if "enc_policy_name" in resource_names: + try: + clients["opensearchserverless"].delete_security_policy( + name=resource_names["enc_policy_name"], type="encryption" + ) + except ClientError: + logger.warning("Could not delete encryption policy.") + if "role_name" in resource_names: + try: + clients["iam"].delete_role_policy( + RoleName=resource_names["role_name"], PolicyName=resource_names["policy_name"] + ) + clients["iam"].delete_role(RoleName=resource_names["role_name"]) + except ClientError: + logger.warning("Could not delete IAM role.") + logger.info("Teardown complete.") + else: + logger.info("Skipping teardown of AWS resources as per STRANDS_TEARDOWN_RESOURCES setting.") + + +@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) +def test_memory_integration_store_and_retrieve(managed_knowledge_base): + """ + End-to-end test for Bedrock Knowledge Base memory tool: + - Store a unique document + - Poll until it is INDEXED + - Retrieve via semantic search and verify presence + """ + kb_id = managed_knowledge_base + agent = Agent(tools=[memory]) + clients = _get_boto_clients() + + unique_content = f"The secret password for the test is {uuid.uuid4()}." + store_result = agent.tool.memory( + action="store", + content=unique_content, + title="Integration Test Document", + STRANDS_KNOWLEDGE_BASE_ID=kb_id, + region_name=AWS_REGION, + ) + assert store_result["status"] == "success", f"Store failed: {store_result}" + + # Extract document ID + doc_id = next( + ( + item["text"].split(":", 1)[1].strip() + for item in store_result["content"] + if "Document ID" in item.get("text", "") + ), + None, + ) + assert doc_id, f"No document_id returned from store operation. Got: {store_result}" + + # Wait up to 3 minutes for document to be INDEXED + ds_id = clients["bedrock-agent"].list_data_sources(knowledgeBaseId=kb_id)["dataSourceSummaries"][0]["dataSourceId"] + for _ in range(18): # 18 * 10s = 180s + docs = clients["bedrock-agent"].list_knowledge_base_documents(knowledgeBaseId=kb_id, dataSourceId=ds_id)[ + "documentDetails" + ] + found = next((d for d in docs if d.get("identifier", {}).get("custom", {}).get("id") == doc_id), None) + if found and found.get("status") == "INDEXED": + break + time.sleep(10) + else: + raise AssertionError("Stored document did not become INDEXED in time.") + + # Try up to 2 minutes for the content to appear in retrieval results + for _ in range(12): + retrieve_result = agent.tool.memory( + action="retrieve", + query="The secret password for the test is", + STRANDS_KNOWLEDGE_BASE_ID=kb_id, + region_name=AWS_REGION, + min_score=0.0, + max_results=5, + ) + assert retrieve_result["status"] == "success", f"Retrieve failed: {retrieve_result}" + + full_retrieved_text = " ".join(item.get("text", "") for item in retrieve_result.get("content", [])) + if unique_content in full_retrieved_text: + break + time.sleep(10) + else: + raise AssertionError("Stored content not found in retrieval after waiting.") From d6710401aca29b475fc89650940476ed644f211d Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 20 Jun 2025 22:49:10 -0400 Subject: [PATCH 2/4] test: refactor code to be more readable --- tests-integ/test_memory_tool.py | 509 +++++++++++++++----------------- 1 file changed, 236 insertions(+), 273 deletions(-) diff --git a/tests-integ/test_memory_tool.py b/tests-integ/test_memory_tool.py index a12b6e52..aa8733db 100644 --- a/tests-integ/test_memory_tool.py +++ b/tests-integ/test_memory_tool.py @@ -2,18 +2,7 @@ Integration test for the Bedrock Knowledge Base (memory) tool. This test creates real AWS resources (IAM Role, OpenSearch Collection, Bedrock KB) -to validate the end-to-end functionality of the memory tool. It is designed to -be resilient and re-use existing resources to speed up local development. - -Prerequisites: -- AWS credentials must be configured in the environment. -- The AWS identity must have permissions to manage IAM, OpenSearch Serverless, - and Bedrock KB resources. - -Configuration: -- STRANDS_TEARDOWN_RESOURCES (env var): Set to "false" to prevent the automatic - deletion of AWS resources after the test run. This is useful for speeding - up local development by re-using the same Knowledge Base. +to validate the end-to-end functionality of the memory tool. """ import json @@ -36,107 +25,144 @@ from strands import Agent from strands_tools import memory -logger = logging.getLogger(__name__) - AWS_REGION = "us-east-1" -# Use a standard embedding model for the KB EMBEDDING_MODEL_ARN = f"arn:aws:bedrock:{AWS_REGION}::foundation-model/amazon.titan-embed-text-v1" EMBEDDING_DIMENSION = 1536 # Dimension for amazon.titan-embed-text-v1 - -def _get_boto_clients(): - """Returns a dictionary of boto3 clients needed for the test.""" - return { - "iam": boto3.client("iam", region_name=AWS_REGION), - "bedrock-agent": boto3.client("bedrock-agent", region_name=AWS_REGION), - "opensearchserverless": boto3.client("opensearchserverless", region_name=AWS_REGION), - "sts": boto3.client("sts", region_name=AWS_REGION), - } - - -def _wait_for_resource( - poll_function, - resource_name, - success_status="ACTIVE", - failure_status="FAILED", - timeout_seconds=600, - delay_seconds=30, -): - """Generic waiter for AWS asynchronous operations with detailed logging.""" - logger.info(f"Waiting for '{resource_name}' to become '{success_status}'...") - start_time = time.time() - while time.time() - start_time < timeout_seconds: - try: - response = poll_function() - status = response.get("status") or response.get("Status") - logger.info(f"Polling status for '{resource_name}': {status} (Elapsed: {int(time.time() - start_time)}s)") - if status == success_status: - logger.info(f"SUCCESS: Resource '{resource_name}' reached '{success_status}' state.") - return response - if status == failure_status: - logger.error( - f"FAILURE: Resource '{resource_name}' entered failure state: {failure_status}-Response: {response}" - ) - raise Exception(f"Resource '{resource_name}' entered failure state: {failure_status}") - except ClientError as e: - if "ResourceNotFoundException" not in str(e): - raise e - logger.info(f"Resource '{resource_name}' not found yet, continuing to wait...") - time.sleep(delay_seconds) - raise TimeoutError(f"Timed out waiting for resource '{resource_name}' to become '{success_status}'.") +logger = logging.getLogger(__name__) @pytest.fixture(scope="module") def managed_knowledge_base(): + helper = KnowledgeBaseHelper() + kb_id = helper.try_get_existing() + if kb_id is not None: + yield kb_id + else: + kb_id = helper.create_resources() + yield kb_id + if helper.should_teardown: + helper.destroy() + + +@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) +def test_memory_integration_store_and_retrieve(managed_knowledge_base): """ - Pytest fixture to create and tear down a Bedrock Knowledge Base and its dependencies. - It will re-use existing resources if found to speed up local test runs. - Teardown can be skipped by setting the STRANDS_TEARDOWN_RESOURCES env var to "false". + End-to-end test for Bedrock Knowledge Base memory tool: + - Store a unique document + - Poll until it is INDEXED + - Retrieve via semantic search and verify presence """ - clients = _get_boto_clients() + kb_id = managed_knowledge_base + agent = Agent(tools=[memory]) + clients = KnowledgeBaseHelper._get_boto_clients() + + test_uuid = str(uuid.uuid4()) + unique_content = f"The secret password for the test is {test_uuid}." + store_result = agent.tool.memory( + action="store", + content=unique_content, + title="Integration Test Document", + STRANDS_KNOWLEDGE_BASE_ID=kb_id, + region_name=AWS_REGION, + ) + assert store_result["status"] == "success", f"Store failed: {store_result}" + + # Extract document ID + doc_id = next( + ( + item["text"].split(":", 1)[1].strip() + for item in store_result["content"] + if "Document ID" in item.get("text", "") + ), + None, + ) + assert doc_id, f"No document_id returned from store operation. Got: {store_result}" - resource_names = { - "role_name": "StrandsMemoryIntegTestRole", - "kb_name": "strands-memory-integ-test-kb", - "ds_name": "strands-memory-integ-test-ds", - "policy_name": "StrandsMemoryIntegTestPolicy", - "collection_name": "strands-memory-integ-coll", - "enc_policy_name": "strands-memory-enc-policy", - "net_policy_name": "strands-memory-net-policy", - "access_policy_name": "strands-memory-access-policy", - "vector_index_name": "bedrock-kb-index", - } - created_resources = {} + ds_id = clients["bedrock-agent"].list_data_sources(knowledgeBaseId=kb_id)["dataSourceSummaries"][0]["dataSourceId"] + for _ in range(18): + docs = clients["bedrock-agent"].list_knowledge_base_documents(knowledgeBaseId=kb_id, dataSourceId=ds_id)[ + "documentDetails" + ] + found = next((d for d in docs if d.get("identifier", {}).get("custom", {}).get("id") == doc_id), None) + if found and found.get("status") == "INDEXED": + break + time.sleep(10) + else: + raise AssertionError("Stored document did not become INDEXED in time.") - # Check if Knowledge Base already exists - try: - kbs = clients["bedrock-agent"].list_knowledge_bases() - existing_kb = next( - (kb for kb in kbs.get("knowledgeBaseSummaries", []) if kb["name"] == resource_names["kb_name"]), None + # Try up to 2 minutes for the content to appear in retrieval results + for _ in range(12): + retrieve_result = agent.tool.memory( + action="retrieve", + query=test_uuid, # use uuid as query is always order by semantic relevance + STRANDS_KNOWLEDGE_BASE_ID=kb_id, + region_name=AWS_REGION, + min_score=0.0, + max_results=10, ) - if existing_kb and existing_kb["status"] == "ACTIVE": - logger.info(f"Found existing and ACTIVE Knowledge Base '{resource_names['kb_name']}'. Re-using for test.") - yield existing_kb["knowledgeBaseId"] - return # Skip creation if we are reusing - except ClientError as e: - logger.error(f"Error checking for existing Knowledge Bases: {e}") - raise - try: - logger.info("No active Knowledge Base found. Creating or validating all resources from scratch...") + full_retrieved_text = " ".join(item.get("text", "") for item in retrieve_result.get("content", [])) + if unique_content in full_retrieved_text: + break + time.sleep(10) + else: + raise AssertionError("Stored content not found in retrieval after waiting.") + + +class KnowledgeBaseHelper: + def __init__(self): + self.clients = self._get_boto_clients() + self.index = { + "role_name": "StrandsMemoryIntegTestRole", + "kb_name": "strands-memory-integ-test-kb", + "ds_name": "strands-memory-integ-test-ds", + "policy_name": "StrandsMemoryIntegTestPolicy", + "collection_name": "strands-memory-integ-coll", + "enc_policy_name": "strands-memory-enc-policy", + "net_policy_name": "strands-memory-net-policy", + "access_policy_name": "strands-memory-access-policy", + "vector_index_name": "bedrock-kb-index", + } + self.resource_names = self.index + self.created_resources = {} + self.should_teardown = os.environ.get("STRANDS_TEARDOWN_RESOURCES", "true").lower() == "true" + + @staticmethod + def _get_boto_clients(): + return { + "iam": boto3.client("iam", region_name=AWS_REGION), + "bedrock-agent": boto3.client("bedrock-agent", region_name=AWS_REGION), + "opensearchserverless": boto3.client("opensearchserverless", region_name=AWS_REGION), + "sts": boto3.client("sts", region_name=AWS_REGION), + } - # STEP 1: Create OpenSearch Security Policies + def try_get_existing(self): try: - clients["opensearchserverless"].create_security_policy( - name=resource_names["enc_policy_name"], + kbs = self.clients["bedrock-agent"].list_knowledge_bases() + kb = next( + (kb for kb in kbs.get("knowledgeBaseSummaries", []) if kb["name"] == self.resource_names["kb_name"]), + None, + ) + if kb and kb["status"] == "ACTIVE": + return kb["knowledgeBaseId"] + except ClientError as e: + logger.error(f"Error checking for existing Knowledge Bases: {e}") + return None + + def create_resources(self): + resources = self.resource_names + client = self.clients + + # 1. OpenSearch Security Policies + try: + client["opensearchserverless"].create_security_policy( + name=resources["enc_policy_name"], type="encryption", policy=json.dumps( { "Rules": [ - { - "ResourceType": "collection", - "Resource": [f"collection/{resource_names['collection_name']}"], - } + {"ResourceType": "collection", "Resource": [f"collection/{resources['collection_name']}"]} ], "AWSOwnedKey": True, } @@ -145,11 +171,10 @@ def managed_knowledge_base(): except ClientError as e: if e.response["Error"]["Code"] != "ConflictException": raise - logger.info(f"Encryption policy '{resource_names['enc_policy_name']}' already exists.") try: - clients["opensearchserverless"].create_security_policy( - name=resource_names["net_policy_name"], + client["opensearchserverless"].create_security_policy( + name=resources["net_policy_name"], type="network", policy=json.dumps( [ @@ -157,7 +182,7 @@ def managed_knowledge_base(): "Rules": [ { "ResourceType": "collection", - "Resource": [f"collection/{resource_names['collection_name']}"], + "Resource": [f"collection/{resources['collection_name']}"], } ], "AllowFromPublic": True, @@ -168,32 +193,29 @@ def managed_knowledge_base(): except ClientError as e: if e.response["Error"]["Code"] != "ConflictException": raise - logger.info(f"Network policy '{resource_names['net_policy_name']}' already exists.") time.sleep(10) - # STEP 2: Create OpenSearch Serverless Collection + # 2. OpenSearch Collection try: - collection_res = clients["opensearchserverless"].create_collection( - name=resource_names["collection_name"], type="VECTORSEARCH" + collection_res = client["opensearchserverless"].create_collection( + name=resources["collection_name"], type="VECTORSEARCH" ) collection_id = collection_res["createCollectionDetail"]["id"] collection_arn = collection_res["createCollectionDetail"]["arn"] except ClientError as e: if e.response["Error"]["Code"] != "ConflictException": raise - logger.info(f"Collection '{resource_names['collection_name']}' already exists. Fetching details.") - collection_details = clients["opensearchserverless"].list_collections( - collectionFilters={"name": resource_names["collection_name"]} + collection_details = client["opensearchserverless"].list_collections( + collectionFilters={"name": resources["collection_name"]} )["collectionSummaries"][0] collection_id = collection_details["id"] collection_arn = collection_details["arn"] - created_resources["collection_id"] = collection_id + self.created_resources["collection_id"] = collection_id - # STEP 3: Create IAM Role and Policies + # 3. IAM Role and Policies try: - role_res = clients["iam"].get_role(RoleName=resource_names["role_name"]) - logger.info(f"IAM Role '{resource_names['role_name']}' already exists.") - created_resources["role_arn"] = role_res["Role"]["Arn"] + role_res = client["iam"].get_role(RoleName=resources["role_name"]) + self.created_resources["role_arn"] = role_res["Role"]["Arn"] except ClientError as e: if e.response["Error"]["Code"] != "NoSuchEntity": raise @@ -210,52 +232,50 @@ def managed_knowledge_base(): {"Effect": "Allow", "Principal": {"Service": "bedrock.amazonaws.com"}, "Action": "sts:AssumeRole"} ], } - role_res = clients["iam"].create_role( - RoleName=resource_names["role_name"], AssumeRolePolicyDocument=json.dumps(assume_role_policy) + role_res = client["iam"].create_role( + RoleName=resources["role_name"], AssumeRolePolicyDocument=json.dumps(assume_role_policy) ) - created_resources["role_arn"] = role_res["Role"]["Arn"] - clients["iam"].put_role_policy( - RoleName=resource_names["role_name"], - PolicyName=resource_names["policy_name"], + self.created_resources["role_arn"] = role_res["Role"]["Arn"] + client["iam"].put_role_policy( + RoleName=resources["role_name"], + PolicyName=resources["policy_name"], PolicyDocument=json.dumps(iam_policy_doc), ) time.sleep(15) - # STEP 4: Create OpenSearch Data Access Policy + # 4. OpenSearch Data Access Policy try: - user_arn = clients["sts"].get_caller_identity()["Arn"] + user_arn = client["sts"].get_caller_identity()["Arn"] access_policy_doc = [ { "Rules": [ { "ResourceType": "collection", - "Resource": [f"collection/{resource_names['collection_name']}"], + "Resource": [f"collection/{resources['collection_name']}"], "Permission": ["aoss:*"], }, { "ResourceType": "index", - "Resource": [f"index/{resource_names['collection_name']}/*"], + "Resource": [f"index/{resources['collection_name']}/*"], "Permission": ["aoss:*"], }, ], - "Principal": [created_resources["role_arn"], user_arn], + "Principal": [self.created_resources["role_arn"], user_arn], } ] - clients["opensearchserverless"].create_access_policy( - name=resource_names["access_policy_name"], type="data", policy=json.dumps(access_policy_doc) + client["opensearchserverless"].create_access_policy( + name=resources["access_policy_name"], type="data", policy=json.dumps(access_policy_doc) ) except ClientError as e: if e.response["Error"]["Code"] != "ConflictException": raise - logger.info(f"Access policy '{resource_names['access_policy_name']}' already exists.") - - # STEP 5: Wait for OpenSearch Collection to be Active - collection_details = _wait_for_resource( - lambda: clients["opensearchserverless"].batch_get_collection(ids=[collection_id])["collectionDetails"][0], + # 5. Wait for OpenSearch Collection to be Active + collection_details = self._wait_for_resource( + lambda: client["opensearchserverless"].batch_get_collection(ids=[collection_id])["collectionDetails"][0], resource_name=f"OpenSearch Collection ({collection_id})", ) - # STEP 6: Create the vector index + # 6. Create vector index collection_endpoint = collection_details["collectionEndpoint"] host = collection_endpoint.replace("https://", "") auth = AWSV4SignerAuth(boto3.Session().get_credentials(), AWS_REGION, "aoss") @@ -280,26 +300,23 @@ def managed_knowledge_base(): } }, } - if not os_client.indices.exists(resource_names["vector_index_name"]): - time.sleep(20) # Wait for access policy to propagate + if not os_client.indices.exists(resources["vector_index_name"]): + time.sleep(20) for i in range(5): try: - os_client.indices.create(resource_names["vector_index_name"], body=index_body) - logger.info("SUCCESS: Vector index created.") + os_client.indices.create(resources["vector_index_name"], body=index_body) break except (AuthenticationException, AuthorizationException) as e: if i < 4: - logger.warning(f"Auth error creating index (Attempt {i+1}). Waiting 30s...") time.sleep(30) else: - logger.error("Authorization error persisted after multiple retries.") raise e time.sleep(10) - # STEP 7: Create Knowledge Base - kb_res = clients["bedrock-agent"].create_knowledge_base( - name=resource_names["kb_name"], - roleArn=created_resources["role_arn"], + # 7. Knowledge Base + kb_res = client["bedrock-agent"].create_knowledge_base( + name=resources["kb_name"], + roleArn=self.created_resources["role_arn"], knowledgeBaseConfiguration={ "type": "VECTOR", "vectorKnowledgeBaseConfiguration": {"embeddingModelArn": EMBEDDING_MODEL_ARN}, @@ -308,7 +325,7 @@ def managed_knowledge_base(): "type": "OPENSEARCH_SERVERLESS", "opensearchServerlessConfiguration": { "collectionArn": collection_arn, - "vectorIndexName": resource_names["vector_index_name"], + "vectorIndexName": resources["vector_index_name"], "fieldMapping": { "vectorField": "bedrock-kb-vector", "textField": "AMAZON_BEDROCK_TEXT_CHUNK", @@ -317,145 +334,91 @@ def managed_knowledge_base(): }, }, ) - created_resources["kb_id"] = kb_res["knowledgeBase"]["knowledgeBaseId"] - _wait_for_resource( - lambda: clients["bedrock-agent"].get_knowledge_base(knowledgeBaseId=created_resources["kb_id"])[ + self.created_resources["kb_id"] = kb_res["knowledgeBase"]["knowledgeBaseId"] + self._wait_for_resource( + lambda: client["bedrock-agent"].get_knowledge_base(knowledgeBaseId=self.created_resources["kb_id"])[ "knowledgeBase" ], - resource_name=f"Knowledge Base ({created_resources['kb_id']})", + resource_name=f"Knowledge Base ({self.created_resources['kb_id']})", ) - - # STEP 8: Create Data Source - ds_res = clients["bedrock-agent"].create_data_source( - knowledgeBaseId=created_resources["kb_id"], - name=resource_names["ds_name"], + # 8. Data Source + ds_resource = client["bedrock-agent"].create_data_source( + knowledgeBaseId=self.created_resources["kb_id"], + name=resources["ds_name"], dataSourceConfiguration={"type": "CUSTOM"}, ) - created_resources["ds_id"] = ds_res["dataSource"]["dataSourceId"] - - # Do not need to wait for CUSTOM data source to become ACTIVE + self.created_resources["ds_id"] = ds_resource["dataSource"]["dataSourceId"] time.sleep(10) + return self.created_resources["kb_id"] - logger.info("All new resources are ready.") - yield created_resources["kb_id"] - - finally: - if os.environ.get("STRANDS_TEARDOWN_RESOURCES", "true").lower() == "true": - logger.info("Starting teardown of AWS resources...") - # Use try/except for each deletion to make teardown more resilient - if "ds_id" in created_resources and "kb_id" in created_resources: - try: - clients["bedrock-agent"].delete_data_source( - knowledgeBaseId=created_resources["kb_id"], dataSourceId=created_resources["ds_id"] - ) - except ClientError: - logger.warning("Could not delete data source.") - if "kb_id" in created_resources: - try: - clients["bedrock-agent"].delete_knowledge_base(knowledgeBaseId=created_resources["kb_id"]) - time.sleep(30) - except ClientError: - logger.warning("Could not delete knowledge base.") - if "collection_id" in created_resources: - try: - clients["opensearchserverless"].delete_collection(id=created_resources["collection_id"]) - except ClientError: - logger.warning("Could not delete OpenSearch collection.") - if "access_policy_name" in resource_names: - try: - clients["opensearchserverless"].delete_access_policy( - name=resource_names["access_policy_name"], type="data" - ) - except ClientError: - logger.warning("Could not delete access policy.") - if "net_policy_name" in resource_names: - try: - clients["opensearchserverless"].delete_security_policy( - name=resource_names["net_policy_name"], type="network" - ) - except ClientError: - logger.warning("Could not delete network policy.") - if "enc_policy_name" in resource_names: - try: - clients["opensearchserverless"].delete_security_policy( - name=resource_names["enc_policy_name"], type="encryption" - ) - except ClientError: - logger.warning("Could not delete encryption policy.") - if "role_name" in resource_names: - try: - clients["iam"].delete_role_policy( - RoleName=resource_names["role_name"], PolicyName=resource_names["policy_name"] + def destroy(self): + client = self.clients + resources = self.resource_names + cr = self.created_resources + try: + if "ds_id" in cr and "kb_id" in cr: + client["bedrock-agent"].delete_data_source(knowledgeBaseId=cr["kb_id"], dataSourceId=cr["ds_id"]) + except ClientError: + pass + try: + if "kb_id" in cr: + client["bedrock-agent"].delete_knowledge_base(knowledgeBaseId=cr["kb_id"]) + time.sleep(30) + except ClientError: + pass + try: + if "collection_id" in cr: + client["opensearchserverless"].delete_collection(id=cr["collection_id"]) + except ClientError: + pass + try: + client["opensearchserverless"].delete_access_policy(name=resources["access_policy_name"], type="data") + except ClientError: + pass + try: + client["opensearchserverless"].delete_security_policy(name=resources["net_policy_name"], type="network") + except ClientError: + pass + try: + client["opensearchserverless"].delete_security_policy(name=resources["enc_policy_name"], type="encryption") + except ClientError: + pass + try: + client["iam"].delete_role_policy(RoleName=resources["role_name"], PolicyName=resources["policy_name"]) + client["iam"].delete_role(RoleName=resources["role_name"]) + except ClientError: + pass + + def _wait_for_resource( + self, + poll_function, + resource_name, + success_status="ACTIVE", + failure_status="FAILED", + timeout_seconds=600, + delay_seconds=30, + ): + """Generic waiter for AWS asynchronous operations with detailed logging.""" + logger.info(f"Waiting for '{resource_name}' to become '{success_status}'...") + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = poll_function() + status = response.get("status") or response.get("Status") + logger.info( + f"Polling status for '{resource_name}': {status} (Elapsed: {int(time.time() - start_time)}s)" + ) + if status == success_status: + logger.info(f"SUCCESS: Resource '{resource_name}' reached '{success_status}' state.") + return response + if status == failure_status: + logger.error( + f"FAILURE: Resource '{resource_name}' failed with status: {failure_status}-Response:{response}" ) - clients["iam"].delete_role(RoleName=resource_names["role_name"]) - except ClientError: - logger.warning("Could not delete IAM role.") - logger.info("Teardown complete.") - else: - logger.info("Skipping teardown of AWS resources as per STRANDS_TEARDOWN_RESOURCES setting.") - - -@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) -def test_memory_integration_store_and_retrieve(managed_knowledge_base): - """ - End-to-end test for Bedrock Knowledge Base memory tool: - - Store a unique document - - Poll until it is INDEXED - - Retrieve via semantic search and verify presence - """ - kb_id = managed_knowledge_base - agent = Agent(tools=[memory]) - clients = _get_boto_clients() - - unique_content = f"The secret password for the test is {uuid.uuid4()}." - store_result = agent.tool.memory( - action="store", - content=unique_content, - title="Integration Test Document", - STRANDS_KNOWLEDGE_BASE_ID=kb_id, - region_name=AWS_REGION, - ) - assert store_result["status"] == "success", f"Store failed: {store_result}" - - # Extract document ID - doc_id = next( - ( - item["text"].split(":", 1)[1].strip() - for item in store_result["content"] - if "Document ID" in item.get("text", "") - ), - None, - ) - assert doc_id, f"No document_id returned from store operation. Got: {store_result}" - - # Wait up to 3 minutes for document to be INDEXED - ds_id = clients["bedrock-agent"].list_data_sources(knowledgeBaseId=kb_id)["dataSourceSummaries"][0]["dataSourceId"] - for _ in range(18): # 18 * 10s = 180s - docs = clients["bedrock-agent"].list_knowledge_base_documents(knowledgeBaseId=kb_id, dataSourceId=ds_id)[ - "documentDetails" - ] - found = next((d for d in docs if d.get("identifier", {}).get("custom", {}).get("id") == doc_id), None) - if found and found.get("status") == "INDEXED": - break - time.sleep(10) - else: - raise AssertionError("Stored document did not become INDEXED in time.") - - # Try up to 2 minutes for the content to appear in retrieval results - for _ in range(12): - retrieve_result = agent.tool.memory( - action="retrieve", - query="The secret password for the test is", - STRANDS_KNOWLEDGE_BASE_ID=kb_id, - region_name=AWS_REGION, - min_score=0.0, - max_results=5, - ) - assert retrieve_result["status"] == "success", f"Retrieve failed: {retrieve_result}" - - full_retrieved_text = " ".join(item.get("text", "") for item in retrieve_result.get("content", [])) - if unique_content in full_retrieved_text: - break - time.sleep(10) - else: - raise AssertionError("Stored content not found in retrieval after waiting.") + raise Exception(f"Resource '{resource_name}' entered failure state: {failure_status}") + except ClientError as e: + if "ResourceNotFoundException" not in str(e): + raise e + logger.info(f"Resource '{resource_name}' not found yet, continuing to wait...") + time.sleep(delay_seconds) + raise TimeoutError(f"Timed out waiting for resource '{resource_name}' to become '{success_status}'.") From d6ba0724df64b1543e40427c98dc1dd1bfde9401 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 26 Jun 2025 16:58:03 -0400 Subject: [PATCH 3/4] test(read_write_edit): add end to end tests for these tools --- tests-integ/test_read_write_edit.py | 100 ++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests-integ/test_read_write_edit.py diff --git a/tests-integ/test_read_write_edit.py b/tests-integ/test_read_write_edit.py new file mode 100644 index 00000000..91b9bb99 --- /dev/null +++ b/tests-integ/test_read_write_edit.py @@ -0,0 +1,100 @@ +import os +import re +from unittest.mock import patch + +import pytest +from strands import Agent +from strands_tools import editor, file_read, file_write + + +@pytest.fixture +def agent(): + """Agent with file read, write, and editor tools.""" + return Agent(tools=[file_write, file_read, editor]) + + +def extract_file_content(response): + """Helper function to extract code block content from LLM output.""" + match = re.search(r"```(?:[a-zA-Z]*\n)?(.*?)```", str(response), re.DOTALL) + return match.group(1) if match else str(response) + + +@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) +def test_semantic_write_read_edit_workflow(agent, tmp_path): + """Test complete semantic workflow: write -> read -> edit -> verify.""" + file_path = tmp_path / "semantic_test.txt" + initial_content = "Hello world from integration test!" + + # 1. Write file + write_response = agent(f"Write '{initial_content}' to file `{file_path}`") + assert "success" in str(write_response).lower() or "written" in str(write_response).lower() + + # 2. Read file back + read_response = agent(f"Read the contents of file `{file_path}`") + content = extract_file_content(read_response) + assert initial_content in content + + # 3. Replace text + edit_response = agent(f"In file `{file_path}`, replace 'Hello' with 'Hi'") + assert "success" in str(edit_response).lower() or "replaced" in str(edit_response).lower() + + # 4. Verify + verify_response = agent(f"Show me the contents of `{file_path}`") + final_content = extract_file_content(verify_response) + assert "Hi world" in final_content + assert "Hello" not in final_content + + +@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) +def test_semantic_python_file_creation(agent, tmp_path): + """Test creating and modifying Python code semantically.""" + file_path = tmp_path / "test_script.py" + + # 1. Create Python file + create_response = agent(f"Create a Python file at `{file_path}` with a function that prints 'Hello World'") + assert "success" in str(create_response).lower() or "created" in str(create_response).lower() + + # 2. Read and verify + read_response = agent(f"Show me the Python code in `{file_path}`") + content = str(read_response) + assert "def" in content and "print" in content and "Hello World" in content + + # 3. Modify the function + modify_response = agent(f"In `{file_path}`, change the print statement to say 'Hi there!' instead") + semantic_success = any( + phrase in str(modify_response).lower() + for phrase in [ + "file has been updated", + "now prints 'hi there!'", + "updated successfully", + "replacement was successful", + "print statement to say 'hi there!'", + ] + ) + assert semantic_success, str(modify_response) + + # 4. Verify modification + final_response = agent(f"Read `{file_path}` and show me the code") + final_content = str(final_response) + assert "Hi there!" in final_content + + +@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) +def test_semantic_search_and_replace(agent, tmp_path): + """Test semantic search and replace operations.""" + file_path = tmp_path / "config.txt" + + # 1. Create config file + agent(f"Create a config file at `{file_path}` with settings: debug=true, port=8080, host=localhost") + + # 2. Change a specific setting + agent(f"In `{file_path}`, change the port from 8080 to 3000") + + # 3. Verify the change + verify_response = agent(f"What is the port setting in `{file_path}`?") + assert "3000" in str(verify_response) + + # 4. Final check + final_response = agent(f"Show me all settings in `{file_path}`") + final_content = str(final_response) + assert "3000" in final_content From 115033fa6681d05b889a333913c09d4229b3b255 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 27 Jun 2025 13:50:11 -0400 Subject: [PATCH 4/4] test(read_write_edit): refactor minor code style --- tests-integ/test_read_write_edit.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests-integ/test_read_write_edit.py b/tests-integ/test_read_write_edit.py index 91b9bb99..78ed4075 100644 --- a/tests-integ/test_read_write_edit.py +++ b/tests-integ/test_read_write_edit.py @@ -13,13 +13,18 @@ def agent(): return Agent(tools=[file_write, file_read, editor]) -def extract_file_content(response): +@pytest.fixture(autouse=True) +def bypass_tool_consent_env(): + with patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}): + yield + + +def extract_code_content(response): """Helper function to extract code block content from LLM output.""" match = re.search(r"```(?:[a-zA-Z]*\n)?(.*?)```", str(response), re.DOTALL) return match.group(1) if match else str(response) -@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) def test_semantic_write_read_edit_workflow(agent, tmp_path): """Test complete semantic workflow: write -> read -> edit -> verify.""" file_path = tmp_path / "semantic_test.txt" @@ -29,23 +34,26 @@ def test_semantic_write_read_edit_workflow(agent, tmp_path): write_response = agent(f"Write '{initial_content}' to file `{file_path}`") assert "success" in str(write_response).lower() or "written" in str(write_response).lower() - # 2. Read file back + # 2. Read file back using both agent & reading file directly read_response = agent(f"Read the contents of file `{file_path}`") - content = extract_file_content(read_response) + content = extract_code_content(read_response) assert initial_content in content + with open(file_path, "r") as f: + raw_content = f.read() + assert initial_content in raw_content + # 3. Replace text edit_response = agent(f"In file `{file_path}`, replace 'Hello' with 'Hi'") assert "success" in str(edit_response).lower() or "replaced" in str(edit_response).lower() # 4. Verify verify_response = agent(f"Show me the contents of `{file_path}`") - final_content = extract_file_content(verify_response) + final_content = extract_code_content(verify_response) assert "Hi world" in final_content assert "Hello" not in final_content -@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) def test_semantic_python_file_creation(agent, tmp_path): """Test creating and modifying Python code semantically.""" file_path = tmp_path / "test_script.py" @@ -69,6 +77,7 @@ def test_semantic_python_file_creation(agent, tmp_path): "updated successfully", "replacement was successful", "print statement to say 'hi there!'", + "prints 'Hello World'", ] ) assert semantic_success, str(modify_response) @@ -79,7 +88,6 @@ def test_semantic_python_file_creation(agent, tmp_path): assert "Hi there!" in final_content -@patch.dict(os.environ, {"BYPASS_TOOL_CONSENT": "true"}) def test_semantic_search_and_replace(agent, tmp_path): """Test semantic search and replace operations.""" file_path = tmp_path / "config.txt"