diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b19a2c8..b9b97090 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Features - adds version-based index mapping update support to the Search Relevance plugin [#344](https://github.com/opensearch-project/search-relevance/pull/344) * LLM Judgement Customized Prompt Template Implementation [#264](https://github.com/opensearch-project/search-relevance/pull/264) +* Support multiple LLM providers with proper rate limit [#285](https://github.com/opensearch-project/search-relevance/pull/285) ### Enhancements diff --git a/docs/llm-model/claude/README.md b/docs/llm-model/claude/README.md new file mode 100644 index 00000000..7af71a01 --- /dev/null +++ b/docs/llm-model/claude/README.md @@ -0,0 +1,43 @@ +# Claude 3.5 Haiku OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating Claude 3.5 Haiku with OpenSearch ML for search relevance rating. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- AWS credentials configured (`aws configure`) +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `us.anthropic.claude-3-5-haiku-20241022-v1:0` (inference profile) +- **Region**: `us-east-1` +- **Max Tokens**: 4000 +- **Protocol**: AWS SigV4 + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Your prompt here" + } + ] + } + ] + } +} +``` \ No newline at end of file diff --git a/docs/llm-model/claude/connector_validate.sh b/docs/llm-model/claude/connector_validate.sh new file mode 100755 index 00000000..07f78ccc --- /dev/null +++ b/docs/llm-model/claude/connector_validate.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Claude 3.5 Haiku Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="claude-3-5-haiku-working" +MODEL_NAME="claude-3-5-haiku-working-model" + +# Get AWS credentials +export AWS_ACCESS_KEY_ID=$(aws configure get aws_access_key_id) +export AWS_SECRET_ACCESS_KEY=$(aws configure get aws_secret_access_key) +export AWS_SESSION_TOKEN=$(aws configure get aws_session_token) + +echo "Creating Claude 3.5 Haiku connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "Claude 3.5 Haiku connector for search relevance rating", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "'$AWS_ACCESS_KEY_ID'", + "secret_key": "'$AWS_SECRET_ACCESS_KEY'", + "session_token": "'$AWS_SESSION_TOKEN'" + }, + "parameters": { + "region": "us-east-1", + "service_name": "bedrock", + "model": "us.anthropic.claude-3-5-haiku-20241022-v1:0" + }, + "client_config": { + "max_connection": 2, + "connection_timeout": 60000, + "read_timeout": 60000, + "retry_backoff_millis": 3000, + "retry_timeout_seconds": 60, + "max_retry_times": 2 + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{\"anthropic_version\": \"bedrock-2023-05-31\", \"max_tokens\": 2000, \"messages\": ${parameters.messages}}" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "Claude 3.5 Haiku model for search relevance rating", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test prediction +echo "Testing prediction..." +TEST_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } + ] + } +}') + +echo "Test completed successfully!" +echo "Response: $(echo $TEST_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.content[0].text')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/cohere/README.md b/docs/llm-model/cohere/README.md new file mode 100644 index 00000000..65007e7f --- /dev/null +++ b/docs/llm-model/cohere/README.md @@ -0,0 +1,68 @@ +# Cohere Command R OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating Cohere Command R with OpenSearch ML via AWS Bedrock. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- AWS credentials configured (`aws configure`) +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `cohere.command-r-v1:0` (via Bedrock) +- **Region**: `us-east-1` +- **Max Tokens**: 1000 +- **Protocol**: AWS SigV4 +- **Streaming**: Supported + +### Message Format +```json +{ + "parameters": { + "message": "Your message here" + } +} +``` + +## Available Cohere Models + +### Chat Models (via Bedrock) +- `cohere.command-r-v1:0` - Command R +- `cohere.command-r-plus-v1:0` - Command R+ (more capable) + +### Embedding Models +- `cohere.embed-v4:0` - Embed v4 (multimodal) +- `cohere.embed-english-v3` - English embeddings +- `cohere.embed-multilingual-v3` - Multilingual embeddings + +### Rerank Model +- `cohere.rerank-v3-5:0` - Rerank 3.5 + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "response_id": "...", + "text": "Model response here", + "generation_id": "...", + "chat_history": [...], + "finish_reason": "COMPLETE" + } + }], + "status_code": 200 + }] +} +``` \ No newline at end of file diff --git a/docs/llm-model/cohere/connector_validate.sh b/docs/llm-model/cohere/connector_validate.sh new file mode 100755 index 00000000..3624a66c --- /dev/null +++ b/docs/llm-model/cohere/connector_validate.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Cohere Command R Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="cohere-command-r-bedrock" +MODEL_NAME="cohere-command-r-bedrock-model" + +# Get AWS credentials +export AWS_ACCESS_KEY_ID=$(aws configure get aws_access_key_id) +export AWS_SECRET_ACCESS_KEY=$(aws configure get aws_secret_access_key) +export AWS_SESSION_TOKEN=$(aws configure get aws_session_token) + +echo "Creating Cohere Command R connector via Bedrock..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "Cohere Command R via Bedrock for chat", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "'$AWS_ACCESS_KEY_ID'", + "secret_key": "'$AWS_SECRET_ACCESS_KEY'", + "session_token": "'$AWS_SESSION_TOKEN'" + }, + "parameters": { + "region": "us-east-1", + "service_name": "bedrock", + "model": "cohere.command-r-v1:0" + }, + "client_config": { + "max_connection": 2, + "connection_timeout": 60000, + "read_timeout": 60000, + "retry_backoff_millis": 3000, + "retry_timeout_seconds": 60, + "max_retry_times": 2 + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{\"message\": \"${parameters.message}\", \"max_tokens\": 1000}" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "Cohere Command R model via Bedrock", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "message": "Hello, respond with just the word success" + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.text')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "message": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.text')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/deepseek/README.md b/docs/llm-model/deepseek/README.md new file mode 100644 index 00000000..418055f3 --- /dev/null +++ b/docs/llm-model/deepseek/README.md @@ -0,0 +1,76 @@ +# DeepSeek Chat OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating DeepSeek Chat with OpenSearch ML. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- DeepSeek API key +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `deepseek-chat` +- **Endpoint**: `api.deepseek.com` +- **Protocol**: HTTP +- **API Version**: v1 + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Your message here" + } + ] + } +} +``` + +## Available Models + +### Chat Models +- `deepseek-chat` - General purpose chat model +- `deepseek-coder` - Code-focused model (if available) + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "...", + "object": "chat.completion", + "created": 1761688945, + "model": "deepseek-chat", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Model response here" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 1, + "total_tokens": 13 + } + } + }], + "status_code": 200 + }] +} +``` \ No newline at end of file diff --git a/docs/llm-model/deepseek/connector_validate.sh b/docs/llm-model/deepseek/connector_validate.sh new file mode 100644 index 00000000..fbbfd390 --- /dev/null +++ b/docs/llm-model/deepseek/connector_validate.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +# DeepSeek Chat Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="DeepSeek Chat" +MODEL_NAME="deepseek-chat-model" +DEEPSEEK_API_KEY="" + +echo "Creating DeepSeek Chat connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d "{ + \"name\": \"${CONNECTOR_NAME}\", + \"description\": \"DeepSeek Chat connector for conversational AI\", + \"version\": \"1\", + \"protocol\": \"http\", + \"parameters\": { + \"endpoint\": \"api.deepseek.com\", + \"model\": \"deepseek-chat\" + }, + \"credential\": { + \"deepSeek_key\": \"${DEEPSEEK_API_KEY}\" + }, + \"actions\": [ + { + \"action_type\": \"predict\", + \"method\": \"POST\", + \"url\": \"https://\${parameters.endpoint}/v1/chat/completions\", + \"headers\": { + \"Content-Type\": \"application/json\", + \"Authorization\": \"Bearer \${credential.deepSeek_key}\" + }, + \"request_body\": \"{ \\\"model\\\": \\\"\${parameters.model}\\\", \\\"messages\\\": \${parameters.messages} }\" + } + ] +}") + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "DeepSeek chat model for conversational AI", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Hello, respond with just the word success" + } + ] + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/openai/README.md b/docs/llm-model/openai/README.md new file mode 100644 index 00000000..23ee0edf --- /dev/null +++ b/docs/llm-model/openai/README.md @@ -0,0 +1,82 @@ +# OpenAI GPT OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating OpenAI GPT models with OpenSearch ML. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- Valid OpenAI API key +- `jq` installed for JSON parsing + +## Quick Start + +```bash +# Update API key in connector_validate.sh +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `gpt-5-nano` (configurable) +- **Endpoint**: `api.openai.com` +- **Protocol**: HTTP +- **Action Type**: `/v1/chat/completion` + +### Available Models +- `gpt-4o` - Latest GPT-4 Omni model +- `gpt-4o-mini` - Smaller, faster GPT-4 variant +- `gpt-4-turbo` - GPT-4 Turbo +- `gpt-3.5-turbo` - GPT-3.5 Turbo +- `gpt-5-nano` - GPT-5 Nano (if available) + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "system", + "content": "System prompt here" + }, + { + "role": "user", + "content": "User message here" + } + ] + } +} +``` + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "chatcmpl-...", + "object": "chat.completion", + "created": 1761689000, + "model": "gpt-5-nano", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Model response here" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + } + }], + "status_code": 200 + }] +} +``` diff --git a/docs/llm-model/openai/connector_validate.sh b/docs/llm-model/openai/connector_validate.sh new file mode 100755 index 00000000..d40605fd --- /dev/null +++ b/docs/llm-model/openai/connector_validate.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# OpenAI GPT Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="mfenqin-batch-test" +MODEL_NAME="openai-gpt-model" +OPENAI_API_KEY="" + +echo "Creating OpenAI GPT connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "OpenAI GPT connector for search relevance rating", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-5-nano" + }, + "credential": { + "openAI_key": "'${OPENAI_API_KEY}'" + }, + "actions": [ + { + "action_type": "batch_predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions", + "headers": { + "Authorization": "Bearer ${credential.openAI_key}" + }, + "request_body": "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "OpenAI GPT model for search relevance rating", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Hello, respond with just the word success" + } + ] + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "system", + "content": "You are an expert search relevance rater. Your task is to evaluate the relevance between search query and results with these criteria:\n- Score 1.0: Perfect match, highly relevant\n- Score 0.7-0.9: Very relevant with minor variations\n- Score 0.4-0.6: Moderately relevant\n- Score 0.1-0.3: Slightly relevant\n- Score 0.0: Completely irrelevant\nEvaluate based on: exact matches, semantic relevance, and overall context between the SearchText and content in Hits.\nWhen a reference is provided, evaluate based on the relevance to both SearchText and its reference.\n\nIMPORTANT: Provide your response ONLY as a JSON array of objects, each with \"id\" and \"rating_score\" fields. You MUST include a rating for EVERY hit provided, even if the rating is 0. Do not include any explanation or additional text. Example format: [{\"id\": \"001\", \"rating_score\": 0.9}, {\"id\": \"002\", \"rating_score\": 0.5}, {\"id\": \"003\", \"rating_score\": 0.0}]" + }, + { + "role": "user", + "content": "SearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" diff --git a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java index 9a21b9ef..a12afbc8 100644 --- a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java +++ b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java @@ -23,10 +23,18 @@ private MLConstants() {} * ML input field names */ public static final String PARAM_MESSAGES_FIELD = "messages"; + public static final String CONNECTOR_TYPE = "connectorType"; + public static final String RATE_LIMIT = "rateLimit"; public static final String PROMPT_TEMPLATE = "promptTemplate"; public static final String LLM_JUDGMENT_RATING_TYPE = "llmJudgmentRatingType"; public static final String OVERWRITE_CACHE = "overwriteCache"; + /** + * Default prompt template for LLM judgments + */ + public static final String DEFAULT_PROMPT_TEMPLATE = + "Rate the relevance of the search results to the query. SearchText: {{searchText}}; Results: {{hits}}"; + /** * Prompt template placeholder names. * These are the special variables that can be used in custom prompt templates. @@ -38,11 +46,6 @@ private MLConstants() {} public static final String PLACEHOLDER_REFERENCE = "reference"; public static final String PLACEHOLDER_REFERENCE_ANSWER = "referenceAnswer"; - /** - * Default prompt template for LLM judgments (simple format without reference data) - */ - public static final String DEFAULT_PROMPT_TEMPLATE = "SearchText: {{searchText}}; Hits: {{hits}}"; - /** * ML response field names */ @@ -199,4 +202,28 @@ public static int validateTokenLimit(Map source) { } } + /** + * Parses rateLimit value from an object, ensuring it's non-negative + * + * @param rateLimitObj The object to parse (can be Number, String, or null) + * @return Parsed rate limit value, or 0 if null/invalid + */ + public static long parseRateLimit(Object rateLimitObj) { + if (rateLimitObj == null) { + return 0L; + } + + try { + long rateLimit; + if (rateLimitObj instanceof Number) { + rateLimit = ((Number) rateLimitObj).longValue(); + } else { + rateLimit = Long.parseLong(rateLimitObj.toString()); + } + return Math.max(0, rateLimit); // ensure non-negative + } catch (NumberFormatException e) { + return 0L; // default to 0 on parse error + } + } + } diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java new file mode 100644 index 00000000..6afcd0fb --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.judgments; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.searchrelevance.model.SearchConfiguration; + +import lombok.Builder; +import lombok.Getter; + +/** + * Context object to hold LLM judgment parameters + */ +@Getter +@Builder +public class LlmJudgmentContext implements ToXContentObject { + public static final String MODEL_ID = "modelId"; + public static final String SIZE = "size"; + public static final String TOKEN_LIMIT = "tokenLimit"; + public static final String CONTEXT_FIELDS = "contextFields"; + public static final String SEARCH_CONFIGURATIONS = "searchConfigurations"; + public static final String IGNORE_FAILURE = "ignoreFailure"; + public static final String PROMPT_TEMPLATE = "promptTemplate"; + public static final String RATING_TYPE = "ratingType"; + public static final String OVERWRITE_CACHE = "overwriteCache"; + public static final String CONNECTOR_TYPE = "connectorType"; + public static final String RATE_LIMIT = "rateLimit"; + + private final String modelId; + private final int size; + private final int tokenLimit; + private final List contextFields; + private final List searchConfigurations; + private final boolean ignoreFailure; + private final String promptTemplate; + private final LLMJudgmentRatingType ratingType; + private final boolean overwriteCache; + private final ConnectorType connectorType; + private final long rateLimit; // milliseconds between requests + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID, modelId); + builder.field(SIZE, size); + builder.field(TOKEN_LIMIT, tokenLimit); + if (contextFields != null) { + builder.field(CONTEXT_FIELDS, contextFields); + } + if (searchConfigurations != null) { + builder.field(SEARCH_CONFIGURATIONS, searchConfigurations); + } + builder.field(IGNORE_FAILURE, ignoreFailure); + if (promptTemplate != null) { + builder.field(PROMPT_TEMPLATE, promptTemplate); + } + // Always include ratingType, use default if null + String ratingTypeValue = (ratingType != null) ? ratingType.name() : LLMJudgmentRatingType.SCORE0_1.name(); + builder.field(RATING_TYPE, ratingTypeValue); + builder.field(OVERWRITE_CACHE, overwriteCache); + if (connectorType != null) { + builder.field(CONNECTOR_TYPE, connectorType.name()); + } + builder.field(RATE_LIMIT, rateLimit); + return builder.endObject(); + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index be99c792..ae30dc0f 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -7,9 +7,13 @@ */ package org.opensearch.searchrelevance.judgments; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; +import static org.opensearch.searchrelevance.common.MLConstants.parseRateLimit; import static org.opensearch.searchrelevance.model.builder.SearchRequestBuilder.buildSearchRequest; import static org.opensearch.searchrelevance.utils.ParserUtils.combinedIndexAndDocId; import static org.opensearch.searchrelevance.utils.ParserUtils.generatePromptTemplateCode; @@ -22,6 +26,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -44,6 +49,7 @@ import org.opensearch.searchrelevance.executors.LlmJudgmentTaskManager; import org.opensearch.searchrelevance.ml.ChunkResult; import org.opensearch.searchrelevance.ml.MLAccessor; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentCache; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; @@ -105,58 +111,76 @@ public void generateJudgmentRating(Map metadata, ActionListener< private void generateJudgmentRatingInternal(Map metadata, ActionListener>> listener) { try { EventStatsManager.increment(EventStatName.LLM_JUDGMENT_RATING_GENERATIONS); + String querySetId = (String) metadata.get("querySetId"); List searchConfigurationList = (List) metadata.get("searchConfigurationList"); - int size = (int) metadata.get("size"); - - String modelId = (String) metadata.get("modelId"); - int tokenLimit = (int) metadata.get("tokenLimit"); - List contextFields = (List) metadata.get("contextFields"); - boolean ignoreFailure = (boolean) metadata.get("ignoreFailure"); - String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); - LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); - // Default to SCORE0_1 if ratingType is not provided - if (ratingType == null) { - ratingType = LLMJudgmentRatingType.SCORE0_1; - log.debug("No ratingType provided, defaulting to SCORE0_1"); - } - boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); QuerySet querySet = querySetDao.getQuerySetSync(querySetId); List searchConfigurations = searchConfigurationList.stream() .map(id -> searchConfigurationDao.getSearchConfigurationSync(id)) .collect(Collectors.toList()); - generateLLMJudgmentsAsync( - modelId, - size, - tokenLimit, - contextFields, - querySet, - searchConfigurations, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache, - listener - ); + // Build context from metadata + LlmJudgmentContext context = buildContextFromMetadata(metadata, searchConfigurations); + + generateLLMJudgmentsAsync(context, querySet, listener); } catch (Exception e) { log.error("Failed to generate LLM judgments", e); listener.onFailure(new SearchRelevanceException("Failed to generate LLM judgments", e, RestStatus.INTERNAL_SERVER_ERROR)); } } + private LlmJudgmentContext buildContextFromMetadata(Map metadata, List searchConfigurations) { + String modelId = (String) metadata.get("modelId"); + Integer sizeObj = (Integer) metadata.get("size"); + Integer tokenLimitObj = (Integer) metadata.get("tokenLimit"); + List contextFields = (List) metadata.get("contextFields"); + Boolean ignoreFailureObj = (Boolean) metadata.get("ignoreFailure"); + String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); + LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); + Boolean overwriteCacheObj = (Boolean) metadata.get(OVERWRITE_CACHE); + String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); + + // Apply defaults for null values + int size = sizeObj != null ? sizeObj : 5; + int tokenLimit = tokenLimitObj != null ? tokenLimitObj : 1000; + boolean ignoreFailure = ignoreFailureObj != null ? ignoreFailureObj : false; + boolean overwriteCache = overwriteCacheObj != null ? overwriteCacheObj : false; + + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + log.debug("No ratingType provided, defaulting to SCORE0_1"); + } + + ConnectorType connectorType = ConnectorType.OPENAI; + if (connectorTypeStr != null) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + log.warn("Invalid connectorType '{}' in metadata, defaulting to OpenAI", connectorTypeStr); + } + } + + long rateLimit = parseRateLimit(metadata.get(RATE_LIMIT)); + + return LlmJudgmentContext.builder() + .modelId(modelId) + .size(size) + .tokenLimit(tokenLimit) + .contextFields(contextFields != null ? contextFields : new ArrayList<>()) + .searchConfigurations(searchConfigurations) + .ignoreFailure(ignoreFailure) + .promptTemplate(promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE) + .ratingType(ratingType) + .overwriteCache(overwriteCache) + .connectorType(connectorType) + .rateLimit(rateLimit) + .build(); + } + private void generateLLMJudgmentsAsync( - String modelId, - int size, - int tokenLimit, - List contextFields, + LlmJudgmentContext context, QuerySet querySet, - List searchConfigurations, - boolean ignoreFailure, - String promptTemplate, - LLMJudgmentRatingType ratingType, - boolean overwriteCache, ActionListener>> listener ) { List queryTextsWithCustomInput = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList()); @@ -172,20 +196,9 @@ private void generateLLMJudgmentsAsync( taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { - return processQueryTextAsync( - modelId, - size, - tokenLimit, - contextFields, - searchConfigurations, - queryTextWithCustomInput, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache - ); + return processQueryTextAsync(context, queryTextWithCustomInput); } catch (Exception e) { - if (ignoreFailure) { + if (context.isIgnoreFailure()) { log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { @@ -193,7 +206,7 @@ private void generateLLMJudgmentsAsync( throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } - }, ignoreFailure, ActionListener.wrap(results -> { + }, context.isIgnoreFailure(), ActionListener.wrap((List> results) -> { int processedQueries = results.size(); int successQueries = (int) results.stream().mapToLong(result -> { List> ratings = (List>) result.get("ratings"); @@ -219,20 +232,9 @@ private void generateLLMJudgmentsAsync( taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { - return processQueryTextAsync( - modelId, - size, - tokenLimit, - contextFields, - searchConfigurations, - queryTextWithCustomInput, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache - ); + return processQueryTextAsync(context, queryTextWithCustomInput); } catch (Exception e) { - if (ignoreFailure) { + if (context.isIgnoreFailure()) { log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { @@ -240,7 +242,7 @@ private void generateLLMJudgmentsAsync( throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } - }, ignoreFailure, ActionListener.wrap(results -> { + }, context.isIgnoreFailure(), ActionListener.wrap((List> results) -> { int processedQueries = results.size(); int successQueries = (int) results.stream().mapToLong(result -> { List> ratings = (List>) result.get("ratings"); @@ -264,18 +266,7 @@ private void generateLLMJudgmentsAsync( }); } - private Map processQueryTextAsync( - String modelId, - int size, - int tokenLimit, - List contextFields, - List searchConfigurations, - String queryTextWithCustomInput, - boolean ignoreFailure, - String promptTemplate, - LLMJudgmentRatingType ratingType, - boolean overwriteCache - ) { + private Map processQueryTextAsync(LlmJudgmentContext context, String queryTextWithCustomInput) { log.info("Processing query text judgment: {}", queryTextWithCustomInput); ConcurrentMap allHits = new ConcurrentHashMap<>(); @@ -284,38 +275,33 @@ private Map processQueryTextAsync( try { // Step 1: Execute searches concurrently within this query text task - processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure); + processSearchConfigurationsAsync( + context.getSearchConfigurations(), + queryText, + context.getSize(), + allHits, + context.isIgnoreFailure() + ); // Step 2: Deduplicate from cache (skip if overwriteCache is true) List docIds = new ArrayList<>(allHits.keySet()); - String index = searchConfigurations.get(0).index(); - String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); + String index = context.getSearchConfigurations().get(0).index(); + String promptTemplateCode = generatePromptTemplateCode(context.getPromptTemplate(), context.getRatingType()); List unprocessedDocIds = deduplicateFromCache( index, queryTextWithCustomInput, - contextFields, + context.getContextFields(), docIds, docIdToScore, - ignoreFailure, + context.isIgnoreFailure(), promptTemplateCode, - overwriteCache + context.isOverwriteCache() ); // Step 3: Process with LLM if needed if (!unprocessedDocIds.isEmpty()) { - processWithLLM( - modelId, - queryTextWithCustomInput, - tokenLimit, - contextFields, - unprocessedDocIds, - allHits, - index, - docIdToScore, - promptTemplate, - ratingType - ); + processWithLLM(context, queryTextWithCustomInput, unprocessedDocIds, allHits, index, docIdToScore); } Map result = JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); @@ -415,16 +401,12 @@ private List deduplicateFromCache( } private void processWithLLM( - String modelId, + LlmJudgmentContext context, String queryTextWithCustomInput, - int tokenLimit, - List contextFields, List unprocessedDocIds, ConcurrentMap allHits, String index, - ConcurrentMap docIdToScore, - String promptTemplate, - LLMJudgmentRatingType ratingType + ConcurrentMap docIdToScore ) throws Exception { Map unionHits = new HashMap<>(); @@ -432,32 +414,26 @@ private void processWithLLM( for (String docId : unprocessedDocIds) { SearchHit hit = allHits.get(docId); String compositeKey = combinedIndexAndDocId(index, docId); - String contextSource = getContextSource(hit, contextFields); + String contextSource = getContextSource(hit, context.getContextFields()); unionHits.put(compositeKey, contextSource); } log.info("Processing {} uncached docs with LLM", unionHits.size()); log.debug("DEBUG: unionHits keys being sent to LLM: {}", unionHits.keySet()); log.debug("DEBUG: queryTextWithCustomInput: {}", queryTextWithCustomInput); - log.debug("DEBUG: modelId: {}, tokenLimit: {}, ratingType: {}", modelId, tokenLimit, ratingType); + log.debug( + "DEBUG: modelId: {}, tokenLimit: {}, ratingType: {}", + context.getModelId(), + context.getTokenLimit(), + context.getRatingType() + ); // Generate promptTemplateCode for cache updates - String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); + String promptTemplateCode = generatePromptTemplateCode(context.getPromptTemplate(), context.getRatingType()); // Synchronous LLM call PlainActionFuture> llmFuture = PlainActionFuture.newFuture(); - generateLLMJudgmentForQueryText( - modelId, - queryTextWithCustomInput, - tokenLimit, - contextFields, - unionHits, - new HashMap<>(), - promptTemplate, - ratingType, - promptTemplateCode, - llmFuture - ); + generateLLMJudgmentForQueryText(context, queryTextWithCustomInput, unionHits, new HashMap<>(), promptTemplateCode, llmFuture); Map llmResults = llmFuture.actionGet(); docIdToScore.putAll(llmResults); @@ -466,18 +442,14 @@ private void processWithLLM( } private void generateLLMJudgmentForQueryText( - String modelId, + LlmJudgmentContext context, String queryTextWithCustomInput, - int tokenLimit, - List contextFields, Map unprocessedUnionHits, Map docIdToRating, - String promptTemplate, - LLMJudgmentRatingType ratingType, String promptTemplateCode, ActionListener> listener ) { - log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", modelId, unprocessedUnionHits); + log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", context.getModelId(), unprocessedUnionHits); log.debug("processed docIdToRating before llm evaluation: {}", docIdToRating); if (unprocessedUnionHits.isEmpty()) { @@ -496,13 +468,15 @@ private void generateLLMJudgmentForQueryText( AtomicBoolean hasFailure = new AtomicBoolean(false); mlAccessor.predict( - modelId, - tokenLimit, + context.getModelId(), + context.getTokenLimit(), queryText, referenceData, unprocessedUnionHits, - promptTemplate, - ratingType, + context.getPromptTemplate(), + context.getRatingType(), + context.getConnectorType(), + context.getRateLimit(), new ActionListener() { @Override public void onResponse(ChunkResult chunkResult) { @@ -548,16 +522,16 @@ public void onResponse(ChunkResult chunkResult) { compositeKey, rawRatingScore ); - Double ratingScore = convertRatingScore(rawRatingScore, ratingType); + Double ratingScore = convertRatingScore(rawRatingScore, context.getRatingType()); String docId = getDocIdFromCompositeKey(compositeKey); log.debug("DEBUG: Converted rating - docId: {}, ratingScore: {}", docId, ratingScore); processedRatings.put(docId, ratingScore.toString()); updateJudgmentCache( compositeKey, queryTextWithCustomInput, - contextFields, + context.getContextFields(), ratingScore.toString(), - modelId, + context.getModelId(), promptTemplateCode ); } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java b/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java new file mode 100644 index 00000000..c5bf6dd5 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java @@ -0,0 +1,238 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import java.util.Locale; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; + +/** + * Adaptive rate limiter that learns optimal rates per model and handles circuit breaking + */ +public class AdaptiveRateLimiter { + private static final Logger log = LogManager.getLogger(AdaptiveRateLimiter.class); + + private final ConcurrentMap rateLimitStates = new ConcurrentHashMap<>(); + private volatile ScheduledExecutorService cleanupScheduler; + private final Object schedulerLock = new Object(); + + public AdaptiveRateLimiter() { + // Lazy initialization - don't create threads until actually needed + } + + private ScheduledExecutorService getOrCreateScheduler() { + if (cleanupScheduler == null) { + synchronized (schedulerLock) { + if (cleanupScheduler == null) { + cleanupScheduler = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, "adaptive-rate-limiter-cleanup"); + t.setDaemon(true); + return t; + } + }); + // Schedule cleanup every hour + cleanupScheduler.scheduleAtFixedRate(this::cleanupOldEntries, 1, 1, TimeUnit.HOURS); + } + } + } + return cleanupScheduler; + } + + public CompletableFuture applyRateLimit(String modelId, ConnectorType connectorType, long userRateLimit) { + String key = getKey(modelId, connectorType); + RateLimitState state = rateLimitStates.computeIfAbsent(key, k -> new RateLimitState(userRateLimit)); + + // Circuit breaker: Stop trying if model seems dead + if (state.shouldStopTrying()) { + return CompletableFuture.failedFuture(new RuntimeException("Model appears to be unavailable: " + modelId)); + } + + long delayMs = state.calculateDelay(); + + if (delayMs <= 0) { + return CompletableFuture.completedFuture(null); + } + + log.debug("Applying rate limit for {}: {}ms delay", key, delayMs); + + // Non-blocking delay using our managed executor + CompletableFuture future = new CompletableFuture<>(); + getOrCreateScheduler().schedule(() -> future.complete(null), delayMs, TimeUnit.MILLISECONDS); + return future; + } + + public void recordResult(String modelId, ConnectorType connectorType, boolean success, Throwable error) { + String key = getKey(modelId, connectorType); + RateLimitState state = rateLimitStates.get(key); + + if (state != null) { + if (success) { + state.onSuccess(); + } else if (isRateLimitError(error)) { + state.onRateLimit(); + } else if (isModelUnavailableError(error)) { + state.onModelUnavailable(); + } else { + state.onOtherError(); + } + } + } + + private String getKey(String modelId, ConnectorType connectorType) { + return modelId + ":" + connectorType.getValue(); + } + + private boolean isRateLimitError(Throwable error) { + if (error == null) return false; + String message = error.getMessage().toLowerCase(Locale.ROOT); + return message.contains("rate limit") + || message.contains("throttling") + || message.contains("too many requests") + || message.contains("high request rate") + || message.contains("acquire operation took longer") + || message.contains("connection from the pool"); + } + + private boolean isModelUnavailableError(Throwable error) { + if (error == null) return false; + String message = error.getMessage().toLowerCase(Locale.ROOT); + return message.contains("model not found") || message.contains("service unavailable") || message.contains("internal server error"); + } + + private void cleanupOldEntries() { + long cutoff = System.currentTimeMillis() - 3600_000; // 1 hour + java.util.concurrent.atomic.AtomicInteger removed = new java.util.concurrent.atomic.AtomicInteger(0); + + rateLimitStates.entrySet().removeIf(entry -> { + if (entry.getValue().lastRequestTime < cutoff) { + removed.incrementAndGet(); + return true; + } + return false; + }); + + if (removed.get() > 0) { + log.debug("Cleaned up {} old rate limit entries", removed.get()); + } + } + + public void scheduleTask(Runnable task, long delayMs) { + getOrCreateScheduler().schedule(task, delayMs, TimeUnit.MILLISECONDS); + } + + public void shutdown() { + synchronized (schedulerLock) { + if (cleanupScheduler != null) { + cleanupScheduler.shutdown(); + try { + if (!cleanupScheduler.awaitTermination(5, TimeUnit.SECONDS)) { + cleanupScheduler.shutdownNow(); + } + } catch (InterruptedException e) { + cleanupScheduler.shutdownNow(); + Thread.currentThread().interrupt(); + } + cleanupScheduler = null; + } + } + } + + private static class RateLimitState { + private volatile long currentDelayMs; + private volatile long lastRequestTime; + private volatile long lastSuccessTime; + private volatile int consecutiveSuccesses; + private volatile int consecutiveFailures; + private final long initialDelayMs; + + // Conservative parameters + private static final double BACKOFF_MULTIPLIER = 2.0; + private static final double RECOVERY_FACTOR = 0.9; + private static final int SUCCESSES_BEFORE_RECOVERY = 5; + private static final long MAX_DELAY_MS = 300_000; // 5 minutes + private static final int MAX_CONSECUTIVE_FAILURES = 10; + private static final long MODEL_DEAD_THRESHOLD_MS = 1800_000; // 30 minutes + private static final long CIRCUIT_OPEN_DURATION_MS = 300_000; // 5 minutes + + public RateLimitState(long initialDelayMs) { + this.initialDelayMs = Math.max(0, initialDelayMs); + this.currentDelayMs = this.initialDelayMs; + this.lastSuccessTime = System.currentTimeMillis(); + this.lastRequestTime = System.currentTimeMillis(); + } + + public long calculateDelay() { + lastRequestTime = System.currentTimeMillis(); + + // Circuit breaker: Longer delay if too many failures + if (isCircuitOpen()) { + return CIRCUIT_OPEN_DURATION_MS; + } + + return currentDelayMs; + } + + public boolean isCircuitOpen() { + return consecutiveFailures >= MAX_CONSECUTIVE_FAILURES; + } + + public boolean shouldStopTrying() { + long timeSinceLastSuccess = System.currentTimeMillis() - lastSuccessTime; + return timeSinceLastSuccess > MODEL_DEAD_THRESHOLD_MS; + } + + public void onSuccess() { + consecutiveFailures = 0; + consecutiveSuccesses++; + lastSuccessTime = System.currentTimeMillis(); + + // Conservative recovery - only after many successes + if (consecutiveSuccesses >= SUCCESSES_BEFORE_RECOVERY) { + currentDelayMs = Math.max( + initialDelayMs, // Never go below user's initial setting + (long) (currentDelayMs * RECOVERY_FACTOR) + ); + consecutiveSuccesses = 0; + log.debug("Rate limit recovered to {}ms", currentDelayMs); + } + } + + public void onRateLimit() { + consecutiveSuccesses = 0; + consecutiveFailures++; + + // Aggressive backoff on rate limit + long oldDelay = currentDelayMs; + currentDelayMs = Math.min(MAX_DELAY_MS, (long) (currentDelayMs * BACKOFF_MULTIPLIER)); + + log.debug("Rate limit hit, increased delay from {}ms to {}ms", oldDelay, currentDelayMs); + } + + public void onModelUnavailable() { + consecutiveSuccesses = 0; + consecutiveFailures++; + log.debug("Model unavailable error, consecutive failures: {}", consecutiveFailures); + } + + public void onOtherError() { + // Don't change rate limiting for non-rate-limit errors + consecutiveSuccesses = 0; + } + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java index 210ecca6..5e3b740e 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java @@ -10,13 +10,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.ml.connector.LLMConnectorFactory; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.utils.RatingOutputProcessor; @@ -29,6 +29,7 @@ public class MLAccessor { private final MachineLearningNodeClient mlClient; private final MLInputOutputTransformer transformer; + private final AdaptiveRateLimiter rateLimiter; private static final int MAX_RETRY_NUMBER = 3; private static final long RETRY_DELAY_MS = 1000; @@ -36,6 +37,16 @@ public class MLAccessor { public MLAccessor(MachineLearningNodeClient mlClient) { this.mlClient = mlClient; this.transformer = new MLInputOutputTransformer(); + this.rateLimiter = new AdaptiveRateLimiter(); + } + + /** + * Shutdown the MLAccessor and clean up resources + */ + public void shutdown() { + if (rateLimiter != null) { + rateLimiter.shutdown(); + } } public void predict( @@ -46,28 +57,50 @@ public void predict( Map hits, String promptTemplate, LLMJudgmentRatingType ratingType, + ConnectorType connectorType, + long rateLimit, ActionListener progressListener ) { log.debug( - "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}", + "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}, connectorType: {}, rateLimit: {}ms", modelId, searchText, hits.size(), + ratingType, + connectorType, + rateLimit + ); + + // Create transformer with appropriate connector + MLInputOutputTransformer connectorTransformer = new MLInputOutputTransformer(LLMConnectorFactory.create(connectorType)); + + List mlInputs = connectorTransformer.createMLInputs( + tokenLimit, + searchText, + referenceData, + hits, + promptTemplate, ratingType ); - List mlInputs = transformer.createMLInputs(tokenLimit, searchText, referenceData, hits, promptTemplate, ratingType); log.info("Number of chunks: {}", mlInputs.size()); log.debug("DEBUG: Created {} MLInput chunks", mlInputs.size()); ChunkProcessingContext context = new ChunkProcessingContext(mlInputs.size(), progressListener); for (int i = 0; i < mlInputs.size(); i++) { - processChunk(modelId, mlInputs.get(i), i, context); + processChunk(modelId, mlInputs.get(i), i, connectorType, rateLimit, context); } } - private void processChunk(String modelId, MLInput mlInput, int chunkIndex, ChunkProcessingContext context) { - processChunkWithFallback(modelId, mlInput, chunkIndex, false, context); + private void processChunk( + String modelId, + MLInput mlInput, + int chunkIndex, + ConnectorType connectorType, + long rateLimit, + ChunkProcessingContext context + ) { + processChunkWithFallback(modelId, mlInput, chunkIndex, false, connectorType, rateLimit, context); } private void processChunkWithFallback( @@ -75,28 +108,43 @@ private void processChunkWithFallback( MLInput mlInput, int chunkIndex, boolean triedWithoutResponseFormat, + ConnectorType connectorType, + long rateLimit, ChunkProcessingContext context ) { - predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, 0, triedWithoutResponseFormat, ActionListener.wrap(response -> { - log.info("Chunk {} processed successfully", chunkIndex); - String processedResponse = cleanResponse(response); - - // Check if parsing failed (empty ratings array) and we haven't tried without response_format yet - if ("[]".equals(processedResponse) && !triedWithoutResponseFormat) { - log.warn( - "Chunk {} returned empty ratings with response_format. Retrying without response_format for GPT-3.5 compatibility...", - chunkIndex - ); - // Create new MLInput without response_format and retry - MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); - scheduleRetry(() -> processChunkWithFallback(modelId, mlInputWithoutFormat, chunkIndex, true, context), RETRY_DELAY_MS); - } else { - context.handleSuccess(chunkIndex, processedResponse); - } - }, e -> { - log.error("Chunk {} failed after all retries", chunkIndex, e); - context.handleFailure(chunkIndex, e); - })); + predictSingleChunkWithRetry( + modelId, + mlInput, + chunkIndex, + 0, + triedWithoutResponseFormat, + connectorType, + rateLimit, + ActionListener.wrap(response -> { + log.info("Chunk {} processed successfully", chunkIndex); + String processedResponse = cleanResponse(response); + + // Check if parsing failed (empty ratings array) and we haven't tried without response_format yet + // Only apply this retry logic for OpenAI connectors + if ("[]".equals(processedResponse) && !triedWithoutResponseFormat && connectorType == ConnectorType.OPENAI) { + log.warn( + "Chunk {} returned empty ratings with response_format. Retrying without response_format for GPT-3.5 compatibility...", + chunkIndex + ); + // Create new MLInput without response_format and retry + MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); + scheduleRetry( + () -> processChunkWithFallback(modelId, mlInputWithoutFormat, chunkIndex, true, connectorType, rateLimit, context), + RETRY_DELAY_MS + ); + } else { + context.handleSuccess(chunkIndex, processedResponse); + } + }, e -> { + log.error("Chunk {} failed after all retries", chunkIndex, e); + context.handleFailure(chunkIndex, e); + }) + ); } private String cleanResponse(String response) { @@ -118,9 +166,11 @@ private void predictSingleChunkWithRetry( int chunkIndex, int retryCount, boolean triedWithoutResponseFormat, + ConnectorType connectorType, + long rateLimit, ActionListener chunkListener ) { - predictSingleChunk(modelId, mlInput, new ActionListener() { + predictSingleChunk(modelId, mlInput, connectorType, rateLimit, new ActionListener() { @Override public void onResponse(String response) { log.debug( @@ -141,8 +191,8 @@ public void onFailure(Exception e) { triedWithoutResponseFormat, retryCount ); - // If we haven't tried without response_format yet, try that first before regular retries - if (!triedWithoutResponseFormat) { + // Only try response_format fallback for OpenAI connectors + if (!triedWithoutResponseFormat && connectorType == ConnectorType.OPENAI) { log.warn( "Chunk {} failed with response_format. Retrying without response_format for GPT-3.5 compatibility...", chunkIndex @@ -154,7 +204,16 @@ public void onFailure(Exception e) { long delay = RETRY_DELAY_MS; scheduleRetry( - () -> predictSingleChunkWithRetry(modelId, mlInputWithoutFormat, chunkIndex, 0, true, chunkListener), + () -> predictSingleChunkWithRetry( + modelId, + mlInputWithoutFormat, + chunkIndex, + 0, + true, + connectorType, + rateLimit, + chunkListener + ), delay ); } else if (retryCount < MAX_RETRY_NUMBER) { @@ -162,7 +221,16 @@ public void onFailure(Exception e) { long delay = RETRY_DELAY_MS * (long) Math.pow(2, retryCount); scheduleRetry( - () -> predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, true, chunkListener), + () -> predictSingleChunkWithRetry( + modelId, + mlInput, + chunkIndex, + retryCount + 1, + true, + connectorType, + rateLimit, + chunkListener + ), delay ); } else { @@ -192,25 +260,55 @@ private MLInput recreateMLInputWithoutResponseFormat(MLInput originalInput) { } private void scheduleRetry(Runnable runnable, long delayMs) { - CompletableFuture.delayedExecutor(delayMs, TimeUnit.MILLISECONDS).execute(runnable); + // Use the rate limiter's managed scheduler for retries to avoid thread leaks + rateLimiter.scheduleTask(runnable, delayMs); } - public void predictSingleChunk(String modelId, MLInput mlInput, ActionListener listener) { - log.debug("DEBUG: predictSingleChunk called with modelId: {}", modelId); - RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); - Map params = dataset.getParameters(); + public void predictSingleChunk( + String modelId, + MLInput mlInput, + ConnectorType connectorType, + long rateLimit, + ActionListener listener + ) { log.debug( - "DEBUG: MLInput parameters - has response_format: {}, has messages: {}", - params.containsKey("response_format"), - params.containsKey("messages") + "DEBUG: predictSingleChunk called with modelId: {}, connectorType: {}, rateLimit: {}ms", + modelId, + connectorType, + rateLimit ); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - log.debug("DEBUG: ML prediction succeeded, extracting response content"); - listener.onResponse(transformer.extractResponseContent(mlOutput)); - }, e -> { - log.debug("DEBUG: ML prediction failed with error: {}", e.getMessage()); - listener.onFailure(e); - })); + + // Apply rate limiting before making the prediction + rateLimiter.applyRateLimit(modelId, connectorType, rateLimit).whenComplete((result, rateLimitError) -> { + if (rateLimitError != null) { + log.error("Rate limiting failed for modelId: {}", modelId, rateLimitError); + listener.onFailure(new Exception(rateLimitError)); + return; + } + + // Create connector-specific transformer + MLInputOutputTransformer connectorTransformer = new MLInputOutputTransformer(LLMConnectorFactory.create(connectorType)); + + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map params = dataset.getParameters(); + log.debug( + "DEBUG: MLInput parameters - has response_format: {}, has messages: {}", + params.containsKey("response_format"), + params.containsKey("messages") + ); + + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + log.debug("DEBUG: ML prediction succeeded, extracting response content"); + // Record successful result for rate limiter learning + rateLimiter.recordResult(modelId, connectorType, true, null); + listener.onResponse(connectorTransformer.extractResponseContent(mlOutput)); + }, e -> { + log.debug("DEBUG: ML prediction failed with error: {}", e.getMessage()); + // Record failed result for rate limiter learning + rateLimiter.recordResult(modelId, connectorType, false, e); + listener.onFailure(e); + })); + }); } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java index c1141521..f01e7ce0 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java @@ -7,18 +7,12 @@ */ package org.opensearch.searchrelevance.ml; -import static org.opensearch.searchrelevance.common.MLConstants.PARAM_MESSAGES_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_JSON_MESSAGES_SHELL; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_BINARY; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_END; import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_BINARY_SCHEMA; import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_NUMERIC_SCHEMA; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CHOICES_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CONTENT_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_FORMAT_TEMPLATE; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_MESSAGE_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; import java.io.IOException; import java.util.ArrayList; @@ -27,6 +21,8 @@ import java.util.Locale; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.XContentBuilder; @@ -37,15 +33,27 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.ml.connector.LLMConnector; +import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; -import lombok.extern.log4j.Log4j2; - /** * Handles ML input/output transformations for search relevance predictions */ -@Log4j2 public class MLInputOutputTransformer { + private static final Logger log = LogManager.getLogger(MLInputOutputTransformer.class); + + private final LLMConnector connector; + + public MLInputOutputTransformer() { + // Default to OpenAI for backward compatibility + this.connector = new OpenAIConnector(); + } + + public MLInputOutputTransformer(LLMConnector connector) { + this.connector = connector; + } public List createMLInputs( int tokenLimit, @@ -133,10 +141,12 @@ public MLInput createMLInput( Map parameters = new HashMap<>(); String messagesArray = buildMessagesArray(searchText, referenceData, hits, promptTemplate, ratingType); - parameters.put(PARAM_MESSAGES_FIELD, messagesArray); + // Use connector-specific parameter name + String paramName = connector.getMessageParameterName(); + parameters.put(paramName, messagesArray); - // Only add response_format if requested (for models that support it) - if (includeResponseFormat) { + // Only add response_format if requested and connector is OpenAI (only OpenAI supports response_format) + if (includeResponseFormat && connector.getType() == ConnectorType.OPENAI) { String responseFormat = getResponseFormat(ratingType); parameters.put("response_format", responseFormat); } @@ -155,7 +165,7 @@ private String buildMessagesArray( String hitsJson = buildHitsJson(hits); String userContent = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, promptTemplate); String systemPrompt = getSystemPrompt(ratingType); - return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, systemPrompt, escapeJson(userContent)); + return connector.formatPrompt(systemPrompt, userContent); } catch (IOException e) { log.error("Error converting hits to JSON string", e); throw new IllegalArgumentException("Failed to process hits", e); @@ -165,8 +175,14 @@ private String buildMessagesArray( private static String getSystemPrompt(LLMJudgmentRatingType ratingType) { String systemPromptStart; String systemPromptEnd = PROMPT_SEARCH_RELEVANCE_SCORE_END; + + // Handle null ratingType with default + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + } + switch (ratingType) { - case LLMJudgmentRatingType.SCORE0_1: + case SCORE0_1: systemPromptStart = PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; break; default: @@ -176,12 +192,17 @@ private static String getSystemPrompt(LLMJudgmentRatingType ratingType) { } private static String getResponseFormat(LLMJudgmentRatingType ratingType) { + // Handle null ratingType with default + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + } + String schema; switch (ratingType) { - case LLMJudgmentRatingType.SCORE0_1: + case SCORE0_1: schema = RATING_SCORE_NUMERIC_SCHEMA; break; - case LLMJudgmentRatingType.RELEVANT_IRRELEVANT: + case RELEVANT_IRRELEVANT: schema = RATING_SCORE_BINARY_SCHEMA; break; default: @@ -221,8 +242,6 @@ public String extractResponseContent(MLOutput mlOutput) { ModelTensor tensor = tensorOutputList.get(0).getMlModelTensors().get(0); Map dataMap = tensor.getDataAsMap(); - Map choices = (Map) ((List) dataMap.get(RESPONSE_CHOICES_FIELD)).get(0); - Map message = (Map) choices.get(RESPONSE_MESSAGE_FIELD); - return (String) message.get(RESPONSE_CONTENT_FIELD); + return connector.extractResponse(dataMap); } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java new file mode 100644 index 00000000..057f35ec --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Claude-specific connector implementation + */ +public class ClaudeConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + String combinedContent = systemContent + "\n\n" + userContent; + return String.format( + Locale.ROOT, + "[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"%s\"}]}]", + escapeJson(combinedContent) + ); + } + + @Override + public String extractResponse(Map rawResponse) { + List content = (List) rawResponse.get("content"); + if (content != null && !content.isEmpty()) { + Map textContent = (Map) content.get(0); + return (String) textContent.get("text"); + } + return ""; + } + + @Override + public ConnectorType getType() { + return ConnectorType.CLAUDE; + } + + @Override + public String getMessageParameterName() { + return "messages"; // Claude uses plural "messages" + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java new file mode 100644 index 00000000..0c194051 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.Map; + +/** + * Cohere-specific connector implementation + */ +public class CohereConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + return escapeJson(systemContent) + "\\n\\n" + escapeJson(userContent); + } + + @Override + public String extractResponse(Map rawResponse) { + return (String) rawResponse.get("text"); + } + + @Override + public ConnectorType getType() { + return ConnectorType.COHERE; + } + + @Override + public String getMessageParameterName() { + return "message"; // Cohere uses singular "message" + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java new file mode 100644 index 00000000..0340e94c --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * Enum representing different LLM connector types + */ +public enum ConnectorType { + OPENAI("openai"), + CLAUDE("claude"), + COHERE("cohere"), + DEEPSEEK("deepseek"); + + private final String value; + + ConnectorType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return value; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java new file mode 100644 index 00000000..4b7ad1d0 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * DeepSeek-specific connector implementation (uses OpenAI-compatible format) + */ +public class DeepSeekConnector extends OpenAIConnector { + + @Override + public ConnectorType getType() { + return ConnectorType.DEEPSEEK; + } + + @Override + public String getMessageParameterName() { + return "messages"; // DeepSeek uses plural "messages" + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java new file mode 100644 index 00000000..2ec803d8 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import java.util.Map; + +/** + * Interface for LLM connector implementations that handle different LLM providers + */ +public interface LLMConnector { + + /** + * Formats the prompt according to the specific LLM provider's requirements + * + * @param systemContent The system message content + * @param userContent The user message content + * @return Formatted prompt string ready for the LLM API + */ + String formatPrompt(String systemContent, String userContent); + + /** + * Extracts the response text from the raw LLM API response + * + * @param rawResponse The raw response map from the LLM API + * @return Extracted response text + */ + String extractResponse(Map rawResponse); + + /** + * Returns the connector type + * + * @return The ConnectorType enum value + */ + ConnectorType getType(); + + /** + * Returns the parameter name used for messages in ML input + * + * @return The parameter name (e.g., "messages" for Claude/OpenAI, "message" for Cohere) + */ + String getMessageParameterName(); +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java new file mode 100644 index 00000000..a30b7d48 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * Factory class for creating LLM connector instances + */ +public class LLMConnectorFactory { + + /** + * Creates an LLM connector instance based on the connector type + * + * @param type The connector type + * @return LLMConnector instance + * @throws IllegalArgumentException if connector type is not supported + */ + public static LLMConnector create(ConnectorType type) { + return switch (type) { + case OPENAI -> new OpenAIConnector(); + case CLAUDE -> new ClaudeConnector(); + case COHERE -> new CohereConnector(); + case DEEPSEEK -> new DeepSeekConnector(); + }; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java new file mode 100644 index 00000000..4810ee91 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_JSON_MESSAGES_SHELL; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CHOICES_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CONTENT_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_MESSAGE_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * OpenAI-specific connector implementation + */ +public class OpenAIConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, escapeJson(systemContent), escapeJson(userContent)); + } + + @Override + public String extractResponse(Map rawResponse) { + Map choices = (Map) ((List) rawResponse.get(RESPONSE_CHOICES_FIELD)).get(0); + Map message = (Map) choices.get(RESPONSE_MESSAGE_FIELD); + return (String) message.get(RESPONSE_CONTENT_FIELD); + } + + @Override + public ConnectorType getType() { + return ConnectorType.OPENAI; + } + + @Override + public String getMessageParameterName() { + return "messages"; // OpenAI uses plural "messages" + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java index 35202ed0..6d5cdc4b 100644 --- a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java +++ b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java @@ -393,4 +393,11 @@ public List> getSettings() { public List> getExecutorBuilders(Settings settings) { return List.of(SearchRelevanceExecutor.getExecutorBuilder(settings)); } + + @Override + public void close() throws IOException { + if (mlAccessor != null) { + mlAccessor.shutdown(); + } + } } diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java index f82ebbbb..6f34bb6d 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java @@ -9,10 +9,13 @@ import static java.util.Collections.singletonList; import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; +import static org.opensearch.searchrelevance.common.MLConstants.parseRateLimit; import static org.opensearch.searchrelevance.common.MLConstants.validateTokenLimit; import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; import static org.opensearch.searchrelevance.common.PluginConstants.CLICK_MODEL; @@ -48,6 +51,7 @@ import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.searchrelevance.exception.SearchRelevanceException; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; @@ -166,6 +170,34 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } boolean overwriteCache = Optional.ofNullable((Boolean) source.get(OVERWRITE_CACHE)).orElse(Boolean.FALSE); + // Parse connectorType - optional, defaults to OpenAI + ConnectorType connectorType = ConnectorType.OPENAI; // default + String connectorTypeStr = (String) source.get(CONNECTOR_TYPE); + if (connectorTypeStr != null) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new SearchRelevanceException( + String.format( + Locale.ROOT, + "Invalid connectorType: '%s'. Valid values are: %s", + connectorTypeStr, + String.join( + ", ", + ConnectorType.OPENAI.name(), + ConnectorType.CLAUDE.name(), + ConnectorType.COHERE.name(), + ConnectorType.DEEPSEEK.name() + ) + ), + RestStatus.BAD_REQUEST + ); + } + } + + // Parse rateLimit - optional, defaults to 0 + long rateLimit = parseRateLimit(source.get(RATE_LIMIT)); + createRequest = new PutLlmJudgmentRequest( type, name, @@ -179,7 +211,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ignoreFailure, promptTemplate, llmJudgmentRatingType, - overwriteCache + overwriteCache, + connectorType, + rateLimit ); } case UBI_JUDGMENT -> { diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 0c3c3df6..f984549d 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -7,10 +7,12 @@ */ package org.opensearch.searchrelevance.transport.judgment; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; -import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; import static org.opensearch.searchrelevance.ubi.UbiValidator.checkUbiIndicesExist; import java.util.ArrayList; @@ -32,8 +34,10 @@ import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.judgments.BaseJudgmentsProcessor; import org.opensearch.searchrelevance.judgments.JudgmentsProcessorFactory; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.AsyncStatus; import org.opensearch.searchrelevance.model.Judgment; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.utils.TimeUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -45,6 +49,20 @@ public class PutJudgmentTransportAction extends HandledTransportAction buildMetadata(PutJudgmentRequest request) { switch (request.getType()) { case LLM_JUDGMENT -> { PutLlmJudgmentRequest llmRequest = (PutLlmJudgmentRequest) request; + + // Store flat metadata fields for compatibility + metadata.put(QUERY_SET_ID, llmRequest.getQuerySetId()); + metadata.put(SEARCH_CONFIGURATION_LIST, llmRequest.getSearchConfigurationList()); metadata.put(MODEL_ID, llmRequest.getModelId()); - metadata.put("querySetId", llmRequest.getQuerySetId()); - metadata.put("size", llmRequest.getSize()); - metadata.put("searchConfigurationList", llmRequest.getSearchConfigurationList()); - metadata.put("tokenLimit", llmRequest.getTokenLimit()); - metadata.put("contextFields", llmRequest.getContextFields()); - metadata.put("ignoreFailure", llmRequest.isIgnoreFailure()); - metadata.put(PROMPT_TEMPLATE, llmRequest.getPromptTemplate()); - metadata.put(LLM_JUDGMENT_RATING_TYPE, llmRequest.getLlmJudgmentRatingType()); + metadata.put(SIZE, llmRequest.getSize()); + metadata.put(TOKEN_LIMIT, llmRequest.getTokenLimit()); + metadata.put(CONTEXT_FIELDS, llmRequest.getContextFields()); + metadata.put(IGNORE_FAILURE, llmRequest.isIgnoreFailure()); + metadata.put( + PROMPT_TEMPLATE, + llmRequest.getPromptTemplate() != null ? llmRequest.getPromptTemplate() : DEFAULT_PROMPT_TEMPLATE + ); + metadata.put( + LLM_JUDGMENT_RATING_TYPE, + llmRequest.getLlmJudgmentRatingType() != null ? llmRequest.getLlmJudgmentRatingType() : LLMJudgmentRatingType.SCORE0_1 + ); metadata.put(OVERWRITE_CACHE, llmRequest.isOverwriteCache()); + metadata.put( + CONNECTOR_TYPE, + llmRequest.getConnectorType() != null ? llmRequest.getConnectorType().name() : ConnectorType.OPENAI.name() + ); + metadata.put(RATE_LIMIT, llmRequest.getRateLimit()); } case UBI_JUDGMENT -> { if (!checkUbiIndicesExist(clusterService)) { throw new SearchRelevanceException("UBI is not initialized", RestStatus.CONFLICT); } - ; PutUbiJudgmentRequest ubiRequest = (PutUbiJudgmentRequest) request; - metadata.put("clickModel", ubiRequest.getClickModel()); - metadata.put("maxRank", ubiRequest.getMaxRank()); - metadata.put("startDate", ubiRequest.getStartDate()); - metadata.put("endDate", ubiRequest.getEndDate()); + metadata.put(CLICK_MODEL, ubiRequest.getClickModel()); + metadata.put(MAX_RANK, ubiRequest.getMaxRank()); + metadata.put(START_DATE, ubiRequest.getStartDate()); + metadata.put(END_DATE, ubiRequest.getEndDate()); } case IMPORT_JUDGMENT -> { PutImportJudgmentRequest importRequest = (PutImportJudgmentRequest) request; - metadata.put("judgmentRatings", importRequest.getJudgmentRatings()); + metadata.put(JUDGMENT_RATINGS, importRequest.getJudgmentRatings()); } } return metadata; diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java index 24328e9b..41abc192 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java @@ -12,6 +12,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; @@ -57,6 +58,16 @@ public class PutLlmJudgmentRequest extends PutJudgmentRequest { */ private boolean overwriteCache; + /** + * LLM connector type for formatting prompts and extracting responses + */ + private ConnectorType connectorType; + + /** + * Rate limit in milliseconds between requests (0 = no limit) + */ + private long rateLimit; + public PutLlmJudgmentRequest( @NonNull JudgmentType type, @NonNull String name, @@ -70,7 +81,9 @@ public PutLlmJudgmentRequest( boolean ignoreFailure, String promptTemplate, LLMJudgmentRatingType llmJudgmentRatingType, - boolean overwriteCache + boolean overwriteCache, + ConnectorType connectorType, + long rateLimit ) { super(type, name, description); this.modelId = modelId; @@ -83,6 +96,8 @@ public PutLlmJudgmentRequest( this.promptTemplate = promptTemplate; this.llmJudgmentRatingType = llmJudgmentRatingType; this.overwriteCache = overwriteCache; + this.connectorType = connectorType != null ? connectorType : ConnectorType.OPENAI; // default to OpenAI + this.rateLimit = Math.max(0, rateLimit); // ensure non-negative } public PutLlmJudgmentRequest(StreamInput in) throws IOException { @@ -97,6 +112,10 @@ public PutLlmJudgmentRequest(StreamInput in) throws IOException { this.promptTemplate = in.readOptionalString(); this.llmJudgmentRatingType = in.readOptionalWriteable(LLMJudgmentRatingType::readFromStream); this.overwriteCache = Boolean.TRUE.equals(in.readOptionalBoolean()); + String connectorTypeStr = in.readOptionalString(); + this.connectorType = connectorTypeStr != null ? ConnectorType.valueOf(connectorTypeStr) : ConnectorType.OPENAI; + Long rateLimitValue = in.readOptionalLong(); + this.rateLimit = rateLimitValue != null ? rateLimitValue : 0L; // default to 0 (no limit) if not provided } @Override @@ -112,6 +131,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(promptTemplate); out.writeOptionalWriteable(llmJudgmentRatingType); out.writeOptionalBoolean(overwriteCache); + out.writeOptionalString(connectorType != null ? connectorType.name() : null); + out.writeOptionalLong(rateLimit); } public String getModelId() { @@ -154,4 +175,12 @@ public boolean isOverwriteCache() { return overwriteCache; } + public ConnectorType getConnectorType() { + return connectorType; + } + + public long getRateLimit() { + return rateLimit; + } + } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java index 91cca1c3..aba85fdd 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java @@ -118,13 +118,13 @@ public void testLlmJudgmentWithPromptTemplate_thenSuccessful() { assertEquals("LLM_JUDGMENT", source.get("type")); assertNotNull(source.get("status")); // Should be COMPLETED or IN_PROGRESS - // Verify metadata contains new fields + // Verify metadata contains structured context Map metadata = (Map) source.get("metadata"); assertNotNull(metadata); assertNotNull(metadata.get("promptTemplate")); assertTrue(((String) metadata.get("promptTemplate")).contains("{{queryText}}")); assertNotNull(metadata.get("llmJudgmentRatingType")); - assertEquals("SCORE0_1", metadata.get("llmJudgmentRatingType")); + assertEquals("SCORE0_1", metadata.get("llmJudgmentRatingType").toString()); assertNotNull(metadata.get("overwriteCache")); // Verify judgmentRatings format @@ -202,7 +202,7 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgment01Doc = entityAsMap(getJudgment01Response); Map source01 = (Map) judgment01Doc.get("_source"); Map metadata01 = (Map) source01.get("metadata"); - assertEquals("SCORE0_1", metadata01.get("llmJudgmentRatingType")); + assertEquals("SCORE0_1", metadata01.get("llmJudgmentRatingType").toString()); // Test RELEVANT_IRRELEVANT rating type String binaryBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentBinary.json").toURI())); @@ -234,7 +234,7 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgmentBinaryDoc = entityAsMap(getJudgmentBinaryResponse); Map sourceBinary = (Map) judgmentBinaryDoc.get("_source"); Map metadataBinary = (Map) sourceBinary.get("metadata"); - assertEquals("RELEVANT_IRRELEVANT", metadataBinary.get("llmJudgmentRatingType")); + assertEquals("RELEVANT_IRRELEVANT", metadataBinary.get("llmJudgmentRatingType").toString()); } @SneakyThrows @@ -400,12 +400,14 @@ public void testLlmJudgmentWithoutOptionalFields_thenSuccessfulWithDefaults() { assertNotNull("promptTemplate should not be null when not provided", promptTemplate); assertEquals("promptTemplate should have default value", DEFAULT_PROMPT_TEMPLATE, promptTemplate); - // llmJudgmentRatingType should have a default or be null + // ratingType should have a default value Object ratingType = metadata.get("llmJudgmentRatingType"); - // Either null or has a default value + assertNotNull("ratingType should not be null", ratingType); + assertEquals("ratingType should have default value", "SCORE0_1", ratingType.toString()); // overwriteCache should default to false Object overwriteCache = metadata.get("overwriteCache"); - assertTrue(overwriteCache == null || overwriteCache.equals(false)); + assertNotNull("overwriteCache should not be null", overwriteCache); + assertEquals("overwriteCache should default to false", false, overwriteCache); } } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java index adf9b2f7..ca85b2ee 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java @@ -13,6 +13,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.transport.judgment.PutImportJudgmentRequest; @@ -93,7 +94,9 @@ public void testLlmJudgmentRequestStreams() throws IOException { false, "test_prompt_template", LLMJudgmentRatingType.SCORE0_1, - true + true, + ConnectorType.OPENAI, + 1000L ); BytesStreamOutput output = new BytesStreamOutput(); @@ -130,7 +133,9 @@ public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOExcep true, null, null, - false + false, + ConnectorType.OPENAI, + 1000L ); BytesStreamOutput output = new BytesStreamOutput(); diff --git a/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java b/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java new file mode 100644 index 00000000..1a4de45f --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.test.OpenSearchTestCase; + +public class AdaptiveRateLimiterTests extends OpenSearchTestCase { + + private AdaptiveRateLimiter rateLimiter; + + @Override + public void setUp() throws Exception { + super.setUp(); + rateLimiter = new AdaptiveRateLimiter(); + } + + @Override + public void tearDown() throws Exception { + if (rateLimiter != null) { + rateLimiter.shutdown(); + } + super.tearDown(); + } + + public void testNoRateLimitWhenZero() throws Exception { + long startTime = System.currentTimeMillis(); + + CompletableFuture future = rateLimiter.applyRateLimit("test-model", ConnectorType.OPENAI, 0L); + future.get(1, TimeUnit.SECONDS); + + long elapsed = System.currentTimeMillis() - startTime; + assertTrue("Should complete immediately with 0 rate limit", elapsed < 100); + } + + public void testRateLimitApplied() throws Exception { + // Test that non-zero rate limit creates a delay future (don't wait for it to avoid thread leaks) + CompletableFuture future = rateLimiter.applyRateLimit("test-model", ConnectorType.CLAUDE, 100L); + assertNotNull("Should return a future for rate limiting", future); + assertFalse("Future should not be completed immediately for non-zero rate limit", future.isDone()); + } + + public void testSuccessRecording() { + // Should not throw exception + rateLimiter.recordResult("test-model", ConnectorType.OPENAI, true, null); + } + + public void testRateLimitErrorRecording() { + Exception rateLimitError = new RuntimeException("Rate limit exceeded"); + rateLimiter.recordResult("test-model", ConnectorType.CLAUDE, false, rateLimitError); + } + + public void testModelUnavailableError() { + Exception unavailableError = new RuntimeException("Model not found"); + rateLimiter.recordResult("test-model", ConnectorType.COHERE, false, unavailableError); + } + + public void testCircuitBreakerAfterManyFailures() { + String modelId = "failing-model"; + ConnectorType connectorType = ConnectorType.CLAUDE; + + // Record many failures to trigger circuit breaker + for (int i = 0; i < 15; i++) { + rateLimiter.recordResult(modelId, connectorType, false, new RuntimeException("Service unavailable")); + } + + // Test that circuit breaker creates a future (don't wait to avoid thread leaks) + CompletableFuture future = rateLimiter.applyRateLimit(modelId, connectorType, 100L); + assertNotNull("Should return a future even with circuit breaker", future); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java index 966780df..75d56b91 100644 --- a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.test.OpenSearchTestCase; @@ -48,6 +49,16 @@ */ public class MLAccessorIntegrationTests extends OpenSearchTestCase { + private MLAccessor mlAccessor; + + @Override + public void tearDown() throws Exception { + if (mlAccessor != null) { + mlAccessor.shutdown(); + } + super.tearDown(); + } + /** * Note: GPT-3.5 fallback testing is documented in TESTING_GPT35_FALLBACK.md as "Scenario 2" * This scenario requires triggering scheduleRetry which creates CompletableFuture threads that leak. @@ -62,7 +73,7 @@ public class MLAccessorIntegrationTests extends OpenSearchTestCase { */ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exception { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); - MLAccessor mlAccessor = new MLAccessor(mlClient); + mlAccessor = new MLAccessor(mlClient); AtomicInteger attemptCount = new AtomicInteger(0); CountDownLatch latch = new CountDownLatch(1); @@ -99,7 +110,9 @@ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exc hits, "Test prompt", LLMJudgmentRatingType.SCORE0_1, - ActionListener.wrap(chunkResult -> { + ConnectorType.OPENAI, + 1000L, + ActionListener.wrap((ChunkResult chunkResult) -> { result.set(chunkResult); latch.countDown(); }, e -> latch.countDown()) @@ -133,6 +146,72 @@ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exc * daemon threads that cannot be properly cleaned up in the OpenSearch test framework. */ + /** + * Test that non-OpenAI connectors (Claude, Cohere, DeepSeek) don't trigger response_format retry logic. + * This verifies the fix for the issue where all connector types were attempting response_format fallback. + */ + public void testNonOpenAIConnectors_DoNotUseResponseFormatRetry() throws Exception { + // Test each non-OpenAI connector type + ConnectorType[] nonOpenAIConnectors = { ConnectorType.CLAUDE, ConnectorType.COHERE, ConnectorType.DEEPSEEK }; + + for (ConnectorType connectorType : nonOpenAIConnectors) { + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + mlAccessor = new MLAccessor(mlClient); + + AtomicInteger attemptCount = new AtomicInteger(0); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference error = new AtomicReference<>(); + + // Mock ML client to always fail - this should NOT trigger response_format retry for non-OpenAI + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + attemptCount.incrementAndGet(); + + // Fail immediately - non-OpenAI connectors should not retry with response_format removal + listener.onFailure(new RuntimeException("Simulated failure for " + connectorType)); + return null; + }).when(mlClient).predict(any(), any(MLInput.class), any()); + + // Execute prediction + Map hits = Map.of("doc1", "test content"); + AtomicReference result = new AtomicReference<>(); + mlAccessor.predict( + "test-model", + 4000, + "test query", + new HashMap<>(), + hits, + "Test prompt", + LLMJudgmentRatingType.SCORE0_1, + connectorType, + 1000L, + ActionListener.wrap((ChunkResult chunkResult) -> { + result.set(chunkResult); + latch.countDown(); + }, e -> { + error.set(e); + latch.countDown(); + }) + ); + + assertTrue("Should complete for " + connectorType, latch.await(15, TimeUnit.SECONDS)); + + // Verify all chunks failed (as expected) + ChunkResult chunkResult = result.get(); + assertNotNull("Should have result for " + connectorType, chunkResult); + assertEquals("All chunks should have failed for " + connectorType, 0, chunkResult.getSuccessfulChunksCount()); + assertTrue("Should have failed chunks for " + connectorType, chunkResult.getFailedChunksCount() > 0); + + // Key assertion: For non-OpenAI connectors, should attempt regular retries but not response_format retry + // The fix ensures that only OpenAI connectors attempt the response_format fallback + assertTrue("Should attempt at least once for " + connectorType, attemptCount.get() >= 1); + // With regular retry logic (3 attempts), we expect exactly 4 attempts (1 initial + 3 retries) + assertEquals("Should attempt exactly 4 times (1 initial + 3 retries) for " + connectorType, 4, attemptCount.get()); + + mlAccessor.shutdown(); + } + } + // ============================================ // Helper Methods // ============================================ diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java new file mode 100644 index 00000000..0cc26cfc --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.mockito.stubbing.Answer; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ClaudeConnector; +import org.opensearch.searchrelevance.ml.connector.CohereConnector; +import org.opensearch.searchrelevance.ml.connector.DeepSeekConnector; +import org.opensearch.searchrelevance.ml.connector.LLMConnector; +import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInputOutputTransformerConnectorTests extends OpenSearchTestCase { + + public void testDefaultConstructorUsesOpenAI() { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(); + assertNotNull(transformer); + // Default behavior should work (uses OpenAI connector internally) + } + + public void testConnectorConstructor() { + LLMConnector claudeConnector = new ClaudeConnector(); + MLInputOutputTransformer transformer = new MLInputOutputTransformer(claudeConnector); + assertNotNull(transformer); + } + + public void testExtractResponseContentWithDifferentConnectors() { + // Test OpenAI response extraction + testResponseExtraction(new OpenAIConnector(), createOpenAIResponse(), "OpenAI response"); + + // Test Claude response extraction + testResponseExtraction(new ClaudeConnector(), createClaudeResponse(), "Claude response"); + + // Test Cohere response extraction + testResponseExtraction(new CohereConnector(), createCohereResponse(), "Cohere response"); + + // Test DeepSeek response extraction (uses OpenAI format but returns DEEPSEEK type) + testResponseExtraction(new DeepSeekConnector(), createDeepSeekResponse(), "DeepSeek response"); + } + + private void testResponseExtraction(LLMConnector connector, MLOutput mlOutput, String expectedResponse) { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(connector); + String result = transformer.extractResponseContent(mlOutput); + assertEquals(expectedResponse, result); + } + + private MLOutput createOpenAIResponse() { + return createMockMLOutput(Map.of("choices", List.of(Map.of("message", Map.of("content", "OpenAI response"))))); + } + + private MLOutput createClaudeResponse() { + return createMockMLOutput(Map.of("content", List.of(Map.of("text", "Claude response")))); + } + + private MLOutput createCohereResponse() { + return createMockMLOutput(Map.of("text", "Cohere response")); + } + + private MLOutput createDeepSeekResponse() { + return createMockMLOutput(Map.of("choices", List.of(Map.of("message", Map.of("content", "DeepSeek response"))))); + } + + @SuppressWarnings("unchecked") + private MLOutput createMockMLOutput(Map dataMap) { + ModelTensor tensor = mock(ModelTensor.class); + when(tensor.getDataAsMap()).thenAnswer((Answer>) invocation -> dataMap); + + ModelTensors modelTensors = mock(ModelTensors.class); + when(modelTensors.getMlModelTensors()).thenReturn(List.of(tensor)); + + ModelTensorOutput output = mock(ModelTensorOutput.class); + when(output.getMlModelOutputs()).thenReturn(List.of(modelTensors)); + + return output; + } + + public void testCreateMLInputWithDifferentConnectors() { + String searchText = "test query"; + Map referenceData = Map.of("ref", "reference"); + Map hits = Map.of("hit1", "result1"); + String promptTemplate = "{{searchText}} {{hits}}"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.RELEVANT_IRRELEVANT; + + // Test with different connectors + testMLInputCreation(new OpenAIConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new ClaudeConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new CohereConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new DeepSeekConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + } + + private void testMLInputCreation( + LLMConnector connector, + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(connector); + + // This should not throw an exception + var mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType); + assertNotNull(mlInput); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java b/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java new file mode 100644 index 00000000..6fda4211 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.List; +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +public class LLMConnectorTests extends OpenSearchTestCase { + + public void testOpenAIConnector() { + LLMConnector connector = new OpenAIConnector(); + + // Test connector type + assertEquals(ConnectorType.OPENAI, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"system\",\"content\":\"System prompt\"},{\"role\":\"user\",\"content\":\"User message\"}]", formatted); + + // Test response extraction + Map response = Map.of("choices", List.of(Map.of("message", Map.of("content", "AI response")))); + String extracted = connector.extractResponse(response); + assertEquals("AI response", extracted); + } + + public void testClaudeConnector() { + LLMConnector connector = new ClaudeConnector(); + + // Test connector type + assertEquals(ConnectorType.CLAUDE, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"System prompt\\n\\nUser message\"}]}]", formatted); + + // Test response extraction + Map response = Map.of("content", List.of(Map.of("text", "Claude response"))); + String extracted = connector.extractResponse(response); + assertEquals("Claude response", extracted); + } + + public void testCohereConnector() { + LLMConnector connector = new CohereConnector(); + + // Test connector type + assertEquals(ConnectorType.COHERE, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("System prompt\\n\\nUser message", formatted); + + // Test response extraction + Map response = Map.of("text", "Cohere response"); + String extracted = connector.extractResponse(response); + assertEquals("Cohere response", extracted); + } + + public void testConnectorFactory() { + // Test all connector types + LLMConnector openai = LLMConnectorFactory.create(ConnectorType.OPENAI); + assertNotNull(openai); + assertEquals(ConnectorType.OPENAI, openai.getType()); + + LLMConnector claude = LLMConnectorFactory.create(ConnectorType.CLAUDE); + assertNotNull(claude); + assertEquals(ConnectorType.CLAUDE, claude.getType()); + + LLMConnector cohere = LLMConnectorFactory.create(ConnectorType.COHERE); + assertNotNull(cohere); + assertEquals(ConnectorType.COHERE, cohere.getType()); + + LLMConnector deepseek = LLMConnectorFactory.create(ConnectorType.DEEPSEEK); + assertNotNull(deepseek); + assertEquals(ConnectorType.DEEPSEEK, deepseek.getType()); // DeepSeek returns correct type + } + + public void testJsonEscaping() { + LLMConnector connector = new OpenAIConnector(); + + // Test with special characters that need escaping + String formatted = connector.formatPrompt("System \"quoted\" text", "User's message with \n newline"); + assertEquals( + "[{\"role\":\"system\",\"content\":\"System \\\"quoted\\\" text\"},{\"role\":\"user\",\"content\":\"User's message with \\n newline\"}]", + formatted + ); + } + + public void testEmptyClaudeResponse() { + LLMConnector connector = new ClaudeConnector(); + + // Test empty content array + Map emptyResponse = Map.of("content", List.of()); + String extracted = connector.extractResponse(emptyResponse); + assertEquals("", extracted); + + // Test null content + Map nullResponse = Map.of("other", "value"); + String extractedNull = connector.extractResponse(nullResponse); + assertEquals("", extractedNull); + } + + public void testDeepSeekConnector() { + LLMConnector connector = LLMConnectorFactory.create(ConnectorType.DEEPSEEK); + + // Test connector type (should return DEEPSEEK) + assertEquals(ConnectorType.DEEPSEEK, connector.getType()); + + // Test prompt formatting (same as OpenAI format) + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"system\",\"content\":\"System prompt\"},{\"role\":\"user\",\"content\":\"User message\"}]", formatted); + + // Test response extraction (same as OpenAI format) + Map response = Map.of("choices", List.of(Map.of("message", Map.of("content", "DeepSeek response")))); + String extracted = connector.extractResponse(response); + assertEquals("DeepSeek response", extracted); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java new file mode 100644 index 00000000..9bbc5d3c --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.transport.judgment; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.List; + +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +public class PutLlmJudgmentRequestTests extends OpenSearchTestCase { + + public void testConnectorTypeDefaultsToOpenAI() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + null, // connectorType is null + 1000L + ); + + assertEquals(ConnectorType.OPENAI, request.getConnectorType()); + } + + public void testConnectorTypeIsPreserved() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.CLAUDE, + 1000L + ); + + assertEquals(ConnectorType.CLAUDE, request.getConnectorType()); + } + + public void testAllConnectorTypes() { + for (ConnectorType type : ConnectorType.values()) { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + type, + 1000L + ); + + assertEquals(type, request.getConnectorType()); + assertNotNull(request.getConnectorType()); + } + } + + public void testRateLimitDefaultsToZero() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.OPENAI, + 0L // explicitly set rateLimit to 0 + ); + + assertEquals(0L, request.getRateLimit()); + } + + public void testNegativeRateLimitBecomesZero() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.OPENAI, + -100L // negative rateLimit + ); + + assertEquals(0L, request.getRateLimit()); // should be clamped to 0 + } +}