From 286c2e45ecd6b7d7f3a4df5b2e219f3bb54b54a9 Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Thu, 20 Nov 2025 22:23:34 +0000 Subject: [PATCH 1/6] add general llm connector factory Signed-off-by: Fen Qin --- docs/llm-model/claude/README.md | 43 ++++++ docs/llm-model/claude/connector_validate.sh | 102 ++++++++++++++ docs/llm-model/cohere/README.md | 68 ++++++++++ docs/llm-model/cohere/connector_validate.sh | 105 +++++++++++++++ docs/llm-model/deepseek/README.md | 76 +++++++++++ docs/llm-model/deepseek/connector_validate.sh | 101 ++++++++++++++ docs/llm-model/openai/README.md | 82 +++++++++++ docs/llm-model/openai/connector_validate.sh | 104 ++++++++++++++ .../ml/MLInputOutputTransformer.java | 24 ++-- .../ml/connector/ClaudeConnector.java | 45 +++++++ .../ml/connector/CohereConnector.java | 33 +++++ .../ml/connector/ConnectorType.java | 33 +++++ .../ml/connector/DeepSeekConnector.java | 19 +++ .../ml/connector/LLMConnector.java | 40 ++++++ .../ml/connector/LLMConnectorFactory.java | 30 +++++ .../ml/connector/OpenAIConnector.java | 41 ++++++ ...LInputOutputTransformerConnectorTests.java | 123 +++++++++++++++++ .../ml/connector/LLMConnectorTests.java | 127 ++++++++++++++++++ 18 files changed, 1187 insertions(+), 9 deletions(-) create mode 100644 docs/llm-model/claude/README.md create mode 100644 docs/llm-model/claude/connector_validate.sh create mode 100644 docs/llm-model/cohere/README.md create mode 100644 docs/llm-model/cohere/connector_validate.sh create mode 100644 docs/llm-model/deepseek/README.md create mode 100644 docs/llm-model/deepseek/connector_validate.sh create mode 100644 docs/llm-model/openai/README.md create mode 100644 docs/llm-model/openai/connector_validate.sh create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java create mode 100644 src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java create mode 100644 src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java diff --git a/docs/llm-model/claude/README.md b/docs/llm-model/claude/README.md new file mode 100644 index 00000000..7af71a01 --- /dev/null +++ b/docs/llm-model/claude/README.md @@ -0,0 +1,43 @@ +# Claude 3.5 Haiku OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating Claude 3.5 Haiku with OpenSearch ML for search relevance rating. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- AWS credentials configured (`aws configure`) +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `us.anthropic.claude-3-5-haiku-20241022-v1:0` (inference profile) +- **Region**: `us-east-1` +- **Max Tokens**: 4000 +- **Protocol**: AWS SigV4 + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Your prompt here" + } + ] + } + ] + } +} +``` \ No newline at end of file diff --git a/docs/llm-model/claude/connector_validate.sh b/docs/llm-model/claude/connector_validate.sh new file mode 100644 index 00000000..f1ac8c39 --- /dev/null +++ b/docs/llm-model/claude/connector_validate.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Claude 3.5 Haiku Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="claude-3-5-haiku-working" +MODEL_NAME="claude-3-5-haiku-working-model" + +# Get AWS credentials +export AWS_ACCESS_KEY_ID=$(aws configure get aws_access_key_id) +export AWS_SECRET_ACCESS_KEY=$(aws configure get aws_secret_access_key) +export AWS_SESSION_TOKEN=$(aws configure get aws_session_token) + +echo "Creating Claude 3.5 Haiku connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "Claude 3.5 Haiku connector for search relevance rating", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "'$AWS_ACCESS_KEY_ID'", + "secret_key": "'$AWS_SECRET_ACCESS_KEY'", + "session_token": "'$AWS_SESSION_TOKEN'" + }, + "parameters": { + "region": "us-east-1", + "service_name": "bedrock", + "model": "us.anthropic.claude-3-5-haiku-20241022-v1:0" + }, + "client_config": { + "max_connection": 1, + "connection_timeout": 60000, + "read_timeout": 60000, + "retry_backoff_millis": 3000, + "retry_timeout_seconds": 60, + "max_retry_times": 2 + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{\"anthropic_version\": \"bedrock-2023-05-31\", \"max_tokens\": 2000, \"messages\": ${parameters.messages}}" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "Claude 3.5 Haiku model for search relevance rating", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test prediction +echo "Testing prediction..." +TEST_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } + ] + } +}') + +echo "Test completed successfully!" +echo "Response: $(echo $TEST_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.content[0].text')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/cohere/README.md b/docs/llm-model/cohere/README.md new file mode 100644 index 00000000..65007e7f --- /dev/null +++ b/docs/llm-model/cohere/README.md @@ -0,0 +1,68 @@ +# Cohere Command R OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating Cohere Command R with OpenSearch ML via AWS Bedrock. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- AWS credentials configured (`aws configure`) +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `cohere.command-r-v1:0` (via Bedrock) +- **Region**: `us-east-1` +- **Max Tokens**: 1000 +- **Protocol**: AWS SigV4 +- **Streaming**: Supported + +### Message Format +```json +{ + "parameters": { + "message": "Your message here" + } +} +``` + +## Available Cohere Models + +### Chat Models (via Bedrock) +- `cohere.command-r-v1:0` - Command R +- `cohere.command-r-plus-v1:0` - Command R+ (more capable) + +### Embedding Models +- `cohere.embed-v4:0` - Embed v4 (multimodal) +- `cohere.embed-english-v3` - English embeddings +- `cohere.embed-multilingual-v3` - Multilingual embeddings + +### Rerank Model +- `cohere.rerank-v3-5:0` - Rerank 3.5 + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "response_id": "...", + "text": "Model response here", + "generation_id": "...", + "chat_history": [...], + "finish_reason": "COMPLETE" + } + }], + "status_code": 200 + }] +} +``` \ No newline at end of file diff --git a/docs/llm-model/cohere/connector_validate.sh b/docs/llm-model/cohere/connector_validate.sh new file mode 100644 index 00000000..09685e87 --- /dev/null +++ b/docs/llm-model/cohere/connector_validate.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Cohere Command R Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="cohere-command-r-bedrock" +MODEL_NAME="cohere-command-r-bedrock-model" + +# Get AWS credentials +export AWS_ACCESS_KEY_ID=$(aws configure get aws_access_key_id) +export AWS_SECRET_ACCESS_KEY=$(aws configure get aws_secret_access_key) +export AWS_SESSION_TOKEN=$(aws configure get aws_session_token) + +echo "Creating Cohere Command R connector via Bedrock..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "Cohere Command R via Bedrock for chat", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "'$AWS_ACCESS_KEY_ID'", + "secret_key": "'$AWS_SECRET_ACCESS_KEY'", + "session_token": "'$AWS_SESSION_TOKEN'" + }, + "parameters": { + "region": "us-east-1", + "service_name": "bedrock", + "model": "cohere.command-r-v1:0" + }, + "client_config": { + "max_connection": 1, + "connection_timeout": 60000, + "read_timeout": 60000, + "retry_backoff_millis": 3000, + "retry_timeout_seconds": 60, + "max_retry_times": 2 + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{\"message\": \"${parameters.message}\", \"max_tokens\": 1000}" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "Cohere Command R model via Bedrock", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "message": "Hello, respond with just the word success" + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.text')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "message": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.text')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/deepseek/README.md b/docs/llm-model/deepseek/README.md new file mode 100644 index 00000000..418055f3 --- /dev/null +++ b/docs/llm-model/deepseek/README.md @@ -0,0 +1,76 @@ +# DeepSeek Chat OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating DeepSeek Chat with OpenSearch ML. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- DeepSeek API key +- `jq` installed for JSON parsing + +## Quick Start + +```bash +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `deepseek-chat` +- **Endpoint**: `api.deepseek.com` +- **Protocol**: HTTP +- **API Version**: v1 + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Your message here" + } + ] + } +} +``` + +## Available Models + +### Chat Models +- `deepseek-chat` - General purpose chat model +- `deepseek-coder` - Code-focused model (if available) + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "...", + "object": "chat.completion", + "created": 1761688945, + "model": "deepseek-chat", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Model response here" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 1, + "total_tokens": 13 + } + } + }], + "status_code": 200 + }] +} +``` \ No newline at end of file diff --git a/docs/llm-model/deepseek/connector_validate.sh b/docs/llm-model/deepseek/connector_validate.sh new file mode 100644 index 00000000..fbbfd390 --- /dev/null +++ b/docs/llm-model/deepseek/connector_validate.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +# DeepSeek Chat Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="DeepSeek Chat" +MODEL_NAME="deepseek-chat-model" +DEEPSEEK_API_KEY="" + +echo "Creating DeepSeek Chat connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d "{ + \"name\": \"${CONNECTOR_NAME}\", + \"description\": \"DeepSeek Chat connector for conversational AI\", + \"version\": \"1\", + \"protocol\": \"http\", + \"parameters\": { + \"endpoint\": \"api.deepseek.com\", + \"model\": \"deepseek-chat\" + }, + \"credential\": { + \"deepSeek_key\": \"${DEEPSEEK_API_KEY}\" + }, + \"actions\": [ + { + \"action_type\": \"predict\", + \"method\": \"POST\", + \"url\": \"https://\${parameters.endpoint}/v1/chat/completions\", + \"headers\": { + \"Content-Type\": \"application/json\", + \"Authorization\": \"Bearer \${credential.deepSeek_key}\" + }, + \"request_body\": \"{ \\\"model\\\": \\\"\${parameters.model}\\\", \\\"messages\\\": \${parameters.messages} }\" + } + ] +}") + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "DeepSeek chat model for conversational AI", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Hello, respond with just the word success" + } + ] + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Rate search relevance 0.0-1.0. Return JSON only: [{\"id\":\"001\",\"rating_score\":0.9}]. Rate ALL hits.\n\nSearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" \ No newline at end of file diff --git a/docs/llm-model/openai/README.md b/docs/llm-model/openai/README.md new file mode 100644 index 00000000..23ee0edf --- /dev/null +++ b/docs/llm-model/openai/README.md @@ -0,0 +1,82 @@ +# OpenAI GPT OpenSearch ML Connector + +This directory contains the setup and validation scripts for integrating OpenAI GPT models with OpenSearch ML. + +## Prerequisites + +- OpenSearch running on `http://localhost:9200` +- Valid OpenAI API key +- `jq` installed for JSON parsing + +## Quick Start + +```bash +# Update API key in connector_validate.sh +chmod +x connector_validate.sh +./connector_validate.sh +``` + +## Configuration + +### Model Details +- **Model**: `gpt-5-nano` (configurable) +- **Endpoint**: `api.openai.com` +- **Protocol**: HTTP +- **Action Type**: `/v1/chat/completion` + +### Available Models +- `gpt-4o` - Latest GPT-4 Omni model +- `gpt-4o-mini` - Smaller, faster GPT-4 variant +- `gpt-4-turbo` - GPT-4 Turbo +- `gpt-3.5-turbo` - GPT-3.5 Turbo +- `gpt-5-nano` - GPT-5 Nano (if available) + +### Message Format +```json +{ + "parameters": { + "messages": [ + { + "role": "system", + "content": "System prompt here" + }, + { + "role": "user", + "content": "User message here" + } + ] + } +} +``` + +## Response Format + +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "chatcmpl-...", + "object": "chat.completion", + "created": 1761689000, + "model": "gpt-5-nano", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Model response here" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + } + }], + "status_code": 200 + }] +} +``` diff --git a/docs/llm-model/openai/connector_validate.sh b/docs/llm-model/openai/connector_validate.sh new file mode 100644 index 00000000..d40605fd --- /dev/null +++ b/docs/llm-model/openai/connector_validate.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# OpenAI GPT Connector Validation Script +set -e + +OPENSEARCH_URL="http://localhost:9200" +CONNECTOR_NAME="mfenqin-batch-test" +MODEL_NAME="openai-gpt-model" +OPENAI_API_KEY="" + +echo "Creating OpenAI GPT connector..." + +# Create connector +CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/_create" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${CONNECTOR_NAME}'", + "description": "OpenAI GPT connector for search relevance rating", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-5-nano" + }, + "credential": { + "openAI_key": "'${OPENAI_API_KEY}'" + }, + "actions": [ + { + "action_type": "batch_predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions", + "headers": { + "Authorization": "Bearer ${credential.openAI_key}" + }, + "request_body": "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }" + } + ] +}') + +CONNECTOR_ID=$(echo $CONNECTOR_RESPONSE | jq -r '.connector_id') +echo "Connector created with ID: $CONNECTOR_ID" + +# Register model +echo "Registering model..." +MODEL_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/_register" \ +-H "Content-Type: application/json" \ +-d '{ + "name": "'${MODEL_NAME}'", + "function_name": "remote", + "description": "OpenAI GPT model for search relevance rating", + "connector_id": "'$CONNECTOR_ID'" +}') + +MODEL_ID=$(echo $MODEL_RESPONSE | jq -r '.model_id') +echo "Model registered with ID: $MODEL_ID" + +# Deploy model +echo "Deploying model..." +curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_deploy" > /dev/null +echo "Model deployed successfully" + +# Test basic chat +echo "Testing basic chat..." +CHAT_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "user", + "content": "Hello, respond with just the word success" + } + ] + } +}') + +echo "Basic chat test completed!" +echo "Response: $(echo $CHAT_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" + +# Test search relevance rating +echo "Testing search relevance rating..." +RATING_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/models/${MODEL_ID}/_predict" \ +-H "Content-Type: application/json" \ +-d '{ + "parameters": { + "messages": [ + { + "role": "system", + "content": "You are an expert search relevance rater. Your task is to evaluate the relevance between search query and results with these criteria:\n- Score 1.0: Perfect match, highly relevant\n- Score 0.7-0.9: Very relevant with minor variations\n- Score 0.4-0.6: Moderately relevant\n- Score 0.1-0.3: Slightly relevant\n- Score 0.0: Completely irrelevant\nEvaluate based on: exact matches, semantic relevance, and overall context between the SearchText and content in Hits.\nWhen a reference is provided, evaluate based on the relevance to both SearchText and its reference.\n\nIMPORTANT: Provide your response ONLY as a JSON array of objects, each with \"id\" and \"rating_score\" fields. You MUST include a rating for EVERY hit provided, even if the rating is 0. Do not include any explanation or additional text. Example format: [{\"id\": \"001\", \"rating_score\": 0.9}, {\"id\": \"002\", \"rating_score\": 0.5}, {\"id\": \"003\", \"rating_score\": 0.0}]" + }, + { + "role": "user", + "content": "SearchText - banana; Reference - banana smoothie; Hits - [{\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana\", \"price\": 1.99, \"description\": \"this is a banana\"}, \"_id\": \"003\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"apple\", \"price\": 0.99, \"description\": \"fresh apple\"}, \"_id\": \"004\"}, {\"_index\": \"sample_index03\", \"_source\": {\"name\": \"banana smoothie\", \"price\": 3.99, \"description\": \"fresh banana smoothie\"}, \"_id\": \"005\"}]" + } + ] + } +}') + +echo "Search relevance rating test completed!" +echo "Response: $(echo $RATING_RESPONSE | jq -r '.inference_results[0].output[0].dataAsMap.choices[0].message.content')" +echo "" +echo "Connector ID: $CONNECTOR_ID" +echo "Model ID: $MODEL_ID" diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java index c1141521..d0375d48 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java @@ -8,17 +8,12 @@ package org.opensearch.searchrelevance.ml; import static org.opensearch.searchrelevance.common.MLConstants.PARAM_MESSAGES_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_JSON_MESSAGES_SHELL; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_BINARY; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_END; import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_BINARY_SCHEMA; import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_NUMERIC_SCHEMA; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CHOICES_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CONTENT_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_FORMAT_TEMPLATE; -import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_MESSAGE_FIELD; -import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; import java.io.IOException; import java.util.ArrayList; @@ -37,6 +32,8 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.LLMConnector; +import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import lombok.extern.log4j.Log4j2; @@ -47,6 +44,17 @@ @Log4j2 public class MLInputOutputTransformer { + private final LLMConnector connector; + + public MLInputOutputTransformer() { + // Default to OpenAI for backward compatibility + this.connector = new OpenAIConnector(); + } + + public MLInputOutputTransformer(LLMConnector connector) { + this.connector = connector; + } + public List createMLInputs( int tokenLimit, String searchText, @@ -155,7 +163,7 @@ private String buildMessagesArray( String hitsJson = buildHitsJson(hits); String userContent = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, promptTemplate); String systemPrompt = getSystemPrompt(ratingType); - return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, systemPrompt, escapeJson(userContent)); + return connector.formatPrompt(systemPrompt, userContent); } catch (IOException e) { log.error("Error converting hits to JSON string", e); throw new IllegalArgumentException("Failed to process hits", e); @@ -221,8 +229,6 @@ public String extractResponseContent(MLOutput mlOutput) { ModelTensor tensor = tensorOutputList.get(0).getMlModelTensors().get(0); Map dataMap = tensor.getDataAsMap(); - Map choices = (Map) ((List) dataMap.get(RESPONSE_CHOICES_FIELD)).get(0); - Map message = (Map) choices.get(RESPONSE_MESSAGE_FIELD); - return (String) message.get(RESPONSE_CONTENT_FIELD); + return connector.extractResponse(dataMap); } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java new file mode 100644 index 00000000..b4695087 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Claude-specific connector implementation + */ +public class ClaudeConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + String combinedContent = systemContent + "\n\n" + userContent; + return String.format( + Locale.ROOT, + "[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"%s\"}]}]", + escapeJson(combinedContent) + ); + } + + @Override + public String extractResponse(Map rawResponse) { + List content = (List) rawResponse.get("content"); + if (content != null && !content.isEmpty()) { + Map textContent = (Map) content.get(0); + return (String) textContent.get("text"); + } + return ""; + } + + @Override + public ConnectorType getType() { + return ConnectorType.CLAUDE; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java new file mode 100644 index 00000000..4a9d9531 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.Map; + +/** + * Cohere-specific connector implementation + */ +public class CohereConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + return escapeJson(systemContent) + "\\n\\n" + escapeJson(userContent); + } + + @Override + public String extractResponse(Map rawResponse) { + return (String) rawResponse.get("text"); + } + + @Override + public ConnectorType getType() { + return ConnectorType.COHERE; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java new file mode 100644 index 00000000..0340e94c --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/ConnectorType.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * Enum representing different LLM connector types + */ +public enum ConnectorType { + OPENAI("openai"), + CLAUDE("claude"), + COHERE("cohere"), + DEEPSEEK("deepseek"); + + private final String value; + + ConnectorType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return value; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java new file mode 100644 index 00000000..8dddfaa2 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * DeepSeek-specific connector implementation (uses OpenAI-compatible format) + */ +public class DeepSeekConnector extends OpenAIConnector { + + @Override + public ConnectorType getType() { + return ConnectorType.DEEPSEEK; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java new file mode 100644 index 00000000..b8c28185 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import java.util.Map; + +/** + * Interface for LLM connector implementations that handle different LLM providers + */ +public interface LLMConnector { + + /** + * Formats the prompt according to the specific LLM provider's requirements + * + * @param systemContent The system message content + * @param userContent The user message content + * @return Formatted prompt string ready for the LLM API + */ + String formatPrompt(String systemContent, String userContent); + + /** + * Extracts the response text from the raw LLM API response + * + * @param rawResponse The raw response map from the LLM API + * @return Extracted response text + */ + String extractResponse(Map rawResponse); + + /** + * Returns the connector type + * + * @return The ConnectorType enum value + */ + ConnectorType getType(); +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java new file mode 100644 index 00000000..a30b7d48 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorFactory.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +/** + * Factory class for creating LLM connector instances + */ +public class LLMConnectorFactory { + + /** + * Creates an LLM connector instance based on the connector type + * + * @param type The connector type + * @return LLMConnector instance + * @throws IllegalArgumentException if connector type is not supported + */ + public static LLMConnector create(ConnectorType type) { + return switch (type) { + case OPENAI -> new OpenAIConnector(); + case CLAUDE -> new ClaudeConnector(); + case COHERE -> new CohereConnector(); + case DEEPSEEK -> new DeepSeekConnector(); + }; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java new file mode 100644 index 00000000..33963286 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_JSON_MESSAGES_SHELL; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CHOICES_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CONTENT_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_MESSAGE_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * OpenAI-specific connector implementation + */ +public class OpenAIConnector implements LLMConnector { + + @Override + public String formatPrompt(String systemContent, String userContent) { + return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, escapeJson(systemContent), escapeJson(userContent)); + } + + @Override + public String extractResponse(Map rawResponse) { + Map choices = (Map) ((List) rawResponse.get(RESPONSE_CHOICES_FIELD)).get(0); + Map message = (Map) choices.get(RESPONSE_MESSAGE_FIELD); + return (String) message.get(RESPONSE_CONTENT_FIELD); + } + + @Override + public ConnectorType getType() { + return ConnectorType.OPENAI; + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java new file mode 100644 index 00000000..0cc26cfc --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerConnectorTests.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.mockito.stubbing.Answer; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ClaudeConnector; +import org.opensearch.searchrelevance.ml.connector.CohereConnector; +import org.opensearch.searchrelevance.ml.connector.DeepSeekConnector; +import org.opensearch.searchrelevance.ml.connector.LLMConnector; +import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInputOutputTransformerConnectorTests extends OpenSearchTestCase { + + public void testDefaultConstructorUsesOpenAI() { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(); + assertNotNull(transformer); + // Default behavior should work (uses OpenAI connector internally) + } + + public void testConnectorConstructor() { + LLMConnector claudeConnector = new ClaudeConnector(); + MLInputOutputTransformer transformer = new MLInputOutputTransformer(claudeConnector); + assertNotNull(transformer); + } + + public void testExtractResponseContentWithDifferentConnectors() { + // Test OpenAI response extraction + testResponseExtraction(new OpenAIConnector(), createOpenAIResponse(), "OpenAI response"); + + // Test Claude response extraction + testResponseExtraction(new ClaudeConnector(), createClaudeResponse(), "Claude response"); + + // Test Cohere response extraction + testResponseExtraction(new CohereConnector(), createCohereResponse(), "Cohere response"); + + // Test DeepSeek response extraction (uses OpenAI format but returns DEEPSEEK type) + testResponseExtraction(new DeepSeekConnector(), createDeepSeekResponse(), "DeepSeek response"); + } + + private void testResponseExtraction(LLMConnector connector, MLOutput mlOutput, String expectedResponse) { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(connector); + String result = transformer.extractResponseContent(mlOutput); + assertEquals(expectedResponse, result); + } + + private MLOutput createOpenAIResponse() { + return createMockMLOutput(Map.of("choices", List.of(Map.of("message", Map.of("content", "OpenAI response"))))); + } + + private MLOutput createClaudeResponse() { + return createMockMLOutput(Map.of("content", List.of(Map.of("text", "Claude response")))); + } + + private MLOutput createCohereResponse() { + return createMockMLOutput(Map.of("text", "Cohere response")); + } + + private MLOutput createDeepSeekResponse() { + return createMockMLOutput(Map.of("choices", List.of(Map.of("message", Map.of("content", "DeepSeek response"))))); + } + + @SuppressWarnings("unchecked") + private MLOutput createMockMLOutput(Map dataMap) { + ModelTensor tensor = mock(ModelTensor.class); + when(tensor.getDataAsMap()).thenAnswer((Answer>) invocation -> dataMap); + + ModelTensors modelTensors = mock(ModelTensors.class); + when(modelTensors.getMlModelTensors()).thenReturn(List.of(tensor)); + + ModelTensorOutput output = mock(ModelTensorOutput.class); + when(output.getMlModelOutputs()).thenReturn(List.of(modelTensors)); + + return output; + } + + public void testCreateMLInputWithDifferentConnectors() { + String searchText = "test query"; + Map referenceData = Map.of("ref", "reference"); + Map hits = Map.of("hit1", "result1"); + String promptTemplate = "{{searchText}} {{hits}}"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.RELEVANT_IRRELEVANT; + + // Test with different connectors + testMLInputCreation(new OpenAIConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new ClaudeConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new CohereConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + testMLInputCreation(new DeepSeekConnector(), searchText, referenceData, hits, promptTemplate, ratingType); + } + + private void testMLInputCreation( + LLMConnector connector, + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { + MLInputOutputTransformer transformer = new MLInputOutputTransformer(connector); + + // This should not throw an exception + var mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType); + assertNotNull(mlInput); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java b/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java new file mode 100644 index 00000000..6fda4211 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/connector/LLMConnectorTests.java @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml.connector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.List; +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +public class LLMConnectorTests extends OpenSearchTestCase { + + public void testOpenAIConnector() { + LLMConnector connector = new OpenAIConnector(); + + // Test connector type + assertEquals(ConnectorType.OPENAI, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"system\",\"content\":\"System prompt\"},{\"role\":\"user\",\"content\":\"User message\"}]", formatted); + + // Test response extraction + Map response = Map.of("choices", List.of(Map.of("message", Map.of("content", "AI response")))); + String extracted = connector.extractResponse(response); + assertEquals("AI response", extracted); + } + + public void testClaudeConnector() { + LLMConnector connector = new ClaudeConnector(); + + // Test connector type + assertEquals(ConnectorType.CLAUDE, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"System prompt\\n\\nUser message\"}]}]", formatted); + + // Test response extraction + Map response = Map.of("content", List.of(Map.of("text", "Claude response"))); + String extracted = connector.extractResponse(response); + assertEquals("Claude response", extracted); + } + + public void testCohereConnector() { + LLMConnector connector = new CohereConnector(); + + // Test connector type + assertEquals(ConnectorType.COHERE, connector.getType()); + + // Test prompt formatting + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("System prompt\\n\\nUser message", formatted); + + // Test response extraction + Map response = Map.of("text", "Cohere response"); + String extracted = connector.extractResponse(response); + assertEquals("Cohere response", extracted); + } + + public void testConnectorFactory() { + // Test all connector types + LLMConnector openai = LLMConnectorFactory.create(ConnectorType.OPENAI); + assertNotNull(openai); + assertEquals(ConnectorType.OPENAI, openai.getType()); + + LLMConnector claude = LLMConnectorFactory.create(ConnectorType.CLAUDE); + assertNotNull(claude); + assertEquals(ConnectorType.CLAUDE, claude.getType()); + + LLMConnector cohere = LLMConnectorFactory.create(ConnectorType.COHERE); + assertNotNull(cohere); + assertEquals(ConnectorType.COHERE, cohere.getType()); + + LLMConnector deepseek = LLMConnectorFactory.create(ConnectorType.DEEPSEEK); + assertNotNull(deepseek); + assertEquals(ConnectorType.DEEPSEEK, deepseek.getType()); // DeepSeek returns correct type + } + + public void testJsonEscaping() { + LLMConnector connector = new OpenAIConnector(); + + // Test with special characters that need escaping + String formatted = connector.formatPrompt("System \"quoted\" text", "User's message with \n newline"); + assertEquals( + "[{\"role\":\"system\",\"content\":\"System \\\"quoted\\\" text\"},{\"role\":\"user\",\"content\":\"User's message with \\n newline\"}]", + formatted + ); + } + + public void testEmptyClaudeResponse() { + LLMConnector connector = new ClaudeConnector(); + + // Test empty content array + Map emptyResponse = Map.of("content", List.of()); + String extracted = connector.extractResponse(emptyResponse); + assertEquals("", extracted); + + // Test null content + Map nullResponse = Map.of("other", "value"); + String extractedNull = connector.extractResponse(nullResponse); + assertEquals("", extractedNull); + } + + public void testDeepSeekConnector() { + LLMConnector connector = LLMConnectorFactory.create(ConnectorType.DEEPSEEK); + + // Test connector type (should return DEEPSEEK) + assertEquals(ConnectorType.DEEPSEEK, connector.getType()); + + // Test prompt formatting (same as OpenAI format) + String formatted = connector.formatPrompt("System prompt", "User message"); + assertEquals("[{\"role\":\"system\",\"content\":\"System prompt\"},{\"role\":\"user\",\"content\":\"User message\"}]", formatted); + + // Test response extraction (same as OpenAI format) + Map response = Map.of("choices", List.of(Map.of("message", Map.of("content", "DeepSeek response")))); + String extracted = connector.extractResponse(response); + assertEquals("DeepSeek response", extracted); + } +} From e9b373998de93f47ee05960a5d5c59282ca000f1 Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Thu, 20 Nov 2025 23:36:08 +0000 Subject: [PATCH 2/6] parse conntectType from llm judgment rest api Signed-off-by: Fen Qin --- .../searchrelevance/common/MLConstants.java | 1 + .../judgments/LlmJudgmentContext.java | 154 +++++++++++++++ .../judgments/LlmJudgmentsProcessor.java | 185 +++++++----------- .../searchrelevance/ml/MLAccessor.java | 19 +- .../rest/RestPutJudgmentAction.java | 30 ++- .../judgment/PutJudgmentTransportAction.java | 2 + .../judgment/PutLlmJudgmentRequest.java | 17 +- .../judgment/PutJudgmentActionTests.java | 7 +- .../ml/MLAccessorIntegrationTests.java | 2 + .../judgment/PutLlmJudgmentRequestTests.java | 87 ++++++++ 10 files changed, 384 insertions(+), 120 deletions(-) create mode 100644 src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java create mode 100644 src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java diff --git a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java index 9a21b9ef..61ad71f5 100644 --- a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java +++ b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java @@ -23,6 +23,7 @@ private MLConstants() {} * ML input field names */ public static final String PARAM_MESSAGES_FIELD = "messages"; + public static final String CONNECTOR_TYPE = "connectorType"; public static final String PROMPT_TEMPLATE = "promptTemplate"; public static final String LLM_JUDGMENT_RATING_TYPE = "llmJudgmentRatingType"; public static final String OVERWRITE_CACHE = "overwriteCache"; diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java new file mode 100644 index 00000000..05a6640d --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.judgments; + +import java.util.List; + +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.searchrelevance.model.SearchConfiguration; + +/** + * Context object to hold LLM judgment parameters + */ +public class LlmJudgmentContext { + private final String modelId; + private final int size; + private final int tokenLimit; + private final List contextFields; + private final List searchConfigurations; + private final boolean ignoreFailure; + private final String promptTemplate; + private final LLMJudgmentRatingType ratingType; + private final boolean overwriteCache; + private final ConnectorType connectorType; + + private LlmJudgmentContext(Builder builder) { + this.modelId = builder.modelId; + this.size = builder.size; + this.tokenLimit = builder.tokenLimit; + this.contextFields = builder.contextFields; + this.searchConfigurations = builder.searchConfigurations; + this.ignoreFailure = builder.ignoreFailure; + this.promptTemplate = builder.promptTemplate; + this.ratingType = builder.ratingType; + this.overwriteCache = builder.overwriteCache; + this.connectorType = builder.connectorType; + } + + public static Builder builder() { + return new Builder(); + } + + public String getModelId() { + return modelId; + } + + public int getSize() { + return size; + } + + public int getTokenLimit() { + return tokenLimit; + } + + public List getContextFields() { + return contextFields; + } + + public List getSearchConfigurations() { + return searchConfigurations; + } + + public boolean isIgnoreFailure() { + return ignoreFailure; + } + + public String getPromptTemplate() { + return promptTemplate; + } + + public LLMJudgmentRatingType getRatingType() { + return ratingType; + } + + public boolean isOverwriteCache() { + return overwriteCache; + } + + public ConnectorType getConnectorType() { + return connectorType; + } + + public static class Builder { + private String modelId; + private int size; + private int tokenLimit; + private List contextFields; + private List searchConfigurations; + private boolean ignoreFailure; + private String promptTemplate; + private LLMJudgmentRatingType ratingType; + private boolean overwriteCache; + private ConnectorType connectorType; + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder size(int size) { + this.size = size; + return this; + } + + public Builder tokenLimit(int tokenLimit) { + this.tokenLimit = tokenLimit; + return this; + } + + public Builder contextFields(List contextFields) { + this.contextFields = contextFields; + return this; + } + + public Builder searchConfigurations(List searchConfigurations) { + this.searchConfigurations = searchConfigurations; + return this; + } + + public Builder ignoreFailure(boolean ignoreFailure) { + this.ignoreFailure = ignoreFailure; + return this; + } + + public Builder promptTemplate(String promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder ratingType(LLMJudgmentRatingType ratingType) { + this.ratingType = ratingType; + return this; + } + + public Builder overwriteCache(boolean overwriteCache) { + this.overwriteCache = overwriteCache; + return this; + } + + public Builder connectorType(ConnectorType connectorType) { + this.connectorType = connectorType; + return this; + } + + public LlmJudgmentContext build() { + return new LlmJudgmentContext(this); + } + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index be99c792..77017df5 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -7,6 +7,7 @@ */ package org.opensearch.searchrelevance.judgments; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; @@ -22,6 +23,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -44,6 +46,7 @@ import org.opensearch.searchrelevance.executors.LlmJudgmentTaskManager; import org.opensearch.searchrelevance.ml.ChunkResult; import org.opensearch.searchrelevance.ml.MLAccessor; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentCache; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; @@ -122,24 +125,37 @@ private void generateJudgmentRatingInternal(Map metadata, Action } boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); + // Extract connectorType from metadata, default to OpenAI if not provided + ConnectorType connectorType = ConnectorType.OPENAI; // default + String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); + if (connectorTypeStr != null) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + log.warn("Invalid connectorType '{}' in metadata, defaulting to OpenAI", connectorTypeStr); + } + } + QuerySet querySet = querySetDao.getQuerySetSync(querySetId); List searchConfigurations = searchConfigurationList.stream() .map(id -> searchConfigurationDao.getSearchConfigurationSync(id)) .collect(Collectors.toList()); - generateLLMJudgmentsAsync( - modelId, - size, - tokenLimit, - contextFields, - querySet, - searchConfigurations, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache, - listener - ); + // Build context object + LlmJudgmentContext context = LlmJudgmentContext.builder() + .modelId(modelId) + .size(size) + .tokenLimit(tokenLimit) + .contextFields(contextFields) + .searchConfigurations(searchConfigurations) + .ignoreFailure(ignoreFailure) + .promptTemplate(promptTemplate) + .ratingType(ratingType) + .overwriteCache(overwriteCache) + .connectorType(connectorType) + .build(); + + generateLLMJudgmentsAsync(context, querySet, listener); } catch (Exception e) { log.error("Failed to generate LLM judgments", e); listener.onFailure(new SearchRelevanceException("Failed to generate LLM judgments", e, RestStatus.INTERNAL_SERVER_ERROR)); @@ -147,16 +163,8 @@ private void generateJudgmentRatingInternal(Map metadata, Action } private void generateLLMJudgmentsAsync( - String modelId, - int size, - int tokenLimit, - List contextFields, + LlmJudgmentContext context, QuerySet querySet, - List searchConfigurations, - boolean ignoreFailure, - String promptTemplate, - LLMJudgmentRatingType ratingType, - boolean overwriteCache, ActionListener>> listener ) { List queryTextsWithCustomInput = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList()); @@ -172,20 +180,9 @@ private void generateLLMJudgmentsAsync( taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { - return processQueryTextAsync( - modelId, - size, - tokenLimit, - contextFields, - searchConfigurations, - queryTextWithCustomInput, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache - ); + return processQueryTextAsync(context, queryTextWithCustomInput); } catch (Exception e) { - if (ignoreFailure) { + if (context.isIgnoreFailure()) { log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { @@ -193,7 +190,7 @@ private void generateLLMJudgmentsAsync( throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } - }, ignoreFailure, ActionListener.wrap(results -> { + }, context.isIgnoreFailure(), ActionListener.wrap((List> results) -> { int processedQueries = results.size(); int successQueries = (int) results.stream().mapToLong(result -> { List> ratings = (List>) result.get("ratings"); @@ -219,20 +216,9 @@ private void generateLLMJudgmentsAsync( taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { - return processQueryTextAsync( - modelId, - size, - tokenLimit, - contextFields, - searchConfigurations, - queryTextWithCustomInput, - ignoreFailure, - promptTemplate, - ratingType, - overwriteCache - ); + return processQueryTextAsync(context, queryTextWithCustomInput); } catch (Exception e) { - if (ignoreFailure) { + if (context.isIgnoreFailure()) { log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { @@ -240,7 +226,7 @@ private void generateLLMJudgmentsAsync( throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } - }, ignoreFailure, ActionListener.wrap(results -> { + }, context.isIgnoreFailure(), ActionListener.wrap((List> results) -> { int processedQueries = results.size(); int successQueries = (int) results.stream().mapToLong(result -> { List> ratings = (List>) result.get("ratings"); @@ -264,18 +250,7 @@ private void generateLLMJudgmentsAsync( }); } - private Map processQueryTextAsync( - String modelId, - int size, - int tokenLimit, - List contextFields, - List searchConfigurations, - String queryTextWithCustomInput, - boolean ignoreFailure, - String promptTemplate, - LLMJudgmentRatingType ratingType, - boolean overwriteCache - ) { + private Map processQueryTextAsync(LlmJudgmentContext context, String queryTextWithCustomInput) { log.info("Processing query text judgment: {}", queryTextWithCustomInput); ConcurrentMap allHits = new ConcurrentHashMap<>(); @@ -284,38 +259,33 @@ private Map processQueryTextAsync( try { // Step 1: Execute searches concurrently within this query text task - processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure); + processSearchConfigurationsAsync( + context.getSearchConfigurations(), + queryText, + context.getSize(), + allHits, + context.isIgnoreFailure() + ); // Step 2: Deduplicate from cache (skip if overwriteCache is true) List docIds = new ArrayList<>(allHits.keySet()); - String index = searchConfigurations.get(0).index(); - String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); + String index = context.getSearchConfigurations().get(0).index(); + String promptTemplateCode = generatePromptTemplateCode(context.getPromptTemplate(), context.getRatingType()); List unprocessedDocIds = deduplicateFromCache( index, queryTextWithCustomInput, - contextFields, + context.getContextFields(), docIds, docIdToScore, - ignoreFailure, + context.isIgnoreFailure(), promptTemplateCode, - overwriteCache + context.isOverwriteCache() ); // Step 3: Process with LLM if needed if (!unprocessedDocIds.isEmpty()) { - processWithLLM( - modelId, - queryTextWithCustomInput, - tokenLimit, - contextFields, - unprocessedDocIds, - allHits, - index, - docIdToScore, - promptTemplate, - ratingType - ); + processWithLLM(context, queryTextWithCustomInput, unprocessedDocIds, allHits, index, docIdToScore); } Map result = JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); @@ -415,16 +385,12 @@ private List deduplicateFromCache( } private void processWithLLM( - String modelId, + LlmJudgmentContext context, String queryTextWithCustomInput, - int tokenLimit, - List contextFields, List unprocessedDocIds, ConcurrentMap allHits, String index, - ConcurrentMap docIdToScore, - String promptTemplate, - LLMJudgmentRatingType ratingType + ConcurrentMap docIdToScore ) throws Exception { Map unionHits = new HashMap<>(); @@ -432,32 +398,26 @@ private void processWithLLM( for (String docId : unprocessedDocIds) { SearchHit hit = allHits.get(docId); String compositeKey = combinedIndexAndDocId(index, docId); - String contextSource = getContextSource(hit, contextFields); + String contextSource = getContextSource(hit, context.getContextFields()); unionHits.put(compositeKey, contextSource); } log.info("Processing {} uncached docs with LLM", unionHits.size()); log.debug("DEBUG: unionHits keys being sent to LLM: {}", unionHits.keySet()); log.debug("DEBUG: queryTextWithCustomInput: {}", queryTextWithCustomInput); - log.debug("DEBUG: modelId: {}, tokenLimit: {}, ratingType: {}", modelId, tokenLimit, ratingType); + log.debug( + "DEBUG: modelId: {}, tokenLimit: {}, ratingType: {}", + context.getModelId(), + context.getTokenLimit(), + context.getRatingType() + ); // Generate promptTemplateCode for cache updates - String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); + String promptTemplateCode = generatePromptTemplateCode(context.getPromptTemplate(), context.getRatingType()); // Synchronous LLM call PlainActionFuture> llmFuture = PlainActionFuture.newFuture(); - generateLLMJudgmentForQueryText( - modelId, - queryTextWithCustomInput, - tokenLimit, - contextFields, - unionHits, - new HashMap<>(), - promptTemplate, - ratingType, - promptTemplateCode, - llmFuture - ); + generateLLMJudgmentForQueryText(context, queryTextWithCustomInput, unionHits, new HashMap<>(), promptTemplateCode, llmFuture); Map llmResults = llmFuture.actionGet(); docIdToScore.putAll(llmResults); @@ -466,18 +426,14 @@ private void processWithLLM( } private void generateLLMJudgmentForQueryText( - String modelId, + LlmJudgmentContext context, String queryTextWithCustomInput, - int tokenLimit, - List contextFields, Map unprocessedUnionHits, Map docIdToRating, - String promptTemplate, - LLMJudgmentRatingType ratingType, String promptTemplateCode, ActionListener> listener ) { - log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", modelId, unprocessedUnionHits); + log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", context.getModelId(), unprocessedUnionHits); log.debug("processed docIdToRating before llm evaluation: {}", docIdToRating); if (unprocessedUnionHits.isEmpty()) { @@ -496,13 +452,14 @@ private void generateLLMJudgmentForQueryText( AtomicBoolean hasFailure = new AtomicBoolean(false); mlAccessor.predict( - modelId, - tokenLimit, + context.getModelId(), + context.getTokenLimit(), queryText, referenceData, unprocessedUnionHits, - promptTemplate, - ratingType, + context.getPromptTemplate(), + context.getRatingType(), + context.getConnectorType(), new ActionListener() { @Override public void onResponse(ChunkResult chunkResult) { @@ -548,16 +505,16 @@ public void onResponse(ChunkResult chunkResult) { compositeKey, rawRatingScore ); - Double ratingScore = convertRatingScore(rawRatingScore, ratingType); + Double ratingScore = convertRatingScore(rawRatingScore, context.getRatingType()); String docId = getDocIdFromCompositeKey(compositeKey); log.debug("DEBUG: Converted rating - docId: {}, ratingScore: {}", docId, ratingScore); processedRatings.put(docId, ratingScore.toString()); updateJudgmentCache( compositeKey, queryTextWithCustomInput, - contextFields, + context.getContextFields(), ratingScore.toString(), - modelId, + context.getModelId(), promptTemplateCode ); } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java index 210ecca6..ebea1e0c 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java @@ -17,6 +17,8 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.ml.connector.LLMConnectorFactory; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.utils.RatingOutputProcessor; @@ -46,16 +48,29 @@ public void predict( Map hits, String promptTemplate, LLMJudgmentRatingType ratingType, + ConnectorType connectorType, ActionListener progressListener ) { log.debug( - "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}", + "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}, connectorType: {}", modelId, searchText, hits.size(), + ratingType, + connectorType + ); + + // Create transformer with appropriate connector + MLInputOutputTransformer connectorTransformer = new MLInputOutputTransformer(LLMConnectorFactory.create(connectorType)); + + List mlInputs = connectorTransformer.createMLInputs( + tokenLimit, + searchText, + referenceData, + hits, + promptTemplate, ratingType ); - List mlInputs = transformer.createMLInputs(tokenLimit, searchText, referenceData, hits, promptTemplate, ratingType); log.info("Number of chunks: {}", mlInputs.size()); log.debug("DEBUG: Created {} MLInput chunks", mlInputs.size()); diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java index f82ebbbb..e0a5a4ac 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java @@ -9,6 +9,7 @@ import static java.util.Collections.singletonList; import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; @@ -48,6 +49,7 @@ import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.searchrelevance.exception.SearchRelevanceException; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; @@ -166,6 +168,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } boolean overwriteCache = Optional.ofNullable((Boolean) source.get(OVERWRITE_CACHE)).orElse(Boolean.FALSE); + // Parse connectorType - optional, defaults to OpenAI + ConnectorType connectorType = ConnectorType.OPENAI; // default + String connectorTypeStr = (String) source.get(CONNECTOR_TYPE); + if (connectorTypeStr != null) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new SearchRelevanceException( + String.format( + Locale.ROOT, + "Invalid connectorType: '%s'. Valid values are: %s", + connectorTypeStr, + String.join( + ", ", + ConnectorType.OPENAI.name(), + ConnectorType.CLAUDE.name(), + ConnectorType.COHERE.name(), + ConnectorType.DEEPSEEK.name() + ) + ), + RestStatus.BAD_REQUEST + ); + } + } + createRequest = new PutLlmJudgmentRequest( type, name, @@ -179,7 +206,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ignoreFailure, promptTemplate, llmJudgmentRatingType, - overwriteCache + overwriteCache, + connectorType ); } case UBI_JUDGMENT -> { diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 0c3c3df6..96b3c3b5 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -7,6 +7,7 @@ */ package org.opensearch.searchrelevance.transport.judgment; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; @@ -100,6 +101,7 @@ private Map buildMetadata(PutJudgmentRequest request) { case LLM_JUDGMENT -> { PutLlmJudgmentRequest llmRequest = (PutLlmJudgmentRequest) request; metadata.put(MODEL_ID, llmRequest.getModelId()); + metadata.put(CONNECTOR_TYPE, llmRequest.getConnectorType().getValue()); metadata.put("querySetId", llmRequest.getQuerySetId()); metadata.put("size", llmRequest.getSize()); metadata.put("searchConfigurationList", llmRequest.getSearchConfigurationList()); diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java index 24328e9b..913d9101 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java @@ -12,6 +12,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; @@ -57,6 +58,11 @@ public class PutLlmJudgmentRequest extends PutJudgmentRequest { */ private boolean overwriteCache; + /** + * LLM connector type for formatting prompts and extracting responses + */ + private ConnectorType connectorType; + public PutLlmJudgmentRequest( @NonNull JudgmentType type, @NonNull String name, @@ -70,7 +76,8 @@ public PutLlmJudgmentRequest( boolean ignoreFailure, String promptTemplate, LLMJudgmentRatingType llmJudgmentRatingType, - boolean overwriteCache + boolean overwriteCache, + ConnectorType connectorType ) { super(type, name, description); this.modelId = modelId; @@ -83,6 +90,7 @@ public PutLlmJudgmentRequest( this.promptTemplate = promptTemplate; this.llmJudgmentRatingType = llmJudgmentRatingType; this.overwriteCache = overwriteCache; + this.connectorType = connectorType != null ? connectorType : ConnectorType.OPENAI; // default to OpenAI } public PutLlmJudgmentRequest(StreamInput in) throws IOException { @@ -97,6 +105,8 @@ public PutLlmJudgmentRequest(StreamInput in) throws IOException { this.promptTemplate = in.readOptionalString(); this.llmJudgmentRatingType = in.readOptionalWriteable(LLMJudgmentRatingType::readFromStream); this.overwriteCache = Boolean.TRUE.equals(in.readOptionalBoolean()); + String connectorTypeStr = in.readOptionalString(); + this.connectorType = connectorTypeStr != null ? ConnectorType.valueOf(connectorTypeStr) : ConnectorType.OPENAI; } @Override @@ -112,6 +122,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(promptTemplate); out.writeOptionalWriteable(llmJudgmentRatingType); out.writeOptionalBoolean(overwriteCache); + out.writeOptionalString(connectorType != null ? connectorType.name() : null); } public String getModelId() { @@ -154,4 +165,8 @@ public boolean isOverwriteCache() { return overwriteCache; } + public ConnectorType getConnectorType() { + return connectorType; + } + } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java index adf9b2f7..8cd5f8e8 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java @@ -13,6 +13,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.JudgmentType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.transport.judgment.PutImportJudgmentRequest; @@ -93,7 +94,8 @@ public void testLlmJudgmentRequestStreams() throws IOException { false, "test_prompt_template", LLMJudgmentRatingType.SCORE0_1, - true + true, + ConnectorType.OPENAI ); BytesStreamOutput output = new BytesStreamOutput(); @@ -130,7 +132,8 @@ public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOExcep true, null, null, - false + false, + ConnectorType.OPENAI ); BytesStreamOutput output = new BytesStreamOutput(); diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java index 966780df..fac73814 100644 --- a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.test.OpenSearchTestCase; @@ -99,6 +100,7 @@ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exc hits, "Test prompt", LLMJudgmentRatingType.SCORE0_1, + ConnectorType.OPENAI, ActionListener.wrap(chunkResult -> { result.set(chunkResult); latch.countDown(); diff --git a/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java new file mode 100644 index 00000000..495c9322 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.transport.judgment; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.List; + +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +public class PutLlmJudgmentRequestTests extends OpenSearchTestCase { + + public void testConnectorTypeDefaultsToOpenAI() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + null // connectorType is null + ); + + assertEquals(ConnectorType.OPENAI, request.getConnectorType()); + } + + public void testConnectorTypeIsPreserved() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.CLAUDE + ); + + assertEquals(ConnectorType.CLAUDE, request.getConnectorType()); + } + + public void testAllConnectorTypes() { + for (ConnectorType type : ConnectorType.values()) { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + type + ); + + assertEquals(type, request.getConnectorType()); + assertNotNull(request.getConnectorType()); + } + } +} From 9cb2d0d98b31daec6790b0f3115272d605082cef Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Fri, 21 Nov 2025 23:44:45 +0000 Subject: [PATCH 3/6] Add basic rate limit support to LLM judgment requests Signed-off-by: Fen Qin --- .../searchrelevance/common/MLConstants.java | 36 +++- .../judgments/LlmJudgmentContext.java | 164 +++++------------ .../judgments/LlmJudgmentsProcessor.java | 174 ++++++++++++++---- .../ml/MLInputOutputTransformer.java | 28 ++- .../ml/connector/ClaudeConnector.java | 5 + .../ml/connector/CohereConnector.java | 5 + .../ml/connector/DeepSeekConnector.java | 5 + .../ml/connector/LLMConnector.java | 7 + .../ml/connector/OpenAIConnector.java | 5 + .../rest/RestPutJudgmentAction.java | 8 +- .../judgment/PutJudgmentTransportAction.java | 30 +-- .../judgment/PutLlmJudgmentRequest.java | 16 +- .../judgment/PutJudgmentActionTests.java | 6 +- .../judgment/PutLlmJudgmentRequestTests.java | 53 +++++- 14 files changed, 350 insertions(+), 192 deletions(-) diff --git a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java index 61ad71f5..a12afbc8 100644 --- a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java +++ b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java @@ -24,10 +24,17 @@ private MLConstants() {} */ public static final String PARAM_MESSAGES_FIELD = "messages"; public static final String CONNECTOR_TYPE = "connectorType"; + public static final String RATE_LIMIT = "rateLimit"; public static final String PROMPT_TEMPLATE = "promptTemplate"; public static final String LLM_JUDGMENT_RATING_TYPE = "llmJudgmentRatingType"; public static final String OVERWRITE_CACHE = "overwriteCache"; + /** + * Default prompt template for LLM judgments + */ + public static final String DEFAULT_PROMPT_TEMPLATE = + "Rate the relevance of the search results to the query. SearchText: {{searchText}}; Results: {{hits}}"; + /** * Prompt template placeholder names. * These are the special variables that can be used in custom prompt templates. @@ -39,11 +46,6 @@ private MLConstants() {} public static final String PLACEHOLDER_REFERENCE = "reference"; public static final String PLACEHOLDER_REFERENCE_ANSWER = "referenceAnswer"; - /** - * Default prompt template for LLM judgments (simple format without reference data) - */ - public static final String DEFAULT_PROMPT_TEMPLATE = "SearchText: {{searchText}}; Hits: {{hits}}"; - /** * ML response field names */ @@ -200,4 +202,28 @@ public static int validateTokenLimit(Map source) { } } + /** + * Parses rateLimit value from an object, ensuring it's non-negative + * + * @param rateLimitObj The object to parse (can be Number, String, or null) + * @return Parsed rate limit value, or 0 if null/invalid + */ + public static long parseRateLimit(Object rateLimitObj) { + if (rateLimitObj == null) { + return 0L; + } + + try { + long rateLimit; + if (rateLimitObj instanceof Number) { + rateLimit = ((Number) rateLimitObj).longValue(); + } else { + rateLimit = Long.parseLong(rateLimitObj.toString()); + } + return Math.max(0, rateLimit); // ensure non-negative + } catch (NumberFormatException e) { + return 0L; // default to 0 on parse error + } + } + } diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java index 05a6640d..6afcd0fb 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentContext.java @@ -7,16 +7,36 @@ */ package org.opensearch.searchrelevance.judgments; +import java.io.IOException; import java.util.List; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.model.SearchConfiguration; +import lombok.Builder; +import lombok.Getter; + /** * Context object to hold LLM judgment parameters */ -public class LlmJudgmentContext { +@Getter +@Builder +public class LlmJudgmentContext implements ToXContentObject { + public static final String MODEL_ID = "modelId"; + public static final String SIZE = "size"; + public static final String TOKEN_LIMIT = "tokenLimit"; + public static final String CONTEXT_FIELDS = "contextFields"; + public static final String SEARCH_CONFIGURATIONS = "searchConfigurations"; + public static final String IGNORE_FAILURE = "ignoreFailure"; + public static final String PROMPT_TEMPLATE = "promptTemplate"; + public static final String RATING_TYPE = "ratingType"; + public static final String OVERWRITE_CACHE = "overwriteCache"; + public static final String CONNECTOR_TYPE = "connectorType"; + public static final String RATE_LIMIT = "rateLimit"; + private final String modelId; private final int size; private final int tokenLimit; @@ -27,128 +47,32 @@ public class LlmJudgmentContext { private final LLMJudgmentRatingType ratingType; private final boolean overwriteCache; private final ConnectorType connectorType; - - private LlmJudgmentContext(Builder builder) { - this.modelId = builder.modelId; - this.size = builder.size; - this.tokenLimit = builder.tokenLimit; - this.contextFields = builder.contextFields; - this.searchConfigurations = builder.searchConfigurations; - this.ignoreFailure = builder.ignoreFailure; - this.promptTemplate = builder.promptTemplate; - this.ratingType = builder.ratingType; - this.overwriteCache = builder.overwriteCache; - this.connectorType = builder.connectorType; - } - - public static Builder builder() { - return new Builder(); - } - - public String getModelId() { - return modelId; - } - - public int getSize() { - return size; - } - - public int getTokenLimit() { - return tokenLimit; - } - - public List getContextFields() { - return contextFields; - } - - public List getSearchConfigurations() { - return searchConfigurations; - } - - public boolean isIgnoreFailure() { - return ignoreFailure; - } - - public String getPromptTemplate() { - return promptTemplate; - } - - public LLMJudgmentRatingType getRatingType() { - return ratingType; - } - - public boolean isOverwriteCache() { - return overwriteCache; - } - - public ConnectorType getConnectorType() { - return connectorType; - } - - public static class Builder { - private String modelId; - private int size; - private int tokenLimit; - private List contextFields; - private List searchConfigurations; - private boolean ignoreFailure; - private String promptTemplate; - private LLMJudgmentRatingType ratingType; - private boolean overwriteCache; - private ConnectorType connectorType; - - public Builder modelId(String modelId) { - this.modelId = modelId; - return this; - } - - public Builder size(int size) { - this.size = size; - return this; - } - - public Builder tokenLimit(int tokenLimit) { - this.tokenLimit = tokenLimit; - return this; - } - - public Builder contextFields(List contextFields) { - this.contextFields = contextFields; - return this; + private final long rateLimit; // milliseconds between requests + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID, modelId); + builder.field(SIZE, size); + builder.field(TOKEN_LIMIT, tokenLimit); + if (contextFields != null) { + builder.field(CONTEXT_FIELDS, contextFields); } - - public Builder searchConfigurations(List searchConfigurations) { - this.searchConfigurations = searchConfigurations; - return this; - } - - public Builder ignoreFailure(boolean ignoreFailure) { - this.ignoreFailure = ignoreFailure; - return this; + if (searchConfigurations != null) { + builder.field(SEARCH_CONFIGURATIONS, searchConfigurations); } - - public Builder promptTemplate(String promptTemplate) { - this.promptTemplate = promptTemplate; - return this; - } - - public Builder ratingType(LLMJudgmentRatingType ratingType) { - this.ratingType = ratingType; - return this; + builder.field(IGNORE_FAILURE, ignoreFailure); + if (promptTemplate != null) { + builder.field(PROMPT_TEMPLATE, promptTemplate); } - - public Builder overwriteCache(boolean overwriteCache) { - this.overwriteCache = overwriteCache; - return this; - } - - public Builder connectorType(ConnectorType connectorType) { - this.connectorType = connectorType; - return this; - } - - public LlmJudgmentContext build() { - return new LlmJudgmentContext(this); + // Always include ratingType, use default if null + String ratingTypeValue = (ratingType != null) ? ratingType.name() : LLMJudgmentRatingType.SCORE0_1.name(); + builder.field(RATING_TYPE, ratingTypeValue); + builder.field(OVERWRITE_CACHE, overwriteCache); + if (connectorType != null) { + builder.field(CONNECTOR_TYPE, connectorType.name()); } + builder.field(RATE_LIMIT, rateLimit); + return builder.endObject(); } } diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index 77017df5..b8395073 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -11,6 +11,8 @@ import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; +import static org.opensearch.searchrelevance.common.MLConstants.parseRateLimit; import static org.opensearch.searchrelevance.model.builder.SearchRequestBuilder.buildSearchRequest; import static org.opensearch.searchrelevance.utils.ParserUtils.combinedIndexAndDocId; import static org.opensearch.searchrelevance.utils.ParserUtils.generatePromptTemplateCode; @@ -108,52 +110,17 @@ public void generateJudgmentRating(Map metadata, ActionListener< private void generateJudgmentRatingInternal(Map metadata, ActionListener>> listener) { try { EventStatsManager.increment(EventStatName.LLM_JUDGMENT_RATING_GENERATIONS); + String querySetId = (String) metadata.get("querySetId"); List searchConfigurationList = (List) metadata.get("searchConfigurationList"); - int size = (int) metadata.get("size"); - - String modelId = (String) metadata.get("modelId"); - int tokenLimit = (int) metadata.get("tokenLimit"); - List contextFields = (List) metadata.get("contextFields"); - boolean ignoreFailure = (boolean) metadata.get("ignoreFailure"); - String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); - LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); - // Default to SCORE0_1 if ratingType is not provided - if (ratingType == null) { - ratingType = LLMJudgmentRatingType.SCORE0_1; - log.debug("No ratingType provided, defaulting to SCORE0_1"); - } - boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); - - // Extract connectorType from metadata, default to OpenAI if not provided - ConnectorType connectorType = ConnectorType.OPENAI; // default - String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); - if (connectorTypeStr != null) { - try { - connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); - } catch (IllegalArgumentException e) { - log.warn("Invalid connectorType '{}' in metadata, defaulting to OpenAI", connectorTypeStr); - } - } QuerySet querySet = querySetDao.getQuerySetSync(querySetId); List searchConfigurations = searchConfigurationList.stream() .map(id -> searchConfigurationDao.getSearchConfigurationSync(id)) .collect(Collectors.toList()); - // Build context object - LlmJudgmentContext context = LlmJudgmentContext.builder() - .modelId(modelId) - .size(size) - .tokenLimit(tokenLimit) - .contextFields(contextFields) - .searchConfigurations(searchConfigurations) - .ignoreFailure(ignoreFailure) - .promptTemplate(promptTemplate) - .ratingType(ratingType) - .overwriteCache(overwriteCache) - .connectorType(connectorType) - .build(); + // Build context from metadata + LlmJudgmentContext context = buildContextFromMetadata(metadata, searchConfigurations); generateLLMJudgmentsAsync(context, querySet, listener); } catch (Exception e) { @@ -162,6 +129,136 @@ private void generateJudgmentRatingInternal(Map metadata, Action } } + private LlmJudgmentContext buildContextFromMetadata(Map metadata, List searchConfigurations) { + // Check if we have a pre-built context (new approach) + Object contextObj = metadata.get("llmJudgmentContext"); + if (contextObj != null) { + if (contextObj instanceof LlmJudgmentContext) { + // Direct object case (in-memory) + LlmJudgmentContext baseContext = (LlmJudgmentContext) contextObj; + return LlmJudgmentContext.builder() + .modelId(baseContext.getModelId()) + .size(baseContext.getSize()) + .tokenLimit(baseContext.getTokenLimit()) + .contextFields(baseContext.getContextFields()) + .searchConfigurations(searchConfigurations) + .ignoreFailure(baseContext.isIgnoreFailure()) + .promptTemplate(baseContext.getPromptTemplate()) + .ratingType(baseContext.getRatingType()) + .overwriteCache(baseContext.isOverwriteCache()) + .connectorType(baseContext.getConnectorType()) + .rateLimit(baseContext.getRateLimit()) + .build(); + } else if (contextObj instanceof Map) { + // Deserialized from OpenSearch as Map + return buildContextFromMap((Map) contextObj, searchConfigurations); + } + } + + // Fallback to legacy metadata parsing for backward compatibility + return buildContextFromLegacyMetadata(metadata, searchConfigurations); + } + + private LlmJudgmentContext buildContextFromMap(Map contextMap, List searchConfigurations) { + String modelId = (String) contextMap.get(LlmJudgmentContext.MODEL_ID); + Integer size = (Integer) contextMap.get(LlmJudgmentContext.SIZE); + Integer tokenLimit = (Integer) contextMap.get(LlmJudgmentContext.TOKEN_LIMIT); + List contextFields = (List) contextMap.get(LlmJudgmentContext.CONTEXT_FIELDS); + Boolean ignoreFailure = (Boolean) contextMap.get(LlmJudgmentContext.IGNORE_FAILURE); + String promptTemplate = (String) contextMap.get(LlmJudgmentContext.PROMPT_TEMPLATE); + Boolean overwriteCache = (Boolean) contextMap.get(LlmJudgmentContext.OVERWRITE_CACHE); + + Long rateLimit = 1000L; + Object rateLimitObj = contextMap.get(LlmJudgmentContext.RATE_LIMIT); + if (rateLimitObj instanceof Number) { + rateLimit = ((Number) rateLimitObj).longValue(); + } + + // Parse enum values with proper defaults + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + String ratingTypeStr = (String) contextMap.get(LlmJudgmentContext.RATING_TYPE); + if (ratingTypeStr != null && !ratingTypeStr.isEmpty()) { + try { + ratingType = LLMJudgmentRatingType.valueOf(ratingTypeStr); + } catch (IllegalArgumentException e) { + log.warn("Invalid ratingType '{}' in context, defaulting to SCORE0_1", ratingTypeStr); + } + } + log.debug("Using ratingType: {} for judgment processing", ratingType); + + ConnectorType connectorType = ConnectorType.OPENAI; + String connectorTypeStr = (String) contextMap.get(LlmJudgmentContext.CONNECTOR_TYPE); + if (connectorTypeStr != null && !connectorTypeStr.isEmpty()) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr); + } catch (IllegalArgumentException e) { + log.warn("Invalid connectorType '{}' in context, defaulting to OPENAI", connectorTypeStr); + } + } + + return LlmJudgmentContext.builder() + .modelId(modelId) + .size(size != null ? size : 5) + .tokenLimit(tokenLimit != null ? tokenLimit : 1000) + .contextFields(contextFields != null ? contextFields : new ArrayList<>()) + .searchConfigurations(searchConfigurations) + .ignoreFailure(ignoreFailure != null ? ignoreFailure : false) + .promptTemplate( + promptTemplate != null + ? promptTemplate + : "Rate the relevance of the search results to the query. SearchText: {{searchText}}; Results: {{hits}}" + ) + .ratingType(ratingType) + .overwriteCache(overwriteCache != null ? overwriteCache : false) + .connectorType(connectorType) + .rateLimit(rateLimit) + .build(); + } + + private LlmJudgmentContext buildContextFromLegacyMetadata( + Map metadata, + List searchConfigurations + ) { + String modelId = (String) metadata.get("modelId"); + int size = (int) metadata.get("size"); + int tokenLimit = (int) metadata.get("tokenLimit"); + List contextFields = (List) metadata.get("contextFields"); + boolean ignoreFailure = (boolean) metadata.get("ignoreFailure"); + String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); + LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + log.debug("No ratingType provided, defaulting to SCORE0_1"); + } + boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); + + ConnectorType connectorType = ConnectorType.OPENAI; + String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); + if (connectorTypeStr != null) { + try { + connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + log.warn("Invalid connectorType '{}' in metadata, defaulting to OpenAI", connectorTypeStr); + } + } + + long rateLimit = parseRateLimit(metadata.get(RATE_LIMIT)); + + return LlmJudgmentContext.builder() + .modelId(modelId) + .size(size) + .tokenLimit(tokenLimit) + .contextFields(contextFields) + .searchConfigurations(searchConfigurations) + .ignoreFailure(ignoreFailure) + .promptTemplate(promptTemplate) + .ratingType(ratingType) + .overwriteCache(overwriteCache) + .connectorType(connectorType) + .rateLimit(rateLimit) + .build(); + } + private void generateLLMJudgmentsAsync( LlmJudgmentContext context, QuerySet querySet, @@ -460,6 +557,7 @@ private void generateLLMJudgmentForQueryText( context.getPromptTemplate(), context.getRatingType(), context.getConnectorType(), + context.getRateLimit(), new ActionListener() { @Override public void onResponse(ChunkResult chunkResult) { diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java index d0375d48..87962662 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java @@ -7,7 +7,6 @@ */ package org.opensearch.searchrelevance.ml; -import static org.opensearch.searchrelevance.common.MLConstants.PARAM_MESSAGES_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_BINARY; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_END; @@ -22,6 +21,8 @@ import java.util.Locale; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.XContentBuilder; @@ -36,13 +37,11 @@ import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; -import lombok.extern.log4j.Log4j2; - /** * Handles ML input/output transformations for search relevance predictions */ -@Log4j2 public class MLInputOutputTransformer { + private static final Logger log = LogManager.getLogger(MLInputOutputTransformer.class); private final LLMConnector connector; @@ -141,7 +140,9 @@ public MLInput createMLInput( Map parameters = new HashMap<>(); String messagesArray = buildMessagesArray(searchText, referenceData, hits, promptTemplate, ratingType); - parameters.put(PARAM_MESSAGES_FIELD, messagesArray); + // Use connector-specific parameter name + String paramName = connector.getMessageParameterName(); + parameters.put(paramName, messagesArray); // Only add response_format if requested (for models that support it) if (includeResponseFormat) { @@ -173,8 +174,14 @@ private String buildMessagesArray( private static String getSystemPrompt(LLMJudgmentRatingType ratingType) { String systemPromptStart; String systemPromptEnd = PROMPT_SEARCH_RELEVANCE_SCORE_END; + + // Handle null ratingType with default + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + } + switch (ratingType) { - case LLMJudgmentRatingType.SCORE0_1: + case SCORE0_1: systemPromptStart = PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; break; default: @@ -184,12 +191,17 @@ private static String getSystemPrompt(LLMJudgmentRatingType ratingType) { } private static String getResponseFormat(LLMJudgmentRatingType ratingType) { + // Handle null ratingType with default + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + } + String schema; switch (ratingType) { - case LLMJudgmentRatingType.SCORE0_1: + case SCORE0_1: schema = RATING_SCORE_NUMERIC_SCHEMA; break; - case LLMJudgmentRatingType.RELEVANT_IRRELEVANT: + case RELEVANT_IRRELEVANT: schema = RATING_SCORE_BINARY_SCHEMA; break; default: diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java index b4695087..057f35ec 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/ClaudeConnector.java @@ -42,4 +42,9 @@ public String extractResponse(Map rawResponse) { public ConnectorType getType() { return ConnectorType.CLAUDE; } + + @Override + public String getMessageParameterName() { + return "messages"; // Claude uses plural "messages" + } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java index 4a9d9531..0c194051 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/CohereConnector.java @@ -30,4 +30,9 @@ public String extractResponse(Map rawResponse) { public ConnectorType getType() { return ConnectorType.COHERE; } + + @Override + public String getMessageParameterName() { + return "message"; // Cohere uses singular "message" + } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java index 8dddfaa2..4b7ad1d0 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/DeepSeekConnector.java @@ -16,4 +16,9 @@ public class DeepSeekConnector extends OpenAIConnector { public ConnectorType getType() { return ConnectorType.DEEPSEEK; } + + @Override + public String getMessageParameterName() { + return "messages"; // DeepSeek uses plural "messages" + } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java index b8c28185..2ec803d8 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/LLMConnector.java @@ -37,4 +37,11 @@ public interface LLMConnector { * @return The ConnectorType enum value */ ConnectorType getType(); + + /** + * Returns the parameter name used for messages in ML input + * + * @return The parameter name (e.g., "messages" for Claude/OpenAI, "message" for Cohere) + */ + String getMessageParameterName(); } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java index 33963286..4810ee91 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/connector/OpenAIConnector.java @@ -38,4 +38,9 @@ public String extractResponse(Map rawResponse) { public ConnectorType getType() { return ConnectorType.OPENAI; } + + @Override + public String getMessageParameterName() { + return "messages"; // OpenAI uses plural "messages" + } } diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java index e0a5a4ac..6f34bb6d 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java @@ -14,6 +14,8 @@ import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; +import static org.opensearch.searchrelevance.common.MLConstants.parseRateLimit; import static org.opensearch.searchrelevance.common.MLConstants.validateTokenLimit; import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; import static org.opensearch.searchrelevance.common.PluginConstants.CLICK_MODEL; @@ -193,6 +195,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } + // Parse rateLimit - optional, defaults to 0 + long rateLimit = parseRateLimit(source.get(RATE_LIMIT)); + createRequest = new PutLlmJudgmentRequest( type, name, @@ -207,7 +212,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli promptTemplate, llmJudgmentRatingType, overwriteCache, - connectorType + connectorType, + rateLimit ); } case UBI_JUDGMENT -> { diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 96b3c3b5..8c8b7bbb 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -7,11 +7,6 @@ */ package org.opensearch.searchrelevance.transport.judgment; -import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; -import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; -import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; -import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; -import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; import static org.opensearch.searchrelevance.ubi.UbiValidator.checkUbiIndicesExist; import java.util.ArrayList; @@ -33,6 +28,7 @@ import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.judgments.BaseJudgmentsProcessor; import org.opensearch.searchrelevance.judgments.JudgmentsProcessorFactory; +import org.opensearch.searchrelevance.judgments.LlmJudgmentContext; import org.opensearch.searchrelevance.model.AsyncStatus; import org.opensearch.searchrelevance.model.Judgment; import org.opensearch.searchrelevance.utils.TimeUtils; @@ -99,18 +95,24 @@ private Map buildMetadata(PutJudgmentRequest request) { Map metadata = new HashMap<>(); switch (request.getType()) { case LLM_JUDGMENT -> { + // Use structured context for complex LLM parameters PutLlmJudgmentRequest llmRequest = (PutLlmJudgmentRequest) request; - metadata.put(MODEL_ID, llmRequest.getModelId()); - metadata.put(CONNECTOR_TYPE, llmRequest.getConnectorType().getValue()); + LlmJudgmentContext context = LlmJudgmentContext.builder() + .modelId(llmRequest.getModelId()) + .connectorType(llmRequest.getConnectorType()) + .rateLimit(llmRequest.getRateLimit()) + .size(llmRequest.getSize()) + .tokenLimit(llmRequest.getTokenLimit()) + .contextFields(llmRequest.getContextFields()) + .ignoreFailure(llmRequest.isIgnoreFailure()) + .promptTemplate(llmRequest.getPromptTemplate()) + .ratingType(llmRequest.getLlmJudgmentRatingType()) + .overwriteCache(llmRequest.isOverwriteCache()) + .build(); + + metadata.put("llmJudgmentContext", context); metadata.put("querySetId", llmRequest.getQuerySetId()); - metadata.put("size", llmRequest.getSize()); metadata.put("searchConfigurationList", llmRequest.getSearchConfigurationList()); - metadata.put("tokenLimit", llmRequest.getTokenLimit()); - metadata.put("contextFields", llmRequest.getContextFields()); - metadata.put("ignoreFailure", llmRequest.isIgnoreFailure()); - metadata.put(PROMPT_TEMPLATE, llmRequest.getPromptTemplate()); - metadata.put(LLM_JUDGMENT_RATING_TYPE, llmRequest.getLlmJudgmentRatingType()); - metadata.put(OVERWRITE_CACHE, llmRequest.isOverwriteCache()); } case UBI_JUDGMENT -> { if (!checkUbiIndicesExist(clusterService)) { diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java index 913d9101..41abc192 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java @@ -63,6 +63,11 @@ public class PutLlmJudgmentRequest extends PutJudgmentRequest { */ private ConnectorType connectorType; + /** + * Rate limit in milliseconds between requests (0 = no limit) + */ + private long rateLimit; + public PutLlmJudgmentRequest( @NonNull JudgmentType type, @NonNull String name, @@ -77,7 +82,8 @@ public PutLlmJudgmentRequest( String promptTemplate, LLMJudgmentRatingType llmJudgmentRatingType, boolean overwriteCache, - ConnectorType connectorType + ConnectorType connectorType, + long rateLimit ) { super(type, name, description); this.modelId = modelId; @@ -91,6 +97,7 @@ public PutLlmJudgmentRequest( this.llmJudgmentRatingType = llmJudgmentRatingType; this.overwriteCache = overwriteCache; this.connectorType = connectorType != null ? connectorType : ConnectorType.OPENAI; // default to OpenAI + this.rateLimit = Math.max(0, rateLimit); // ensure non-negative } public PutLlmJudgmentRequest(StreamInput in) throws IOException { @@ -107,6 +114,8 @@ public PutLlmJudgmentRequest(StreamInput in) throws IOException { this.overwriteCache = Boolean.TRUE.equals(in.readOptionalBoolean()); String connectorTypeStr = in.readOptionalString(); this.connectorType = connectorTypeStr != null ? ConnectorType.valueOf(connectorTypeStr) : ConnectorType.OPENAI; + Long rateLimitValue = in.readOptionalLong(); + this.rateLimit = rateLimitValue != null ? rateLimitValue : 0L; // default to 0 (no limit) if not provided } @Override @@ -123,6 +132,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(llmJudgmentRatingType); out.writeOptionalBoolean(overwriteCache); out.writeOptionalString(connectorType != null ? connectorType.name() : null); + out.writeOptionalLong(rateLimit); } public String getModelId() { @@ -169,4 +179,8 @@ public ConnectorType getConnectorType() { return connectorType; } + public long getRateLimit() { + return rateLimit; + } + } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java index 8cd5f8e8..ca85b2ee 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java @@ -95,7 +95,8 @@ public void testLlmJudgmentRequestStreams() throws IOException { "test_prompt_template", LLMJudgmentRatingType.SCORE0_1, true, - ConnectorType.OPENAI + ConnectorType.OPENAI, + 1000L ); BytesStreamOutput output = new BytesStreamOutput(); @@ -133,7 +134,8 @@ public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOExcep null, null, false, - ConnectorType.OPENAI + ConnectorType.OPENAI, + 1000L ); BytesStreamOutput output = new BytesStreamOutput(); diff --git a/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java index 495c9322..9bbc5d3c 100644 --- a/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java +++ b/src/test/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequestTests.java @@ -34,7 +34,8 @@ public void testConnectorTypeDefaultsToOpenAI() { "{{searchText}} {{hits}}", LLMJudgmentRatingType.SCORE0_1, false, - null // connectorType is null + null, // connectorType is null + 1000L ); assertEquals(ConnectorType.OPENAI, request.getConnectorType()); @@ -55,7 +56,8 @@ public void testConnectorTypeIsPreserved() { "{{searchText}} {{hits}}", LLMJudgmentRatingType.SCORE0_1, false, - ConnectorType.CLAUDE + ConnectorType.CLAUDE, + 1000L ); assertEquals(ConnectorType.CLAUDE, request.getConnectorType()); @@ -77,11 +79,56 @@ public void testAllConnectorTypes() { "{{searchText}} {{hits}}", LLMJudgmentRatingType.SCORE0_1, false, - type + type, + 1000L ); assertEquals(type, request.getConnectorType()); assertNotNull(request.getConnectorType()); } } + + public void testRateLimitDefaultsToZero() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.OPENAI, + 0L // explicitly set rateLimit to 0 + ); + + assertEquals(0L, request.getRateLimit()); + } + + public void testNegativeRateLimitBecomesZero() { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test-judgment", + "Test description", + "test-model-id", + "test-queryset-id", + List.of("test-config"), + 10, + 1000, + List.of("field1"), + false, + "{{searchText}} {{hits}}", + LLMJudgmentRatingType.SCORE0_1, + false, + ConnectorType.OPENAI, + -100L // negative rateLimit + ); + + assertEquals(0L, request.getRateLimit()); // should be clamped to 0 + } } From 460151315ebe9550a21419f93fff971e90f90483 Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Sat, 22 Nov 2025 00:40:11 +0000 Subject: [PATCH 4/6] add AdaptiveRateLimiter Signed-off-by: Fen Qin --- docs/llm-model/claude/connector_validate.sh | 2 +- docs/llm-model/cohere/connector_validate.sh | 2 +- docs/llm-model/openai/connector_validate.sh | 0 .../judgments/LlmJudgmentsProcessor.java | 111 ++------ .../ml/AdaptiveRateLimiter.java | 238 ++++++++++++++++++ .../searchrelevance/ml/MLAccessor.java | 177 +++++++++---- .../plugin/SearchRelevancePlugin.java | 7 + .../judgment/PutJudgmentTransportAction.java | 74 ++++-- .../judgment/LlmJudgmentTemplateIT.java | 41 +-- .../ml/AdaptiveRateLimiterTests.java | 79 ++++++ .../ml/MLAccessorIntegrationTests.java | 77 +++++- 11 files changed, 621 insertions(+), 187 deletions(-) mode change 100644 => 100755 docs/llm-model/claude/connector_validate.sh mode change 100644 => 100755 docs/llm-model/cohere/connector_validate.sh mode change 100644 => 100755 docs/llm-model/openai/connector_validate.sh create mode 100644 src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java create mode 100644 src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java diff --git a/docs/llm-model/claude/connector_validate.sh b/docs/llm-model/claude/connector_validate.sh old mode 100644 new mode 100755 index f1ac8c39..07f78ccc --- a/docs/llm-model/claude/connector_validate.sh +++ b/docs/llm-model/claude/connector_validate.sh @@ -33,7 +33,7 @@ CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/ "model": "us.anthropic.claude-3-5-haiku-20241022-v1:0" }, "client_config": { - "max_connection": 1, + "max_connection": 2, "connection_timeout": 60000, "read_timeout": 60000, "retry_backoff_millis": 3000, diff --git a/docs/llm-model/cohere/connector_validate.sh b/docs/llm-model/cohere/connector_validate.sh old mode 100644 new mode 100755 index 09685e87..3624a66c --- a/docs/llm-model/cohere/connector_validate.sh +++ b/docs/llm-model/cohere/connector_validate.sh @@ -33,7 +33,7 @@ CONNECTOR_RESPONSE=$(curl -s -X POST "${OPENSEARCH_URL}/_plugins/_ml/connectors/ "model": "cohere.command-r-v1:0" }, "client_config": { - "max_connection": 1, + "max_connection": 2, "connection_timeout": 60000, "read_timeout": 60000, "retry_backoff_millis": 3000, diff --git a/docs/llm-model/openai/connector_validate.sh b/docs/llm-model/openai/connector_validate.sh old mode 100644 new mode 100755 diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index b8395073..ae30dc0f 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -8,6 +8,7 @@ package org.opensearch.searchrelevance.judgments; import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; @@ -130,110 +131,28 @@ private void generateJudgmentRatingInternal(Map metadata, Action } private LlmJudgmentContext buildContextFromMetadata(Map metadata, List searchConfigurations) { - // Check if we have a pre-built context (new approach) - Object contextObj = metadata.get("llmJudgmentContext"); - if (contextObj != null) { - if (contextObj instanceof LlmJudgmentContext) { - // Direct object case (in-memory) - LlmJudgmentContext baseContext = (LlmJudgmentContext) contextObj; - return LlmJudgmentContext.builder() - .modelId(baseContext.getModelId()) - .size(baseContext.getSize()) - .tokenLimit(baseContext.getTokenLimit()) - .contextFields(baseContext.getContextFields()) - .searchConfigurations(searchConfigurations) - .ignoreFailure(baseContext.isIgnoreFailure()) - .promptTemplate(baseContext.getPromptTemplate()) - .ratingType(baseContext.getRatingType()) - .overwriteCache(baseContext.isOverwriteCache()) - .connectorType(baseContext.getConnectorType()) - .rateLimit(baseContext.getRateLimit()) - .build(); - } else if (contextObj instanceof Map) { - // Deserialized from OpenSearch as Map - return buildContextFromMap((Map) contextObj, searchConfigurations); - } - } - - // Fallback to legacy metadata parsing for backward compatibility - return buildContextFromLegacyMetadata(metadata, searchConfigurations); - } - - private LlmJudgmentContext buildContextFromMap(Map contextMap, List searchConfigurations) { - String modelId = (String) contextMap.get(LlmJudgmentContext.MODEL_ID); - Integer size = (Integer) contextMap.get(LlmJudgmentContext.SIZE); - Integer tokenLimit = (Integer) contextMap.get(LlmJudgmentContext.TOKEN_LIMIT); - List contextFields = (List) contextMap.get(LlmJudgmentContext.CONTEXT_FIELDS); - Boolean ignoreFailure = (Boolean) contextMap.get(LlmJudgmentContext.IGNORE_FAILURE); - String promptTemplate = (String) contextMap.get(LlmJudgmentContext.PROMPT_TEMPLATE); - Boolean overwriteCache = (Boolean) contextMap.get(LlmJudgmentContext.OVERWRITE_CACHE); - - Long rateLimit = 1000L; - Object rateLimitObj = contextMap.get(LlmJudgmentContext.RATE_LIMIT); - if (rateLimitObj instanceof Number) { - rateLimit = ((Number) rateLimitObj).longValue(); - } - - // Parse enum values with proper defaults - LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; - String ratingTypeStr = (String) contextMap.get(LlmJudgmentContext.RATING_TYPE); - if (ratingTypeStr != null && !ratingTypeStr.isEmpty()) { - try { - ratingType = LLMJudgmentRatingType.valueOf(ratingTypeStr); - } catch (IllegalArgumentException e) { - log.warn("Invalid ratingType '{}' in context, defaulting to SCORE0_1", ratingTypeStr); - } - } - log.debug("Using ratingType: {} for judgment processing", ratingType); - - ConnectorType connectorType = ConnectorType.OPENAI; - String connectorTypeStr = (String) contextMap.get(LlmJudgmentContext.CONNECTOR_TYPE); - if (connectorTypeStr != null && !connectorTypeStr.isEmpty()) { - try { - connectorType = ConnectorType.valueOf(connectorTypeStr); - } catch (IllegalArgumentException e) { - log.warn("Invalid connectorType '{}' in context, defaulting to OPENAI", connectorTypeStr); - } - } - - return LlmJudgmentContext.builder() - .modelId(modelId) - .size(size != null ? size : 5) - .tokenLimit(tokenLimit != null ? tokenLimit : 1000) - .contextFields(contextFields != null ? contextFields : new ArrayList<>()) - .searchConfigurations(searchConfigurations) - .ignoreFailure(ignoreFailure != null ? ignoreFailure : false) - .promptTemplate( - promptTemplate != null - ? promptTemplate - : "Rate the relevance of the search results to the query. SearchText: {{searchText}}; Results: {{hits}}" - ) - .ratingType(ratingType) - .overwriteCache(overwriteCache != null ? overwriteCache : false) - .connectorType(connectorType) - .rateLimit(rateLimit) - .build(); - } - - private LlmJudgmentContext buildContextFromLegacyMetadata( - Map metadata, - List searchConfigurations - ) { String modelId = (String) metadata.get("modelId"); - int size = (int) metadata.get("size"); - int tokenLimit = (int) metadata.get("tokenLimit"); + Integer sizeObj = (Integer) metadata.get("size"); + Integer tokenLimitObj = (Integer) metadata.get("tokenLimit"); List contextFields = (List) metadata.get("contextFields"); - boolean ignoreFailure = (boolean) metadata.get("ignoreFailure"); + Boolean ignoreFailureObj = (Boolean) metadata.get("ignoreFailure"); String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); + Boolean overwriteCacheObj = (Boolean) metadata.get(OVERWRITE_CACHE); + String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); + + // Apply defaults for null values + int size = sizeObj != null ? sizeObj : 5; + int tokenLimit = tokenLimitObj != null ? tokenLimitObj : 1000; + boolean ignoreFailure = ignoreFailureObj != null ? ignoreFailureObj : false; + boolean overwriteCache = overwriteCacheObj != null ? overwriteCacheObj : false; + if (ratingType == null) { ratingType = LLMJudgmentRatingType.SCORE0_1; log.debug("No ratingType provided, defaulting to SCORE0_1"); } - boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); ConnectorType connectorType = ConnectorType.OPENAI; - String connectorTypeStr = (String) metadata.get(CONNECTOR_TYPE); if (connectorTypeStr != null) { try { connectorType = ConnectorType.valueOf(connectorTypeStr.toUpperCase(Locale.ROOT)); @@ -248,10 +167,10 @@ private LlmJudgmentContext buildContextFromLegacyMetadata( .modelId(modelId) .size(size) .tokenLimit(tokenLimit) - .contextFields(contextFields) + .contextFields(contextFields != null ? contextFields : new ArrayList<>()) .searchConfigurations(searchConfigurations) .ignoreFailure(ignoreFailure) - .promptTemplate(promptTemplate) + .promptTemplate(promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE) .ratingType(ratingType) .overwriteCache(overwriteCache) .connectorType(connectorType) diff --git a/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java b/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java new file mode 100644 index 00000000..c5bf6dd5 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiter.java @@ -0,0 +1,238 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import java.util.Locale; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; + +/** + * Adaptive rate limiter that learns optimal rates per model and handles circuit breaking + */ +public class AdaptiveRateLimiter { + private static final Logger log = LogManager.getLogger(AdaptiveRateLimiter.class); + + private final ConcurrentMap rateLimitStates = new ConcurrentHashMap<>(); + private volatile ScheduledExecutorService cleanupScheduler; + private final Object schedulerLock = new Object(); + + public AdaptiveRateLimiter() { + // Lazy initialization - don't create threads until actually needed + } + + private ScheduledExecutorService getOrCreateScheduler() { + if (cleanupScheduler == null) { + synchronized (schedulerLock) { + if (cleanupScheduler == null) { + cleanupScheduler = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, "adaptive-rate-limiter-cleanup"); + t.setDaemon(true); + return t; + } + }); + // Schedule cleanup every hour + cleanupScheduler.scheduleAtFixedRate(this::cleanupOldEntries, 1, 1, TimeUnit.HOURS); + } + } + } + return cleanupScheduler; + } + + public CompletableFuture applyRateLimit(String modelId, ConnectorType connectorType, long userRateLimit) { + String key = getKey(modelId, connectorType); + RateLimitState state = rateLimitStates.computeIfAbsent(key, k -> new RateLimitState(userRateLimit)); + + // Circuit breaker: Stop trying if model seems dead + if (state.shouldStopTrying()) { + return CompletableFuture.failedFuture(new RuntimeException("Model appears to be unavailable: " + modelId)); + } + + long delayMs = state.calculateDelay(); + + if (delayMs <= 0) { + return CompletableFuture.completedFuture(null); + } + + log.debug("Applying rate limit for {}: {}ms delay", key, delayMs); + + // Non-blocking delay using our managed executor + CompletableFuture future = new CompletableFuture<>(); + getOrCreateScheduler().schedule(() -> future.complete(null), delayMs, TimeUnit.MILLISECONDS); + return future; + } + + public void recordResult(String modelId, ConnectorType connectorType, boolean success, Throwable error) { + String key = getKey(modelId, connectorType); + RateLimitState state = rateLimitStates.get(key); + + if (state != null) { + if (success) { + state.onSuccess(); + } else if (isRateLimitError(error)) { + state.onRateLimit(); + } else if (isModelUnavailableError(error)) { + state.onModelUnavailable(); + } else { + state.onOtherError(); + } + } + } + + private String getKey(String modelId, ConnectorType connectorType) { + return modelId + ":" + connectorType.getValue(); + } + + private boolean isRateLimitError(Throwable error) { + if (error == null) return false; + String message = error.getMessage().toLowerCase(Locale.ROOT); + return message.contains("rate limit") + || message.contains("throttling") + || message.contains("too many requests") + || message.contains("high request rate") + || message.contains("acquire operation took longer") + || message.contains("connection from the pool"); + } + + private boolean isModelUnavailableError(Throwable error) { + if (error == null) return false; + String message = error.getMessage().toLowerCase(Locale.ROOT); + return message.contains("model not found") || message.contains("service unavailable") || message.contains("internal server error"); + } + + private void cleanupOldEntries() { + long cutoff = System.currentTimeMillis() - 3600_000; // 1 hour + java.util.concurrent.atomic.AtomicInteger removed = new java.util.concurrent.atomic.AtomicInteger(0); + + rateLimitStates.entrySet().removeIf(entry -> { + if (entry.getValue().lastRequestTime < cutoff) { + removed.incrementAndGet(); + return true; + } + return false; + }); + + if (removed.get() > 0) { + log.debug("Cleaned up {} old rate limit entries", removed.get()); + } + } + + public void scheduleTask(Runnable task, long delayMs) { + getOrCreateScheduler().schedule(task, delayMs, TimeUnit.MILLISECONDS); + } + + public void shutdown() { + synchronized (schedulerLock) { + if (cleanupScheduler != null) { + cleanupScheduler.shutdown(); + try { + if (!cleanupScheduler.awaitTermination(5, TimeUnit.SECONDS)) { + cleanupScheduler.shutdownNow(); + } + } catch (InterruptedException e) { + cleanupScheduler.shutdownNow(); + Thread.currentThread().interrupt(); + } + cleanupScheduler = null; + } + } + } + + private static class RateLimitState { + private volatile long currentDelayMs; + private volatile long lastRequestTime; + private volatile long lastSuccessTime; + private volatile int consecutiveSuccesses; + private volatile int consecutiveFailures; + private final long initialDelayMs; + + // Conservative parameters + private static final double BACKOFF_MULTIPLIER = 2.0; + private static final double RECOVERY_FACTOR = 0.9; + private static final int SUCCESSES_BEFORE_RECOVERY = 5; + private static final long MAX_DELAY_MS = 300_000; // 5 minutes + private static final int MAX_CONSECUTIVE_FAILURES = 10; + private static final long MODEL_DEAD_THRESHOLD_MS = 1800_000; // 30 minutes + private static final long CIRCUIT_OPEN_DURATION_MS = 300_000; // 5 minutes + + public RateLimitState(long initialDelayMs) { + this.initialDelayMs = Math.max(0, initialDelayMs); + this.currentDelayMs = this.initialDelayMs; + this.lastSuccessTime = System.currentTimeMillis(); + this.lastRequestTime = System.currentTimeMillis(); + } + + public long calculateDelay() { + lastRequestTime = System.currentTimeMillis(); + + // Circuit breaker: Longer delay if too many failures + if (isCircuitOpen()) { + return CIRCUIT_OPEN_DURATION_MS; + } + + return currentDelayMs; + } + + public boolean isCircuitOpen() { + return consecutiveFailures >= MAX_CONSECUTIVE_FAILURES; + } + + public boolean shouldStopTrying() { + long timeSinceLastSuccess = System.currentTimeMillis() - lastSuccessTime; + return timeSinceLastSuccess > MODEL_DEAD_THRESHOLD_MS; + } + + public void onSuccess() { + consecutiveFailures = 0; + consecutiveSuccesses++; + lastSuccessTime = System.currentTimeMillis(); + + // Conservative recovery - only after many successes + if (consecutiveSuccesses >= SUCCESSES_BEFORE_RECOVERY) { + currentDelayMs = Math.max( + initialDelayMs, // Never go below user's initial setting + (long) (currentDelayMs * RECOVERY_FACTOR) + ); + consecutiveSuccesses = 0; + log.debug("Rate limit recovered to {}ms", currentDelayMs); + } + } + + public void onRateLimit() { + consecutiveSuccesses = 0; + consecutiveFailures++; + + // Aggressive backoff on rate limit + long oldDelay = currentDelayMs; + currentDelayMs = Math.min(MAX_DELAY_MS, (long) (currentDelayMs * BACKOFF_MULTIPLIER)); + + log.debug("Rate limit hit, increased delay from {}ms to {}ms", oldDelay, currentDelayMs); + } + + public void onModelUnavailable() { + consecutiveSuccesses = 0; + consecutiveFailures++; + log.debug("Model unavailable error, consecutive failures: {}", consecutiveFailures); + } + + public void onOtherError() { + // Don't change rate limiting for non-rate-limit errors + consecutiveSuccesses = 0; + } + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java index ebea1e0c..5e3b740e 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java @@ -10,8 +10,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -31,6 +29,7 @@ public class MLAccessor { private final MachineLearningNodeClient mlClient; private final MLInputOutputTransformer transformer; + private final AdaptiveRateLimiter rateLimiter; private static final int MAX_RETRY_NUMBER = 3; private static final long RETRY_DELAY_MS = 1000; @@ -38,6 +37,16 @@ public class MLAccessor { public MLAccessor(MachineLearningNodeClient mlClient) { this.mlClient = mlClient; this.transformer = new MLInputOutputTransformer(); + this.rateLimiter = new AdaptiveRateLimiter(); + } + + /** + * Shutdown the MLAccessor and clean up resources + */ + public void shutdown() { + if (rateLimiter != null) { + rateLimiter.shutdown(); + } } public void predict( @@ -49,15 +58,17 @@ public void predict( String promptTemplate, LLMJudgmentRatingType ratingType, ConnectorType connectorType, + long rateLimit, ActionListener progressListener ) { log.debug( - "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}, connectorType: {}", + "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}, connectorType: {}, rateLimit: {}ms", modelId, searchText, hits.size(), ratingType, - connectorType + connectorType, + rateLimit ); // Create transformer with appropriate connector @@ -77,12 +88,19 @@ public void predict( ChunkProcessingContext context = new ChunkProcessingContext(mlInputs.size(), progressListener); for (int i = 0; i < mlInputs.size(); i++) { - processChunk(modelId, mlInputs.get(i), i, context); + processChunk(modelId, mlInputs.get(i), i, connectorType, rateLimit, context); } } - private void processChunk(String modelId, MLInput mlInput, int chunkIndex, ChunkProcessingContext context) { - processChunkWithFallback(modelId, mlInput, chunkIndex, false, context); + private void processChunk( + String modelId, + MLInput mlInput, + int chunkIndex, + ConnectorType connectorType, + long rateLimit, + ChunkProcessingContext context + ) { + processChunkWithFallback(modelId, mlInput, chunkIndex, false, connectorType, rateLimit, context); } private void processChunkWithFallback( @@ -90,28 +108,43 @@ private void processChunkWithFallback( MLInput mlInput, int chunkIndex, boolean triedWithoutResponseFormat, + ConnectorType connectorType, + long rateLimit, ChunkProcessingContext context ) { - predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, 0, triedWithoutResponseFormat, ActionListener.wrap(response -> { - log.info("Chunk {} processed successfully", chunkIndex); - String processedResponse = cleanResponse(response); - - // Check if parsing failed (empty ratings array) and we haven't tried without response_format yet - if ("[]".equals(processedResponse) && !triedWithoutResponseFormat) { - log.warn( - "Chunk {} returned empty ratings with response_format. Retrying without response_format for GPT-3.5 compatibility...", - chunkIndex - ); - // Create new MLInput without response_format and retry - MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); - scheduleRetry(() -> processChunkWithFallback(modelId, mlInputWithoutFormat, chunkIndex, true, context), RETRY_DELAY_MS); - } else { - context.handleSuccess(chunkIndex, processedResponse); - } - }, e -> { - log.error("Chunk {} failed after all retries", chunkIndex, e); - context.handleFailure(chunkIndex, e); - })); + predictSingleChunkWithRetry( + modelId, + mlInput, + chunkIndex, + 0, + triedWithoutResponseFormat, + connectorType, + rateLimit, + ActionListener.wrap(response -> { + log.info("Chunk {} processed successfully", chunkIndex); + String processedResponse = cleanResponse(response); + + // Check if parsing failed (empty ratings array) and we haven't tried without response_format yet + // Only apply this retry logic for OpenAI connectors + if ("[]".equals(processedResponse) && !triedWithoutResponseFormat && connectorType == ConnectorType.OPENAI) { + log.warn( + "Chunk {} returned empty ratings with response_format. Retrying without response_format for GPT-3.5 compatibility...", + chunkIndex + ); + // Create new MLInput without response_format and retry + MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); + scheduleRetry( + () -> processChunkWithFallback(modelId, mlInputWithoutFormat, chunkIndex, true, connectorType, rateLimit, context), + RETRY_DELAY_MS + ); + } else { + context.handleSuccess(chunkIndex, processedResponse); + } + }, e -> { + log.error("Chunk {} failed after all retries", chunkIndex, e); + context.handleFailure(chunkIndex, e); + }) + ); } private String cleanResponse(String response) { @@ -133,9 +166,11 @@ private void predictSingleChunkWithRetry( int chunkIndex, int retryCount, boolean triedWithoutResponseFormat, + ConnectorType connectorType, + long rateLimit, ActionListener chunkListener ) { - predictSingleChunk(modelId, mlInput, new ActionListener() { + predictSingleChunk(modelId, mlInput, connectorType, rateLimit, new ActionListener() { @Override public void onResponse(String response) { log.debug( @@ -156,8 +191,8 @@ public void onFailure(Exception e) { triedWithoutResponseFormat, retryCount ); - // If we haven't tried without response_format yet, try that first before regular retries - if (!triedWithoutResponseFormat) { + // Only try response_format fallback for OpenAI connectors + if (!triedWithoutResponseFormat && connectorType == ConnectorType.OPENAI) { log.warn( "Chunk {} failed with response_format. Retrying without response_format for GPT-3.5 compatibility...", chunkIndex @@ -169,7 +204,16 @@ public void onFailure(Exception e) { long delay = RETRY_DELAY_MS; scheduleRetry( - () -> predictSingleChunkWithRetry(modelId, mlInputWithoutFormat, chunkIndex, 0, true, chunkListener), + () -> predictSingleChunkWithRetry( + modelId, + mlInputWithoutFormat, + chunkIndex, + 0, + true, + connectorType, + rateLimit, + chunkListener + ), delay ); } else if (retryCount < MAX_RETRY_NUMBER) { @@ -177,7 +221,16 @@ public void onFailure(Exception e) { long delay = RETRY_DELAY_MS * (long) Math.pow(2, retryCount); scheduleRetry( - () -> predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, true, chunkListener), + () -> predictSingleChunkWithRetry( + modelId, + mlInput, + chunkIndex, + retryCount + 1, + true, + connectorType, + rateLimit, + chunkListener + ), delay ); } else { @@ -207,25 +260,55 @@ private MLInput recreateMLInputWithoutResponseFormat(MLInput originalInput) { } private void scheduleRetry(Runnable runnable, long delayMs) { - CompletableFuture.delayedExecutor(delayMs, TimeUnit.MILLISECONDS).execute(runnable); + // Use the rate limiter's managed scheduler for retries to avoid thread leaks + rateLimiter.scheduleTask(runnable, delayMs); } - public void predictSingleChunk(String modelId, MLInput mlInput, ActionListener listener) { - log.debug("DEBUG: predictSingleChunk called with modelId: {}", modelId); - RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); - Map params = dataset.getParameters(); + public void predictSingleChunk( + String modelId, + MLInput mlInput, + ConnectorType connectorType, + long rateLimit, + ActionListener listener + ) { log.debug( - "DEBUG: MLInput parameters - has response_format: {}, has messages: {}", - params.containsKey("response_format"), - params.containsKey("messages") + "DEBUG: predictSingleChunk called with modelId: {}, connectorType: {}, rateLimit: {}ms", + modelId, + connectorType, + rateLimit ); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - log.debug("DEBUG: ML prediction succeeded, extracting response content"); - listener.onResponse(transformer.extractResponseContent(mlOutput)); - }, e -> { - log.debug("DEBUG: ML prediction failed with error: {}", e.getMessage()); - listener.onFailure(e); - })); + + // Apply rate limiting before making the prediction + rateLimiter.applyRateLimit(modelId, connectorType, rateLimit).whenComplete((result, rateLimitError) -> { + if (rateLimitError != null) { + log.error("Rate limiting failed for modelId: {}", modelId, rateLimitError); + listener.onFailure(new Exception(rateLimitError)); + return; + } + + // Create connector-specific transformer + MLInputOutputTransformer connectorTransformer = new MLInputOutputTransformer(LLMConnectorFactory.create(connectorType)); + + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map params = dataset.getParameters(); + log.debug( + "DEBUG: MLInput parameters - has response_format: {}, has messages: {}", + params.containsKey("response_format"), + params.containsKey("messages") + ); + + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + log.debug("DEBUG: ML prediction succeeded, extracting response content"); + // Record successful result for rate limiter learning + rateLimiter.recordResult(modelId, connectorType, true, null); + listener.onResponse(connectorTransformer.extractResponseContent(mlOutput)); + }, e -> { + log.debug("DEBUG: ML prediction failed with error: {}", e.getMessage()); + // Record failed result for rate limiter learning + rateLimiter.recordResult(modelId, connectorType, false, e); + listener.onFailure(e); + })); + }); } } diff --git a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java index 35202ed0..6d5cdc4b 100644 --- a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java +++ b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java @@ -393,4 +393,11 @@ public List> getSettings() { public List> getExecutorBuilders(Settings settings) { return List.of(SearchRelevanceExecutor.getExecutorBuilder(settings)); } + + @Override + public void close() throws IOException { + if (mlAccessor != null) { + mlAccessor.shutdown(); + } + } } diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 8c8b7bbb..f984549d 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -7,6 +7,12 @@ */ package org.opensearch.searchrelevance.transport.judgment; +import static org.opensearch.searchrelevance.common.MLConstants.CONNECTOR_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.RATE_LIMIT; import static org.opensearch.searchrelevance.ubi.UbiValidator.checkUbiIndicesExist; import java.util.ArrayList; @@ -28,9 +34,10 @@ import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.judgments.BaseJudgmentsProcessor; import org.opensearch.searchrelevance.judgments.JudgmentsProcessorFactory; -import org.opensearch.searchrelevance.judgments.LlmJudgmentContext; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.model.AsyncStatus; import org.opensearch.searchrelevance.model.Judgment; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.utils.TimeUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -42,6 +49,20 @@ public class PutJudgmentTransportAction extends HandledTransportAction buildMetadata(PutJudgmentRequest request) { Map metadata = new HashMap<>(); switch (request.getType()) { case LLM_JUDGMENT -> { - // Use structured context for complex LLM parameters PutLlmJudgmentRequest llmRequest = (PutLlmJudgmentRequest) request; - LlmJudgmentContext context = LlmJudgmentContext.builder() - .modelId(llmRequest.getModelId()) - .connectorType(llmRequest.getConnectorType()) - .rateLimit(llmRequest.getRateLimit()) - .size(llmRequest.getSize()) - .tokenLimit(llmRequest.getTokenLimit()) - .contextFields(llmRequest.getContextFields()) - .ignoreFailure(llmRequest.isIgnoreFailure()) - .promptTemplate(llmRequest.getPromptTemplate()) - .ratingType(llmRequest.getLlmJudgmentRatingType()) - .overwriteCache(llmRequest.isOverwriteCache()) - .build(); - - metadata.put("llmJudgmentContext", context); - metadata.put("querySetId", llmRequest.getQuerySetId()); - metadata.put("searchConfigurationList", llmRequest.getSearchConfigurationList()); + + // Store flat metadata fields for compatibility + metadata.put(QUERY_SET_ID, llmRequest.getQuerySetId()); + metadata.put(SEARCH_CONFIGURATION_LIST, llmRequest.getSearchConfigurationList()); + metadata.put(MODEL_ID, llmRequest.getModelId()); + metadata.put(SIZE, llmRequest.getSize()); + metadata.put(TOKEN_LIMIT, llmRequest.getTokenLimit()); + metadata.put(CONTEXT_FIELDS, llmRequest.getContextFields()); + metadata.put(IGNORE_FAILURE, llmRequest.isIgnoreFailure()); + metadata.put( + PROMPT_TEMPLATE, + llmRequest.getPromptTemplate() != null ? llmRequest.getPromptTemplate() : DEFAULT_PROMPT_TEMPLATE + ); + metadata.put( + LLM_JUDGMENT_RATING_TYPE, + llmRequest.getLlmJudgmentRatingType() != null ? llmRequest.getLlmJudgmentRatingType() : LLMJudgmentRatingType.SCORE0_1 + ); + metadata.put(OVERWRITE_CACHE, llmRequest.isOverwriteCache()); + metadata.put( + CONNECTOR_TYPE, + llmRequest.getConnectorType() != null ? llmRequest.getConnectorType().name() : ConnectorType.OPENAI.name() + ); + metadata.put(RATE_LIMIT, llmRequest.getRateLimit()); } case UBI_JUDGMENT -> { if (!checkUbiIndicesExist(clusterService)) { throw new SearchRelevanceException("UBI is not initialized", RestStatus.CONFLICT); } - ; PutUbiJudgmentRequest ubiRequest = (PutUbiJudgmentRequest) request; - metadata.put("clickModel", ubiRequest.getClickModel()); - metadata.put("maxRank", ubiRequest.getMaxRank()); - metadata.put("startDate", ubiRequest.getStartDate()); - metadata.put("endDate", ubiRequest.getEndDate()); + metadata.put(CLICK_MODEL, ubiRequest.getClickModel()); + metadata.put(MAX_RANK, ubiRequest.getMaxRank()); + metadata.put(START_DATE, ubiRequest.getStartDate()); + metadata.put(END_DATE, ubiRequest.getEndDate()); } case IMPORT_JUDGMENT -> { PutImportJudgmentRequest importRequest = (PutImportJudgmentRequest) request; - metadata.put("judgmentRatings", importRequest.getJudgmentRatings()); + metadata.put(JUDGMENT_RATINGS, importRequest.getJudgmentRatings()); } } return metadata; diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java index 91cca1c3..57cc92bd 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java @@ -118,14 +118,16 @@ public void testLlmJudgmentWithPromptTemplate_thenSuccessful() { assertEquals("LLM_JUDGMENT", source.get("type")); assertNotNull(source.get("status")); // Should be COMPLETED or IN_PROGRESS - // Verify metadata contains new fields + // Verify metadata contains structured context Map metadata = (Map) source.get("metadata"); assertNotNull(metadata); - assertNotNull(metadata.get("promptTemplate")); - assertTrue(((String) metadata.get("promptTemplate")).contains("{{queryText}}")); - assertNotNull(metadata.get("llmJudgmentRatingType")); - assertEquals("SCORE0_1", metadata.get("llmJudgmentRatingType")); - assertNotNull(metadata.get("overwriteCache")); + Map context = (Map) metadata.get("llmJudgmentContext"); + assertNotNull(context); + assertNotNull(context.get("promptTemplate")); + assertTrue(((String) context.get("promptTemplate")).contains("{{queryText}}")); + assertNotNull(context.get("ratingType")); + assertEquals("SCORE0_1", context.get("ratingType")); + assertNotNull(context.get("overwriteCache")); // Verify judgmentRatings format List> judgmentRatings = (List>) source.get("judgmentRatings"); @@ -202,7 +204,8 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgment01Doc = entityAsMap(getJudgment01Response); Map source01 = (Map) judgment01Doc.get("_source"); Map metadata01 = (Map) source01.get("metadata"); - assertEquals("SCORE0_1", metadata01.get("llmJudgmentRatingType")); + Map context01 = (Map) metadata01.get("llmJudgmentContext"); + assertEquals("SCORE0_1", context01.get("ratingType")); // Test RELEVANT_IRRELEVANT rating type String binaryBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentBinary.json").toURI())); @@ -234,7 +237,8 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgmentBinaryDoc = entityAsMap(getJudgmentBinaryResponse); Map sourceBinary = (Map) judgmentBinaryDoc.get("_source"); Map metadataBinary = (Map) sourceBinary.get("metadata"); - assertEquals("RELEVANT_IRRELEVANT", metadataBinary.get("llmJudgmentRatingType")); + Map contextBinary = (Map) metadataBinary.get("llmJudgmentContext"); + assertEquals("RELEVANT_IRRELEVANT", contextBinary.get("ratingType")); } @SneakyThrows @@ -298,7 +302,8 @@ public void testLlmJudgmentWithOverwriteCache_thenSuccessful() { Map judgmentTrueDoc = entityAsMap(getJudgmentTrueResponse); Map sourceTrue = (Map) judgmentTrueDoc.get("_source"); Map metadataTrue = (Map) sourceTrue.get("metadata"); - assertEquals(true, metadataTrue.get("overwriteCache")); + Map contextTrue = (Map) metadataTrue.get("llmJudgmentContext"); + assertEquals(true, contextTrue.get("overwriteCache")); // Test with overwriteCache = false String overwriteFalseBody = Files.readString( @@ -332,7 +337,8 @@ public void testLlmJudgmentWithOverwriteCache_thenSuccessful() { Map judgmentFalseDoc = entityAsMap(getJudgmentFalseResponse); Map sourceFalse = (Map) judgmentFalseDoc.get("_source"); Map metadataFalse = (Map) sourceFalse.get("metadata"); - assertEquals(false, metadataFalse.get("overwriteCache")); + Map contextFalse = (Map) metadataFalse.get("llmJudgmentContext"); + assertEquals(false, contextFalse.get("overwriteCache")); } @SneakyThrows @@ -394,18 +400,21 @@ public void testLlmJudgmentWithoutOptionalFields_thenSuccessfulWithDefaults() { Map judgmentDoc = entityAsMap(getJudgmentResponse); Map source = (Map) judgmentDoc.get("_source"); Map metadata = (Map) source.get("metadata"); + Map context = (Map) metadata.get("llmJudgmentContext"); // promptTemplate should have the default value when not provided - Object promptTemplate = metadata.get("promptTemplate"); + Object promptTemplate = context.get("promptTemplate"); assertNotNull("promptTemplate should not be null when not provided", promptTemplate); assertEquals("promptTemplate should have default value", DEFAULT_PROMPT_TEMPLATE, promptTemplate); - // llmJudgmentRatingType should have a default or be null - Object ratingType = metadata.get("llmJudgmentRatingType"); - // Either null or has a default value + // ratingType should have a default value + Object ratingType = context.get("ratingType"); + assertNotNull("ratingType should not be null", ratingType); + assertEquals("ratingType should have default value", "SCORE0_1", ratingType); // overwriteCache should default to false - Object overwriteCache = metadata.get("overwriteCache"); - assertTrue(overwriteCache == null || overwriteCache.equals(false)); + Object overwriteCache = context.get("overwriteCache"); + assertNotNull("overwriteCache should not be null", overwriteCache); + assertEquals("overwriteCache should default to false", false, overwriteCache); } } diff --git a/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java b/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java new file mode 100644 index 00000000..1a4de45f --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/AdaptiveRateLimiterTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.opensearch.searchrelevance.ml.connector.ConnectorType; +import org.opensearch.test.OpenSearchTestCase; + +public class AdaptiveRateLimiterTests extends OpenSearchTestCase { + + private AdaptiveRateLimiter rateLimiter; + + @Override + public void setUp() throws Exception { + super.setUp(); + rateLimiter = new AdaptiveRateLimiter(); + } + + @Override + public void tearDown() throws Exception { + if (rateLimiter != null) { + rateLimiter.shutdown(); + } + super.tearDown(); + } + + public void testNoRateLimitWhenZero() throws Exception { + long startTime = System.currentTimeMillis(); + + CompletableFuture future = rateLimiter.applyRateLimit("test-model", ConnectorType.OPENAI, 0L); + future.get(1, TimeUnit.SECONDS); + + long elapsed = System.currentTimeMillis() - startTime; + assertTrue("Should complete immediately with 0 rate limit", elapsed < 100); + } + + public void testRateLimitApplied() throws Exception { + // Test that non-zero rate limit creates a delay future (don't wait for it to avoid thread leaks) + CompletableFuture future = rateLimiter.applyRateLimit("test-model", ConnectorType.CLAUDE, 100L); + assertNotNull("Should return a future for rate limiting", future); + assertFalse("Future should not be completed immediately for non-zero rate limit", future.isDone()); + } + + public void testSuccessRecording() { + // Should not throw exception + rateLimiter.recordResult("test-model", ConnectorType.OPENAI, true, null); + } + + public void testRateLimitErrorRecording() { + Exception rateLimitError = new RuntimeException("Rate limit exceeded"); + rateLimiter.recordResult("test-model", ConnectorType.CLAUDE, false, rateLimitError); + } + + public void testModelUnavailableError() { + Exception unavailableError = new RuntimeException("Model not found"); + rateLimiter.recordResult("test-model", ConnectorType.COHERE, false, unavailableError); + } + + public void testCircuitBreakerAfterManyFailures() { + String modelId = "failing-model"; + ConnectorType connectorType = ConnectorType.CLAUDE; + + // Record many failures to trigger circuit breaker + for (int i = 0; i < 15; i++) { + rateLimiter.recordResult(modelId, connectorType, false, new RuntimeException("Service unavailable")); + } + + // Test that circuit breaker creates a future (don't wait to avoid thread leaks) + CompletableFuture future = rateLimiter.applyRateLimit(modelId, connectorType, 100L); + assertNotNull("Should return a future even with circuit breaker", future); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java index fac73814..a31ff7cd 100644 --- a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java @@ -49,6 +49,16 @@ */ public class MLAccessorIntegrationTests extends OpenSearchTestCase { + private MLAccessor mlAccessor; + + @Override + public void tearDown() throws Exception { + if (mlAccessor != null) { + mlAccessor.shutdown(); + } + super.tearDown(); + } + /** * Note: GPT-3.5 fallback testing is documented in TESTING_GPT35_FALLBACK.md as "Scenario 2" * This scenario requires triggering scheduleRetry which creates CompletableFuture threads that leak. @@ -63,7 +73,7 @@ public class MLAccessorIntegrationTests extends OpenSearchTestCase { */ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exception { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); - MLAccessor mlAccessor = new MLAccessor(mlClient); + mlAccessor = new MLAccessor(mlClient); AtomicInteger attemptCount = new AtomicInteger(0); CountDownLatch latch = new CountDownLatch(1); @@ -101,7 +111,8 @@ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exc "Test prompt", LLMJudgmentRatingType.SCORE0_1, ConnectorType.OPENAI, - ActionListener.wrap(chunkResult -> { + 1000L, + ActionListener.wrap((ChunkResult chunkResult) -> { result.set(chunkResult); latch.countDown(); }, e -> latch.countDown()) @@ -135,6 +146,68 @@ public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exc * daemon threads that cannot be properly cleaned up in the OpenSearch test framework. */ + /** + * Test that non-OpenAI connectors (Claude, Cohere, DeepSeek) don't trigger response_format retry logic. + * This verifies the fix for the issue where all connector types were attempting response_format fallback. + */ + public void testNonOpenAIConnectors_DoNotUseResponseFormatRetry() throws Exception { + // Test each non-OpenAI connector type + ConnectorType[] nonOpenAIConnectors = { ConnectorType.CLAUDE, ConnectorType.COHERE, ConnectorType.DEEPSEEK }; + + for (ConnectorType connectorType : nonOpenAIConnectors) { + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + mlAccessor = new MLAccessor(mlClient); + + AtomicInteger attemptCount = new AtomicInteger(0); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference error = new AtomicReference<>(); + + // Mock ML client to always fail - this should NOT trigger response_format retry for non-OpenAI + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + attemptCount.incrementAndGet(); + + // Fail immediately - non-OpenAI connectors should not retry with response_format removal + listener.onFailure(new RuntimeException("Simulated failure for " + connectorType)); + return null; + }).when(mlClient).predict(any(), any(MLInput.class), any()); + + // Execute prediction + Map hits = Map.of("doc1", "test content"); + mlAccessor.predict( + "test-model", + 4000, + "test query", + new HashMap<>(), + hits, + "Test prompt", + LLMJudgmentRatingType.SCORE0_1, + connectorType, + 1000L, + ActionListener.wrap((ChunkResult chunkResult) -> { + // Should not succeed + latch.countDown(); + }, e -> { + error.set(e); + latch.countDown(); + }) + ); + + assertTrue("Should complete for " + connectorType, latch.await(15, TimeUnit.SECONDS)); + + // Verify failure occurred (as expected) + assertNotNull("Should have failed for " + connectorType, error.get()); + + // Key assertion: For non-OpenAI connectors, should attempt regular retries but not response_format retry + // The fix ensures that only OpenAI connectors attempt the response_format fallback + assertTrue("Should attempt at least once for " + connectorType, attemptCount.get() >= 1); + // With regular retry logic (3 attempts), we expect exactly 4 attempts (1 initial + 3 retries) + assertEquals("Should attempt exactly 4 times (1 initial + 3 retries) for " + connectorType, 4, attemptCount.get()); + + mlAccessor.shutdown(); + } + } + // ============================================ // Helper Methods // ============================================ From 9deee6f46d7b8a85eb5ea73fd60ca6ebdf4cb185 Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Tue, 16 Dec 2025 20:50:16 +0000 Subject: [PATCH 5/6] fix tests after rebasing Signed-off-by: Fen Qin --- .../ml/MLInputOutputTransformer.java | 5 +-- .../judgment/LlmJudgmentTemplateIT.java | 33 ++++++++----------- .../ml/MLAccessorIntegrationTests.java | 10 ++++-- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java index 87962662..f01e7ce0 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java @@ -33,6 +33,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.ml.connector.ConnectorType; import org.opensearch.searchrelevance.ml.connector.LLMConnector; import org.opensearch.searchrelevance.ml.connector.OpenAIConnector; import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; @@ -144,8 +145,8 @@ public MLInput createMLInput( String paramName = connector.getMessageParameterName(); parameters.put(paramName, messagesArray); - // Only add response_format if requested (for models that support it) - if (includeResponseFormat) { + // Only add response_format if requested and connector is OpenAI (only OpenAI supports response_format) + if (includeResponseFormat && connector.getType() == ConnectorType.OPENAI) { String responseFormat = getResponseFormat(ratingType); parameters.put("response_format", responseFormat); } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java index 57cc92bd..aba85fdd 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java @@ -121,13 +121,11 @@ public void testLlmJudgmentWithPromptTemplate_thenSuccessful() { // Verify metadata contains structured context Map metadata = (Map) source.get("metadata"); assertNotNull(metadata); - Map context = (Map) metadata.get("llmJudgmentContext"); - assertNotNull(context); - assertNotNull(context.get("promptTemplate")); - assertTrue(((String) context.get("promptTemplate")).contains("{{queryText}}")); - assertNotNull(context.get("ratingType")); - assertEquals("SCORE0_1", context.get("ratingType")); - assertNotNull(context.get("overwriteCache")); + assertNotNull(metadata.get("promptTemplate")); + assertTrue(((String) metadata.get("promptTemplate")).contains("{{queryText}}")); + assertNotNull(metadata.get("llmJudgmentRatingType")); + assertEquals("SCORE0_1", metadata.get("llmJudgmentRatingType").toString()); + assertNotNull(metadata.get("overwriteCache")); // Verify judgmentRatings format List> judgmentRatings = (List>) source.get("judgmentRatings"); @@ -204,8 +202,7 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgment01Doc = entityAsMap(getJudgment01Response); Map source01 = (Map) judgment01Doc.get("_source"); Map metadata01 = (Map) source01.get("metadata"); - Map context01 = (Map) metadata01.get("llmJudgmentContext"); - assertEquals("SCORE0_1", context01.get("ratingType")); + assertEquals("SCORE0_1", metadata01.get("llmJudgmentRatingType").toString()); // Test RELEVANT_IRRELEVANT rating type String binaryBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentBinary.json").toURI())); @@ -237,8 +234,7 @@ public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { Map judgmentBinaryDoc = entityAsMap(getJudgmentBinaryResponse); Map sourceBinary = (Map) judgmentBinaryDoc.get("_source"); Map metadataBinary = (Map) sourceBinary.get("metadata"); - Map contextBinary = (Map) metadataBinary.get("llmJudgmentContext"); - assertEquals("RELEVANT_IRRELEVANT", contextBinary.get("ratingType")); + assertEquals("RELEVANT_IRRELEVANT", metadataBinary.get("llmJudgmentRatingType").toString()); } @SneakyThrows @@ -302,8 +298,7 @@ public void testLlmJudgmentWithOverwriteCache_thenSuccessful() { Map judgmentTrueDoc = entityAsMap(getJudgmentTrueResponse); Map sourceTrue = (Map) judgmentTrueDoc.get("_source"); Map metadataTrue = (Map) sourceTrue.get("metadata"); - Map contextTrue = (Map) metadataTrue.get("llmJudgmentContext"); - assertEquals(true, contextTrue.get("overwriteCache")); + assertEquals(true, metadataTrue.get("overwriteCache")); // Test with overwriteCache = false String overwriteFalseBody = Files.readString( @@ -337,8 +332,7 @@ public void testLlmJudgmentWithOverwriteCache_thenSuccessful() { Map judgmentFalseDoc = entityAsMap(getJudgmentFalseResponse); Map sourceFalse = (Map) judgmentFalseDoc.get("_source"); Map metadataFalse = (Map) sourceFalse.get("metadata"); - Map contextFalse = (Map) metadataFalse.get("llmJudgmentContext"); - assertEquals(false, contextFalse.get("overwriteCache")); + assertEquals(false, metadataFalse.get("overwriteCache")); } @SneakyThrows @@ -400,20 +394,19 @@ public void testLlmJudgmentWithoutOptionalFields_thenSuccessfulWithDefaults() { Map judgmentDoc = entityAsMap(getJudgmentResponse); Map source = (Map) judgmentDoc.get("_source"); Map metadata = (Map) source.get("metadata"); - Map context = (Map) metadata.get("llmJudgmentContext"); // promptTemplate should have the default value when not provided - Object promptTemplate = context.get("promptTemplate"); + Object promptTemplate = metadata.get("promptTemplate"); assertNotNull("promptTemplate should not be null when not provided", promptTemplate); assertEquals("promptTemplate should have default value", DEFAULT_PROMPT_TEMPLATE, promptTemplate); // ratingType should have a default value - Object ratingType = context.get("ratingType"); + Object ratingType = metadata.get("llmJudgmentRatingType"); assertNotNull("ratingType should not be null", ratingType); - assertEquals("ratingType should have default value", "SCORE0_1", ratingType); + assertEquals("ratingType should have default value", "SCORE0_1", ratingType.toString()); // overwriteCache should default to false - Object overwriteCache = context.get("overwriteCache"); + Object overwriteCache = metadata.get("overwriteCache"); assertNotNull("overwriteCache should not be null", overwriteCache); assertEquals("overwriteCache should default to false", false, overwriteCache); } diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java index a31ff7cd..75d56b91 100644 --- a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java @@ -174,6 +174,7 @@ public void testNonOpenAIConnectors_DoNotUseResponseFormatRetry() throws Excepti // Execute prediction Map hits = Map.of("doc1", "test content"); + AtomicReference result = new AtomicReference<>(); mlAccessor.predict( "test-model", 4000, @@ -185,7 +186,7 @@ public void testNonOpenAIConnectors_DoNotUseResponseFormatRetry() throws Excepti connectorType, 1000L, ActionListener.wrap((ChunkResult chunkResult) -> { - // Should not succeed + result.set(chunkResult); latch.countDown(); }, e -> { error.set(e); @@ -195,8 +196,11 @@ public void testNonOpenAIConnectors_DoNotUseResponseFormatRetry() throws Excepti assertTrue("Should complete for " + connectorType, latch.await(15, TimeUnit.SECONDS)); - // Verify failure occurred (as expected) - assertNotNull("Should have failed for " + connectorType, error.get()); + // Verify all chunks failed (as expected) + ChunkResult chunkResult = result.get(); + assertNotNull("Should have result for " + connectorType, chunkResult); + assertEquals("All chunks should have failed for " + connectorType, 0, chunkResult.getSuccessfulChunksCount()); + assertTrue("Should have failed chunks for " + connectorType, chunkResult.getFailedChunksCount() > 0); // Key assertion: For non-OpenAI connectors, should attempt regular retries but not response_format retry // The fix ensures that only OpenAI connectors attempt the response_format fallback From 4d17397d794ed36ef9ed85e3b0bd23a220d4c591 Mon Sep 17 00:00:00 2001 From: Fen Qin Date: Tue, 16 Dec 2025 22:20:25 +0000 Subject: [PATCH 6/6] add CHANGELOG.md Signed-off-by: Fen Qin --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b19a2c8..b9b97090 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Features - adds version-based index mapping update support to the Search Relevance plugin [#344](https://github.com/opensearch-project/search-relevance/pull/344) * LLM Judgement Customized Prompt Template Implementation [#264](https://github.com/opensearch-project/search-relevance/pull/264) +* Support multiple LLM providers with proper rate limit [#285](https://github.com/opensearch-project/search-relevance/pull/285) ### Enhancements