diff --git a/.github/workflows/backwards_compatibility_tests_workflow.yml b/.github/workflows/backwards_compatibility_tests_workflow.yml new file mode 100644 index 00000000..ffd9bdec --- /dev/null +++ b/.github/workflows/backwards_compatibility_tests_workflow.yml @@ -0,0 +1,49 @@ +name: Backwards Compatibility Tests SearchRelevance +on: + push: + branches: + - "*" + - "feature/**" + pull_request: + branches: + - "*" + - "feature/**" + +jobs: + Get-CI-Image-Tag: + uses: opensearch-project/opensearch-build/.github/workflows/get-ci-image-tag.yml@main + with: + product: opensearch + + Rolling-Upgrade-BWCTests-SearchRelevance: + needs: Get-CI-Image-Tag + strategy: + matrix: + java: [21] + os: [ubuntu-latest] + bwc_version: ["3.3.0-SNAPSHOT"] + opensearch_version: ["3.4.0-SNAPSHOT"] + + name: SearchRelevance Rolling-Upgrade BWC Tests + runs-on: ${{ matrix.os }} + container: + image: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-version-linux }} + options: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-start-options }} + env: + BWC_VERSION_ROLLING_UPGRADE: ${{ matrix.bwc_version }} + + steps: + - name: Run start commands + run: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-start-command }} + - name: Checkout search-relevance + uses: actions/checkout@v4 + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: ${{ matrix.java }} + - name: Run SearchRelevance Rolling-Upgrade BWC Tests + run: | + chown -R 1000:1000 `pwd` + echo "Running rolling-upgrade backwards compatibility tests..." + su `id -un 1000` -c "./gradlew :qa:rolling-upgrade:testRollingUpgrade -Dtests.bwc.version=${{ matrix.bwc_version }} --refresh-dependencies --no-daemon" diff --git a/CHANGELOG.md b/CHANGELOG.md index a9225750..1b19a2c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Features - adds version-based index mapping update support to the Search Relevance plugin [#344](https://github.com/opensearch-project/search-relevance/pull/344) +* LLM Judgement Customized Prompt Template Implementation [#264](https://github.com/opensearch-project/search-relevance/pull/264) ### Enhancements diff --git a/build.gradle b/build.gradle index bf4059e4..7fa2e19c 100644 --- a/build.gradle +++ b/build.gradle @@ -86,6 +86,13 @@ java { } ext { + + default_bwc_version = System.getProperty("bwc.version") + default_bwc_bundle_version= System.getProperty("bwc.bundle.version") + bwcBundleTest = (project.findProperty('customDistributionDownloadType') != null && project.properties['customDistributionDownloadType'] == "bundle") + search_relevance_bwc_version = bwcBundleTest ? System.getProperty("tests.bwc.bundle.version",rootProject.ext.default_bwc_bundle_version): System.getProperty("tests.bwc.version", rootProject.ext.default_bwc_version) + currentBundleVersion = opensearch_version.replace("-SNAPSHOT","") + projectSubstitutions = [:] licenseFile = rootProject.file('LICENSE.txt') noticeFile = rootProject.file('NOTICE.txt') @@ -530,6 +537,8 @@ opensearch_tmp_dir.mkdirs() integTest { systemProperty 'tests.security.manager', 'false' systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath + // allows integration test classes to access test resource from project root path + systemProperty 'project.root', project.rootDir.absolutePath systemProperty 'buildDir', buildDir.path systemProperty "https", securityEnabled systemProperty "security", securityEnabled diff --git a/formatter/formatting.gradle b/formatter/formatting.gradle index 520fe112..c4d01d74 100644 --- a/formatter/formatting.gradle +++ b/formatter/formatting.gradle @@ -1,4 +1,5 @@ allprojects { + apply plugin: "com.diffplug.spotless" spotless { java { // Normally this isn't necessary, but we have Java sources in diff --git a/gradle.properties b/gradle.properties index 7717686e..4b5ee7ed 100644 --- a/gradle.properties +++ b/gradle.properties @@ -9,3 +9,17 @@ org.gradle.caching=true org.gradle.warning.mode=none org.gradle.parallel=true + +# The BWC version here should always be the latest opensearch version set in +# https://github.com/opensearch-project/OpenSearch/blob/main/libs/core/src/main/java/org/opensearch/Version.java . +# Wired compatibility of OpenSearch works like 3.x version is compatible with 2.(latest-major) version. +# Therefore, to run rolling-upgrade BWC Test on local machine the BWC version here should be set 2.(latest-major). +systemProp.bwc.version=3.3.0-SNAPSHOT +systemProp.bwc.bundle.version=3.2.0 + +# For fixing Spotless check with Java 17 +org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED \ No newline at end of file diff --git a/qa/build.gradle b/qa/build.gradle new file mode 100644 index 00000000..ce45d20f --- /dev/null +++ b/qa/build.gradle @@ -0,0 +1,262 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import org.apache.tools.ant.taskdefs.condition.Os + +import java.util.concurrent.Callable +import java.nio.file.Path + +apply plugin: 'opensearch.testclusters' +apply plugin: 'opensearch.build' +apply plugin: 'opensearch.rest-test' +apply plugin: 'io.freefair.lombok' +apply plugin: 'opensearch.java-agent' + +// Disable a few tasks that come with build +build.enabled = false +integTest.enabled = false +test.enabled = false +assemble.enabled = false +dependenciesInfo.enabled = false +dependencyLicenses.enabled = false +thirdPartyAudit.enabled = false +validateNebulaPom.enabled = false +loggerUsageCheck.enabled = false + +java { + targetCompatibility = JavaVersion.VERSION_21 + sourceCompatibility = JavaVersion.VERSION_21 +} + +configurations { + zipArchive +} + +repositories { + mavenLocal() + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } + mavenCentral() + maven { url "https://plugins.gradle.org/m2/" } +} + +def knnJarDirectory = "$rootDir/build/dependencies/opensearch-knn" + +dependencies { + api "org.opensearch:opensearch:${opensearch_version}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" + compileOnly fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"]) + compileOnly group: 'com.google.guava', name: 'guava', version:'33.4.8-jre' + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.20.0' + // json-path 2.10.0 depends on slf4j 2.0.11, which conflicts with the version used by OpenSearch core. + // Excluding slf4j here since json-path is only used for testing, and logging failures in this context are acceptable. + testRuntimeOnly('com.jayway.jsonpath:json-path:2.10.0') { + // OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here. + exclude group: 'org.slf4j', module: 'slf4j-api' + exclude group: 'net.minidev', module: 'json-smart' + } + testRuntimeOnly group: 'net.minidev', name:'json-smart', version: "${versions.json_smart}" + api "org.apache.logging.log4j:log4j-api:${versions.log4j}" + api "org.apache.logging.log4j:log4j-core:${versions.log4j}" + api "junit:junit:${versions.junit}" + testImplementation "org.opensearch.test:framework:${opensearch_version}" + testImplementation(testFixtures(rootProject)) +} + +ext { + licenseFile = rootProject.file('LICENSE.txt') + noticeFile = rootProject.file('NOTICE.txt') +} + +def tmp_dir = project.file('build/private/artifact_tmp').absoluteFile +tmp_dir.mkdirs() +String default_bwc_version = System.getProperty("bwc.version", rootProject.ext.default_bwc_version) +String search_relevance_bwc_version = System.getProperty("tests.bwc.version", default_bwc_version) +boolean isSnapshot = search_relevance_bwc_version.contains("-SNAPSHOT") +String search_relevance_bwc_version_no_qualifier = isSnapshot ? search_relevance_bwc_version - "-SNAPSHOT" : search_relevance_bwc_version + +String os_platform = "linux" +String artifact_type = "tar" +String file_ext = "tar.gz" + +if (Os.isFamily(Os.FAMILY_WINDOWS)) { + os_platform = "windows" + artifact_type = "zip" + file_ext = "zip" +} + +ext{ + plugins = [provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching{include "**/opensearch-job-scheduler-${opensearch_build}.zip"}.getSingleFile() + } + } + } + }), provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching{include "**/opensearch-ml-plugin-${opensearch_build}.zip"}.getSingleFile() + } + } + } + }), provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching{include "**/opensearch-knn-${opensearch_build}.zip"}.getSingleFile() + } + } + } + }), rootProject.tasks.bundlePlugin.archiveFile] +} + +task deleteTempDirectories { + doFirst { + if (tmp_dir.exists()) { + File[] tempFiles = tmp_dir.listFiles() + if (tempFiles != null) { + for (File child : tempFiles) { + if (child.exists() && child.toString().contains("opensearch-")) { + project.delete(child) + } + } + } + } + } +} + +// Task to pull opensearch artifact from archive +task pullOpensearchArtifact { + dependsOn "deleteTempDirectories" + + doLast{ + ext{ + if (isSnapshot) { + srcUrl = "https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/${search_relevance_bwc_version_no_qualifier}/latest/${os_platform}/x64/${artifact_type}/dist/opensearch/opensearch-${search_relevance_bwc_version_no_qualifier}-${os_platform}-x64.${file_ext}" + } else { + srcUrl = "https://artifacts.opensearch.org/releases/bundle/opensearch/${search_relevance_bwc_version}/opensearch-${search_relevance_bwc_version}-${os_platform}-x64.${file_ext}" + } + } + ant.get( + src: srcUrl, + dest: tmp_dir.absolutePath, + httpusecaches: false + ) + copy { + if (Os.isFamily(Os.FAMILY_WINDOWS)) { + from zipTree(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}-${os_platform}-x64.${file_ext}")) + } else { + from tarTree(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}-${os_platform}-x64.${file_ext}")) + } + into tmp_dir.absolutePath + } + } +} + +// Task to pull ml plugin from archive +task pullMlCommonsBwcPlugin { + doLast { + copy { + from(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}", "plugins", "opensearch-ml")) + into Path.of(tmp_dir.absolutePath, "opensearch-ml") + } + } +} + +// Task to pull KNN plugin from archive +task pullKnnBwcPlugin { + dependsOn "pullOpensearchArtifact" + + doLast { + copy { + from(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}", "plugins", "opensearch-knn")) + into Path.of(tmp_dir.absolutePath, "opensearch-knn") + } + } +} + +// Task to pull job scheduler plugin from archive +task pullJobSchedulerBwcPlugin { + dependsOn "pullKnnBwcPlugin" + doLast { + copy { + from(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}", "plugins", "opensearch-job-scheduler")) + into Path.of(tmp_dir.absolutePath, "opensearch-job-scheduler") + } + } +} + +// Task to pull search relevance plugin from archive +task pullBwcPlugin { + doLast { + copy { + from(Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}", "plugins", "opensearch-search-relevance")) + into Path.of(tmp_dir.absolutePath, "opensearch-search-relevance") + } + delete Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}"), java.nio.file.Path.of(tmp_dir.absolutePath, "opensearch-${search_relevance_bwc_version_no_qualifier}-${os_platform}-x64.${file_ext}") + } +} + +// Task to zip opensearch-job-scheduler plugin from archive +task zipBwcJobSchedulerPlugin(type: Zip) { + dependsOn "pullJobSchedulerBwcPlugin" + from(Path.of(tmp_dir.absolutePath, "opensearch-job-scheduler")) + destinationDirectory = tmp_dir + archiveFileName = "opensearch-job-scheduler-${search_relevance_bwc_version_no_qualifier}.zip" + doLast { + delete Path.of(tmp_dir.absolutePath, "opensearch-job-scheduler") + } +} + +// Task to zip ml-commons plugin from archive +task zipBwcMlCommonsPlugin(type: Zip) { + dependsOn "pullMlCommonsBwcPlugin" + dependsOn "zipBwcJobSchedulerPlugin" + from(Path.of(tmp_dir.absolutePath, "opensearch-ml")) + destinationDirectory = tmp_dir + archiveFileName = "opensearch-ml-${search_relevance_bwc_version_no_qualifier}.zip" + doLast { + delete Path.of(tmp_dir.absolutePath, "opensearch-ml") + } +} + +// Task to zip knn plugin from archive +task zipBwcKnnPlugin(type: Zip) { + dependsOn "pullKnnBwcPlugin" + dependsOn "zipBwcMlCommonsPlugin" + from(Path.of(tmp_dir.absolutePath, "opensearch-knn")) + destinationDirectory = tmp_dir + archiveFileName = "opensearch-knn-${search_relevance_bwc_version_no_qualifier}.zip" + doLast { + delete Path.of(tmp_dir.absolutePath, "opensearch-knn") + } +} + +// Task to zip search relevance plugin from archive +task zipBwcPlugin(type: Zip) { + dependsOn "zipBwcKnnPlugin" + dependsOn "pullBwcPlugin" + from(Path.of(tmp_dir.absolutePath, "opensearch-search-relevance")) + destinationDirectory = tmp_dir + archiveFileName = "opensearch-search-relevance-${search_relevance_bwc_version_no_qualifier}.zip" + doLast { + delete Path.of(tmp_dir.absolutePath, "opensearch-search-relevance") + } +} + + +task bwcTestSuite { + dependsOn ":qa:rolling-upgrade:testRollingUpgrade" +} diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle new file mode 100644 index 00000000..02de2c48 --- /dev/null +++ b/qa/rolling-upgrade/build.gradle @@ -0,0 +1,192 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +import org.opensearch.gradle.testclusters.StandaloneRestIntegTestTask + +apply from : "$rootDir/qa/build.gradle" + +def ext=rootProject.ext +String baseName = "searchRelevanceBwcCluster-rolling" + +// Creates a test cluster of previous version and loads k-NN plugin of bwcVersion +testClusters { + "${baseName}" { + testDistribution = "ARCHIVE" + jvmArgs("-Xms1g", "-Xmx4g") + numberOfNodes = 3 + if(ext.bwcBundleTest){ + versions = [ext.search_relevance_bwc_version, ext.currentBundleVersion] + def path=ext.opensearch_tmp_dir + nodes.each { node -> + node.extraConfigFile("kirk.pem", file("$path/kirk.pem")) + node.extraConfigFile("kirk-key.pem", file("$path/kirk-key.pem")) + node.extraConfigFile("esnode.pem", file("$path/esnode.pem")) + node.extraConfigFile("esnode-key.pem", file("$path/esnode-key.pem")) + node.extraConfigFile("root-ca.pem", file("$path/root-ca.pem")) + node.setting("plugins.security.disabled", "true") + node.setting("plugins.security.ssl.transport.pemcert_filepath", "esnode.pem") + node.setting("plugins.security.ssl.transport.pemkey_filepath", "esnode-key.pem") + node.setting("plugins.security.ssl.transport.pemtrustedcas_filepath", "root-ca.pem") + node.setting("plugins.security.ssl.transport.enforce_hostname_verification", "false") + node.setting("plugins.security.ssl.http.enabled", "true") + node.setting("plugins.security.ssl.http.pemcert_filepath", "esnode.pem") + node.setting("plugins.security.ssl.http.pemkey_filepath", "esnode-key.pem") + node.setting("plugins.security.ssl.http.pemtrustedcas_filepath", "root-ca.pem") + node.setting("plugins.security.allow_unsafe_democertificates", "true") + node.setting("plugins.security.allow_default_init_securityindex", "true") + node.setting("plugins.security.authcz.admin_dn", "CN=kirk,OU=client,O=client,L=test,C=de") + node.setting("plugins.security.audit.type", "internal_elasticsearch") + node.setting("plugins.security.enable_snapshot_restore_privilege", "true") + node.setting("plugins.security.check_snapshot_restore_write_privileges", "true") + node.setting("plugins.security.restapi.roles_enabled", "[\"all_access\", \"security_rest_api_access\"]") + node.setting("plugins.security.system_indices.enabled", "true") + } + }else{ + versions = [ext.search_relevance_bwc_version, opensearch_version] + plugin(project.tasks.zipBwcJobSchedulerPlugin.archiveFile) + plugin(project.tasks.zipBwcMlCommonsPlugin.archiveFile) + plugin(project.tasks.zipBwcKnnPlugin.archiveFile) + plugin(project.tasks.zipBwcPlugin.archiveFile) + } + setting 'path.repo', "${buildDir}/cluster/shared/repo/${baseName}" + setting 'http.content_type.required', 'true' + } +} + +def versionsBelow3_3 = ["3.0", "3.1", "3.2"] +def versionsBelow3_4 = versionsBelow3_3 + "3.3" + +// Task to run BWC tests against the old cluster +task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { + if(!ext.bwcBundleTest){ + dependsOn "zipBwcPlugin" + } + useCluster testClusters."${baseName}" + systemProperty 'tests.rest.bwcsuite_cluster', 'old_cluster' + systemProperty 'tests.plugin_bwc_version', ext.search_relevance_bwc_version + systemProperty 'tests.skip_delete_model_index', 'true' + + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") + systemProperty 'tests.security.manager', 'false' + + // Excluding LLM Judgment tests because we introduced this feature in 3.3.0 + if (versionsBelow3_3.any { ext.search_relevance_bwc_version.startsWith(it) }) { + filter { + excludeTestsMatching "org.opensearch.searchrelevance.bwc.rolling.LlmJudgmentBWCIT.*" + } + } +} + +// Part of rolling upgrade. Upgrades one node of the old cluster to new OpenSearch version with upgraded plugin version +// This results in a mixed cluster with 2 nodes on the old version and 1 upgraded node. +task testAgainstOneThirdUpgradedCluster(type: StandaloneRestIntegTestTask) { + useCluster testClusters."${baseName}" + dependsOn rootProject.tasks.assemble + dependsOn "testAgainstOldCluster" + + doFirst { + println "${ext.bwcBundleTest}" + // This is added to prevent the cluster from getting stuck in yellow state + println "Waiting for cluster to stabilize after previous test, including data synchronization on replica shards" + println "BWC Test: Upgrading from ${ext.search_relevance_bwc_version} to ${opensearch_version}" + Thread.sleep(10000) // 10 seconds delay + + if(ext.bwcBundleTest){ + testClusters."${baseName}".nextNodeToNextVersion() + }else{ + testClusters."${baseName}".upgradeNodeAndPluginToNextVersion(project.ext.plugins) + } + } + + systemProperty 'tests.rest.bwcsuite_cluster', 'mixed_cluster' + systemProperty 'tests.rest.first_round', 'true' + systemProperty 'tests.skip_delete_model_index', 'true' + systemProperty 'tests.plugin_bwc_version', ext.search_relevance_bwc_version + + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") + systemProperty 'tests.security.manager', 'false' + + // Excluding LLM Judgment tests because we introduced this feature in 3.3.0 + if (versionsBelow3_3.any { ext.search_relevance_bwc_version.startsWith(it) }) { + filter { + excludeTestsMatching "org.opensearch.searchrelevance.bwc.rolling.LlmJudgmentBWCIT.*" + } + } +} + +// Part of rolling upgrade. Upgrades the second node to new OpenSearch version with upgraded plugin version after the +// first node is upgraded. This results in a mixed cluster with 1 node on the old version and 2 upgraded nodes. +task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) { + dependsOn "testAgainstOneThirdUpgradedCluster" + useCluster testClusters."${baseName}" + + doFirst { + // This is added to prevent the cluster from getting stuck in yellow state + println "Waiting for cluster to stabilize after previous test, including data synchronization on replica shards" + println "BWC Test: Upgrading from ${ext.search_relevance_bwc_version} to ${opensearch_version}" + Thread.sleep(10000) // 10 seconds delay + + if(ext.bwcBundleTest){ + testClusters."${baseName}".nextNodeToNextVersion() + }else{ + testClusters."${baseName}".upgradeNodeAndPluginToNextVersion(project.ext.plugins) + } + } + systemProperty 'tests.rest.bwcsuite_cluster', 'mixed_cluster' + systemProperty 'tests.rest.first_round', 'false' + systemProperty 'tests.skip_delete_model_index', 'true' + systemProperty 'tests.plugin_bwc_version', ext.search_relevance_bwc_version + + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") + systemProperty 'tests.security.manager', 'false' + + // Excluding LLM Judgment tests because we introduced this feature in 3.3.0 + if (versionsBelow3_3.any { ext.search_relevance_bwc_version.startsWith(it) }) { + filter { + excludeTestsMatching "org.opensearch.searchrelevance.bwc.rolling.LlmJudgmentBWCIT.*" + } + } +} + +// Part of rolling upgrade. Upgrades the third node to new OpenSearch version with upgraded plugin version after the +// second node is upgraded. This results in a fully upgraded cluster. +task testRollingUpgrade(type: StandaloneRestIntegTestTask) { + dependsOn "testAgainstTwoThirdsUpgradedCluster" + useCluster testClusters."${baseName}" + + doFirst { + // This is added to prevent the cluster from getting stuck in yellow state + println "Waiting for cluster to stabilize after previous test, including data synchronization on replica shards" + Thread.sleep(10000) // 10 seconds delay + + if(ext.bwcBundleTest){ + testClusters."${baseName}".nextNodeToNextVersion() + }else{ + testClusters."${baseName}".upgradeNodeAndPluginToNextVersion(project.ext.plugins) + } + } + + mustRunAfter "testAgainstOneThirdUpgradedCluster" + systemProperty 'tests.rest.bwcsuite_cluster', 'upgraded_cluster' + systemProperty 'tests.skip_delete_model_index', 'true' + systemProperty 'tests.plugin_bwc_version', ext.search_relevance_bwc_version + + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") + systemProperty 'tests.security.manager', 'false' + + // Excluding LLM Judgment tests because we introduced this feature in 3.3.0 + if (versionsBelow3_3.any { ext.search_relevance_bwc_version.startsWith(it) }) { + filter { + excludeTestsMatching "org.opensearch.searchrelevance.bwc.rolling.LlmJudgmentBWCIT.*" + } + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/AbstractSearchRelevanceRollingUpgradeTestCase.java b/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/AbstractSearchRelevanceRollingUpgradeTestCase.java new file mode 100644 index 00000000..acb9a434 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/AbstractSearchRelevanceRollingUpgradeTestCase.java @@ -0,0 +1,151 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.bwc.rolling; + +import java.util.Locale; + +import org.opensearch.common.settings.Settings; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +/** + * Base class for Search Relevance BWC (Backward Compatibility) tests during rolling upgrades. + * Provides common utilities and cluster state management for testing compatibility across versions. + */ +public abstract class AbstractSearchRelevanceRollingUpgradeTestCase extends OpenSearchRestTestCase { + + private static final String OLD_CLUSTER = "old_cluster"; + private static final String MIXED_CLUSTER = "mixed_cluster"; + private static final String UPGRADED_CLUSTER = "upgraded_cluster"; + + /** + * Enum representing the different cluster states during a rolling upgrade. + */ + protected enum ClusterType { + OLD, + MIXED, + UPGRADED; + + public static ClusterType instance(String value) { + switch (value) { + case OLD_CLUSTER: + return OLD; + case MIXED_CLUSTER: + return MIXED; + case UPGRADED_CLUSTER: + return UPGRADED; + default: + throw new IllegalArgumentException("unknown cluster type: " + value); + } + } + } + + /** + * Gets the current cluster type based on system properties. + * This determines which phase of the rolling upgrade the test is currently executing. + * + * @return The current ClusterType (OLD, MIXED, or UPGRADED) + */ + protected ClusterType getClusterType() { + return ClusterType.instance(System.getProperty("tests.rest.bwcsuite_cluster")); + } + + /** + * Customizes REST client settings to accommodate rolling upgrade scenarios. + * Increases socket timeout to handle delays during cluster transitions. + * + * @return Settings with extended client socket timeout + */ + @Override + protected final Settings restClientSettings() { + return Settings.builder().put(super.restClientSettings()).put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "120s").build(); + } + + /** + * Gets the index name for the test with a prefix to identify BWC test resources. + * + * @return Index name prefixed with "search-relevance-bwc-" + */ + protected String getIndexNameForTest() { + return String.format(Locale.ROOT, "search-relevance-bwc-%s", getTestName().toLowerCase(Locale.ROOT)); + } + + /** + * Gets the query set name for the test with a prefix to identify BWC test resources. + * + * @return Query set name prefixed with "bwc-queryset-" + */ + protected String getQuerySetNameForTest() { + return String.format(Locale.ROOT, "bwc-queryset-%s", getTestName().toLowerCase(Locale.ROOT)); + } + + /** + * Gets the judgment name for the test with a prefix to identify BWC test resources. + * + * @return Judgment name prefixed with "bwc-judgment-" + */ + protected String getJudgmentNameForTest() { + return String.format(Locale.ROOT, "bwc-judgment-%s", getTestName().toLowerCase(Locale.ROOT)); + } + + /** + * Gets the search configuration name for the test with a prefix to identify BWC test resources. + * + * @return Search configuration name prefixed with "bwc-search-config-" + */ + protected String getSearchConfigNameForTest() { + return String.format(Locale.ROOT, "bwc-search-config-%s", getTestName().toLowerCase(Locale.ROOT)); + } + + /** + * Checks if this is the first round of the mixed cluster phase. + * During rolling upgrades, the mixed phase has multiple rounds as nodes are upgraded one by one. + * + * @return true if this is the first mixed cluster round, false otherwise + */ + protected boolean isFirstMixedRound() { + return Boolean.parseBoolean(System.getProperty("tests.rest.first_round", "false")); + } + + /** + * Gets the BWC (backward compatible) version being tested. + * This is the older version that we're upgrading from. + * + * @return The BWC version string + */ + protected String getBWCVersion() { + return System.getProperty("tests.plugin_bwc_version"); + } + + /** + * Preserves indices created during tests across rolling upgrade phases. + * This is essential for BWC testing where data created in OLD cluster + * must be accessible in MIXED and UPGRADED cluster phases. + * + * @return true to preserve indices between test phases + */ + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + public boolean preserveClusterUponCompletion() { + // Otherwise, the cluster setting to enable ml-common is reset and the model is undeployed + return true; + } + + @Override + protected boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/LlmJudgmentBWCIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/LlmJudgmentBWCIT.java new file mode 100644 index 00000000..d7408863 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/searchrelevance/bwc/rolling/LlmJudgmentBWCIT.java @@ -0,0 +1,675 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.bwc.rolling; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +/** + * BWC (Backward Compatibility) Integration Test for LLM Judgment functionality. + * + * This test validates that: + * 1. OLD cluster: Creates query sets and judgments using the old format (no promptTemplate, no ratingType) + * 2. MIXED cluster: Can read and process both old and new format data + * 3. UPGRADED cluster: Supports new features (promptTemplate, ratingType) while maintaining old format compatibility + */ +public class LlmJudgmentBWCIT extends AbstractSearchRelevanceRollingUpgradeTestCase { + + private static final String QUERY_SET_ENDPOINT = "/_plugins/_search_relevance/query_sets"; + private static final String JUDGMENT_ENDPOINT = "/_plugins/_search_relevance/judgments"; + private static final String SEARCH_CONFIG_ENDPOINT = "/_plugins/_search_relevance/search_configurations"; + + private static String querySetId; + private static String judgmentId; + private static String searchConfigId; + + /** + * Main BWC test for LLM Judgment functionality. + * Tests backward compatibility during rolling upgrade: + * - OLD: Create resources with old format (no promptTemplate, no ratingType) + * - MIXED: Validate existing resources still work, can create new resources with new format + * - UPGRADED: Full new format support, old format still works + */ + public void testLlmJudgment_RollingUpgrade() throws Exception { + switch (getClusterType()) { + case OLD: + testCreateResourcesWithOldFormat(); + testValidateOldFormatResources(); + break; + case MIXED: + testValidateOldFormatResources(); + if (isFirstMixedRound()) { + testCreateResourcesWithNewFormat(); + } + break; + case UPGRADED: + testValidateAllResources(); + testNewFormatFeatures(); + cleanupResources(); + break; + default: + throw new IllegalStateException("Unknown cluster type: " + getClusterType()); + } + } + + /** + * OLD cluster test: Create resources using old format. + * Old format characteristics: + * - Query set: Only queryText and referenceAnswer (no custom fields) + * - LLM Judgment: No promptTemplate, no llmJudgmentRatingType (uses defaults) + */ + private void testCreateResourcesWithOldFormat() throws Exception { + String indexName = getIndexNameForTest(); + + // Create test index + createTestIndex(indexName); + + // Create search configuration (this hasn't changed) + searchConfigId = createSearchConfiguration(indexName); + assertNotNull("Search configuration should be created", searchConfigId); + + // Create query set with OLD format (no custom fields, just queryText and referenceAnswer) + String querySetName = getQuerySetNameForTest(); + querySetId = createQuerySetOldFormat(querySetName); + assertNotNull("Query set should be created with old format", querySetId); + + // Validate query set was created correctly + Map querySet = getQuerySet(querySetId); + assertEquals("Query set name should match", querySetName, querySet.get("name")); + + // Validate that we can retrieve the resources by name (same approach used in MIXED/UPGRADED cluster) + String searchConfigName = getSearchConfigNameForTest(); + String retrievedQuerySetId = getQuerySetIdByName(querySetName); + String retrievedSearchConfigId = getSearchConfigIdByName(searchConfigName); + + assertEquals("Query set ID should match when retrieved by name", querySetId, retrievedQuerySetId); + assertEquals("Search config ID should match when retrieved by name", searchConfigId, retrievedSearchConfigId); + + // Create LLM judgment with OLD format (no promptTemplate, no llmJudgmentRatingType) + String judgmentName = getJudgmentNameForTest(); + judgmentId = createLlmJudgmentOldFormat(judgmentName, querySetId, searchConfigId); + assertNotNull("LLM judgment should be created with old format", judgmentId); + + // Validate the judgment can be retrieved and has correct OLD format + Map judgment = getLlmJudgment(judgmentId); + assertNotNull("LLM judgment should be retrievable", judgment); + assertEquals("Judgment name should match", judgmentName, judgment.get("name")); + assertEquals("Judgment type should be LLM_JUDGMENT", "LLM_JUDGMENT", judgment.get("type")); + + // Validate OLD format: should NOT have new fields like promptTemplate and llmJudgmentRatingType + Map metadata = (Map) judgment.get("metadata"); + if (metadata != null) { + assertNull("OLD format should not have promptTemplate", metadata.get("promptTemplate")); + assertNull("OLD format should not have llmJudgmentRatingType", metadata.get("llmJudgmentRatingType")); + } + } + + /** + * MIXED cluster test: Validate that old format resources still work. + * Also test creating new format resources if this is the first mixed round. + */ + private void testValidateOldFormatResources() throws Exception { + // Retrieve IDs by name (since static variables don't persist across test phases) + String querySetName = getQuerySetNameForTest(); + String searchConfigName = getSearchConfigNameForTest(); + + querySetId = getQuerySetIdByName(querySetName); + searchConfigId = getSearchConfigIdByName(searchConfigName); + + // Validate query set created in OLD cluster still exists and is readable + Map querySet = getQuerySet(querySetId); + assertNotNull("Query set from OLD cluster should still exist", querySet); + + // Validate search configuration still exists + Map searchConfig = getSearchConfiguration(searchConfigId); + assertNotNull("Search configuration from OLD cluster should still exist", searchConfig); + + // Validate LLM judgment created in OLD cluster still exists and can be retrieved + String judgmentName = getJudgmentNameForTest(); + judgmentId = getLlmJudgmentIdByName(judgmentName); + assertNotNull("LLM judgment from OLD cluster should still exist", judgmentId); + + // Retrieve and validate the old judgment to ensure backward compatibility + Map judgment = getLlmJudgment(judgmentId); + assertNotNull("Old format LLM judgment should be retrievable in MIXED cluster", judgment); + assertEquals("Judgment name should match", judgmentName, judgment.get("name")); + } + + /** + * Test creating resources with new format in MIXED cluster. + */ + private void testCreateResourcesWithNewFormat() throws Exception { + String querySetName = getQuerySetNameForTest() + "-new"; + + // Create query set with NEW format (includes custom fields) + String newQuerySetId = createQuerySetNewFormat(querySetName); + assertNotNull("Query set should be created with new format", newQuerySetId); + + // Validate new format query set + Map querySet = getQuerySet(newQuerySetId); + assertEquals("Query set name should match", querySetName, querySet.get("name")); + + // Create LLM judgment with NEW format (with promptTemplate and ratingType) + String newJudgmentId = createLlmJudgmentNewFormat(newQuerySetId, searchConfigId); + assertNotNull("New format LLM judgment should be created", newJudgmentId); + + // Validate new format judgment can be retrieved + Map newJudgment = getLlmJudgment(newJudgmentId); + assertNotNull("New format LLM judgment should be retrievable", newJudgment); + + // In MIXED cluster, the new format fields might not be stored/returned by old nodes (3.3.0) + // We just verify the judgment was created successfully + // Full validation of new fields will happen in UPGRADED cluster where all nodes support them + } + + /** + * UPGRADED cluster test: Validate all resources work correctly. + * Test new format features like promptTemplate and ratingType. + */ + private void testValidateAllResources() throws Exception { + // Retrieve IDs by name (since static variables don't persist across test phases) + String querySetName = getQuerySetNameForTest(); + String searchConfigName = getSearchConfigNameForTest(); + String judgmentName = getJudgmentNameForTest(); + + querySetId = getQuerySetIdByName(querySetName); + searchConfigId = getSearchConfigIdByName(searchConfigName); + judgmentId = getLlmJudgmentIdByName(judgmentName); + + // Validate old format query set still works + Map oldQuerySet = getQuerySet(querySetId); + assertNotNull("Old format query set should still work in upgraded cluster", oldQuerySet); + + // Validate search configuration still works + Map searchConfig = getSearchConfiguration(searchConfigId); + assertNotNull("Search configuration should still work in upgraded cluster", searchConfig); + + // Validate old format judgment still works + Map oldJudgment = getLlmJudgment(judgmentId); + assertNotNull("Old format LLM judgment should still work in upgraded cluster", oldJudgment); + assertEquals("Judgment name should match", judgmentName, oldJudgment.get("name")); + } + + /** + * Test new format features in UPGRADED cluster. + */ + private void testNewFormatFeatures() throws Exception { + String querySetName = getQuerySetNameForTest() + "-upg"; + + // Create query set with new format including multiple custom fields + String newQuerySetId = createQuerySetWithMultipleCustomFields(querySetName); + assertNotNull("Query set with multiple custom fields should be created", newQuerySetId); + + // Validate the query set has custom fields + Map querySet = getQuerySet(newQuerySetId); + assertEquals("Query set name should match", querySetName, querySet.get("name")); + + // Create LLM judgment with new format and validate it works + String newJudgmentId = createLlmJudgmentNewFormat(newQuerySetId, searchConfigId); + assertNotNull("New format LLM judgment should be created in upgraded cluster", newJudgmentId); + + // Validate new judgment exists and can be retrieved with new format fields + Map newJudgment = getLlmJudgment(newJudgmentId); + assertNotNull("New judgment should exist", newJudgment); + assertEquals("New judgment name should match", "bwc-judgment-new-format", newJudgment.get("name")); + + // Validate NEW format fields are present in UPGRADED cluster + Map newMetadata = (Map) newJudgment.get("metadata"); + assertNotNull("Metadata should exist", newMetadata); + assertNotNull("NEW format should have promptTemplate", newMetadata.get("promptTemplate")); + assertEquals( + "Prompt template should match", + "Query: {{queryText}}\\n\\nDocuments: {{hits}}\\n\\nEvaluate the relevance of the search result.", + newMetadata.get("promptTemplate") + ); + assertNotNull("NEW format should have llmJudgmentRatingType", newMetadata.get("llmJudgmentRatingType")); + assertEquals("Rating type should be SCORE0_1", "SCORE0_1", newMetadata.get("llmJudgmentRatingType")); + } + + /** + * Clean up test resources. + */ + private void cleanupResources() throws Exception { + // Clean up LLM judgments + if (judgmentId != null) { + deleteLlmJudgment(judgmentId); + } + + // Clean up query sets + if (querySetId != null) { + deleteQuerySet(querySetId); + } + + // Clean up search configurations + if (searchConfigId != null) { + deleteSearchConfiguration(searchConfigId); + } + + // Clean up test index + String indexName = getIndexNameForTest(); + deleteIndexSilently(indexName); + } + + // ==================== Helper Methods ==================== + + /** + * Creates a test index for search configuration. + */ + private void createTestIndex(String indexName) throws IOException { + Request request = new Request("PUT", "/" + indexName); + request.setJsonEntity( + "{" + + "\"settings\": {\"index\": {\"number_of_shards\": 1, \"number_of_replicas\": 0}}," + + "\"mappings\": {\"properties\": {\"text\": {\"type\": \"text\"}}}" + + "}" + ); + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + /** + * Creates a search configuration. + */ + private String createSearchConfiguration(String indexName) throws IOException, ParseException { + Request request = new Request("PUT", SEARCH_CONFIG_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"" + + getSearchConfigNameForTest() + + "\"," + + "\"description\": \"BWC test search configuration\"," + + "\"index\": \"" + + indexName + + "\"," + + "\"query\": \"{\\\"match_all\\\": {}}\"" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("search_configuration_id"); + } + + /** + * Creates a query set using OLD format (no custom fields). + * Format: [{queryText: "...", referenceAnswer: "..."}] + */ + private String createQuerySetOldFormat(String name) throws IOException, ParseException { + Request request = new Request("PUT", QUERY_SET_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"" + + name + + "\"," + + "\"description\": \"BWC test query set - old format\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"What is OpenSearch?\", \"referenceAnswer\": \"OpenSearch is a search and analytics suite\"}," + + " {\"queryText\": \"red shoes\", \"referenceAnswer\": \"High quality leather shoes\"}" + + "]" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("query_set_id"); + } + + /** + * Creates a query set using NEW format (with custom fields). + * Format: [{queryText: "...", referenceAnswer: "...", category: "...", expectedScore: "..."}] + */ + private String createQuerySetNewFormat(String name) throws IOException, ParseException { + Request request = new Request("PUT", QUERY_SET_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"" + + name + + "\"," + + "\"description\": \"BWC test query set - new format\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"What is OpenSearch?\", \"referenceAnswer\": \"OpenSearch is a search suite\", \"category\": \"technology\"}," + + " {\"queryText\": \"red shoes\", \"referenceAnswer\": \"Leather shoes\", \"category\": \"fashion\", \"expectedScore\": \"0.95\"}" + + "]" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("query_set_id"); + } + + /** + * Creates a query set with multiple custom fields. + */ + private String createQuerySetWithMultipleCustomFields(String name) throws IOException, ParseException { + Request request = new Request("PUT", QUERY_SET_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"" + + name + + "\"," + + "\"description\": \"BWC test query set - multiple custom fields\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {" + + " \"queryText\": \"red leather shoes\"," + + " \"referenceAnswer\": \"High quality red leather shoes\"," + + " \"category\": \"footwear\"," + + " \"expectedScore\": \"0.95\"," + + " \"brand\": \"Nike\"," + + " \"priceRange\": \"premium\"" + + " }" + + "]" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("query_set_id"); + } + + /** + * Gets a query set by ID. + */ + private Map getQuerySet(String id) throws IOException, ParseException { + Request request = new Request("GET", QUERY_SET_ENDPOINT + "/" + id); + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + Map responseMap = parseResponse(response); + + // Extract the query set from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (Map) hitsList.get(0).get("_source"); + } + } + return null; + } + + /** + * Gets a search configuration by ID. + */ + private Map getSearchConfiguration(String id) throws IOException, ParseException { + Request request = new Request("GET", SEARCH_CONFIG_ENDPOINT + "/" + id); + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + Map responseMap = parseResponse(response); + + // Extract the search configuration from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (Map) hitsList.get(0).get("_source"); + } + } + return null; + } + + /** + * Deletes a query set by ID. + */ + private void deleteQuerySet(String id) throws IOException { + Request request = new Request("DELETE", QUERY_SET_ENDPOINT + "/" + id); + client().performRequest(request); + } + + /** + * Deletes a search configuration by ID. + */ + private void deleteSearchConfiguration(String id) throws IOException { + Request request = new Request("DELETE", SEARCH_CONFIG_ENDPOINT + "/" + id); + client().performRequest(request); + } + + /** + * Deletes an index silently (ignoring errors if index doesn't exist). + */ + private void deleteIndexSilently(String indexName) throws IOException { + Request request = new Request("DELETE", "/" + indexName); + try { + client().performRequest(request); + } catch (Exception e) { + // Ignore if index doesn't exist + } + } + + /** + * Gets query set ID by searching for it by name in the index. + * Similar to how neural-search BWC tests get model ID from pipeline. + */ + private String getQuerySetIdByName(String name) throws IOException, ParseException { + // Index name from PluginConstants.QUERY_SET_INDEX = "search-relevance-queryset" + String indexName = "search-relevance-queryset"; + + try { + Request request = new Request("POST", "/" + indexName + "/_search"); + // name is already a keyword field, no need for .keyword suffix + request.setJsonEntity("{" + "\"query\": {" + " \"term\": {" + " \"name\": \"" + name + "\"" + " }" + "}" + "}"); + + Response response = client().performRequest(request); + if (response.getStatusLine().getStatusCode() == 200) { + Map responseMap = parseResponse(response); + + // Extract the ID from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (String) hitsList.get(0).get("_id"); + } + } + } + } catch (Exception e) { + // Index might not exist yet + logger.debug("Failed to query index {}: {}", indexName, e.getMessage()); + } + return null; + } + + /** + * Gets search configuration ID by searching for it by name in the index. + */ + private String getSearchConfigIdByName(String name) throws IOException, ParseException { + // Index name from PluginConstants.SEARCH_CONFIGURATION_INDEX = "search-relevance-search-config" + String indexName = "search-relevance-search-config"; + + try { + Request request = new Request("POST", "/" + indexName + "/_search"); + // name is already a keyword field, no need for .keyword suffix + request.setJsonEntity("{" + "\"query\": {" + " \"term\": {" + " \"name\": \"" + name + "\"" + " }" + "}" + "}"); + + Response response = client().performRequest(request); + if (response.getStatusLine().getStatusCode() == 200) { + Map responseMap = parseResponse(response); + + // Extract the ID from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (String) hitsList.get(0).get("_id"); + } + } + } + } catch (Exception e) { + // Index might not exist yet + logger.debug("Failed to query index {}: {}", indexName, e.getMessage()); + } + return null; + } + + /** + * Gets judgment ID by searching for it by name in the index. + */ + private String getLlmJudgmentIdByName(String name) throws IOException, ParseException { + // Index name from PluginConstants.JUDGMENT_INDEX = "search-relevance-judgment" + String indexName = "search-relevance-judgment"; + + try { + Request request = new Request("POST", "/" + indexName + "/_search"); + // name is already a keyword field, no need for .keyword suffix + request.setJsonEntity("{" + "\"query\": {" + " \"term\": {" + " \"name\": \"" + name + "\"" + " }" + "}" + "}"); + + Response response = client().performRequest(request); + if (response.getStatusLine().getStatusCode() == 200) { + Map responseMap = parseResponse(response); + + // Extract the ID from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (String) hitsList.get(0).get("_id"); + } + } + } + } catch (Exception e) { + // Index might not exist yet + logger.debug("Failed to query index {}: {}", indexName, e.getMessage()); + } + return null; + } + + /** + * Creates an LLM judgment using OLD format (no promptTemplate, no llmJudgmentRatingType). + * Uses default values for these fields. + */ + private String createLlmJudgmentOldFormat(String name, String querySetId, String searchConfigId) throws IOException, ParseException { + Request request = new Request("PUT", JUDGMENT_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"" + + name + + "\"," + + "\"description\": \"BWC test judgment - old format\"," + + "\"type\": \"LLM_JUDGMENT\"," + + "\"modelId\": \"test-model-id\"," + + "\"querySetId\": \"" + + querySetId + + "\"," + + "\"searchConfigurationList\": [\"" + + searchConfigId + + "\"]," + + "\"size\": 10," + + "\"tokenLimit\": 1000," + + "\"contextFields\": [\"text\"]," + + "\"ignoreFailure\": false" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("judgment_id"); + } + + /** + * Creates an LLM judgment using NEW format (with promptTemplate and llmJudgmentRatingType). + */ + private String createLlmJudgmentNewFormat(String querySetId, String searchConfigId) throws IOException, ParseException { + Request request = new Request("PUT", JUDGMENT_ENDPOINT); + request.setJsonEntity( + "{" + + "\"name\": \"bwc-judgment-new-format\"," + + "\"description\": \"BWC test judgment - new format\"," + + "\"type\": \"LLM_JUDGMENT\"," + + "\"modelId\": \"test-model-id\"," + + "\"querySetId\": \"" + + querySetId + + "\"," + + "\"searchConfigurationList\": [\"" + + searchConfigId + + "\"]," + + "\"size\": 10," + + "\"tokenLimit\": 1000," + + "\"contextFields\": [\"text\"]," + + "\"ignoreFailure\": false," + + "\"promptTemplate\": \"Query: {{queryText}}\\\\n\\\\nDocuments: {{hits}}\\\\n\\\\nEvaluate the relevance of the search result.\"," + + "\"llmJudgmentRatingType\": \"SCORE0_1\"," + + "\"overwriteCache\": true" + + "}" + ); + + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + + Map responseMap = parseResponse(response); + return (String) responseMap.get("judgment_id"); + } + + /** + * Gets an LLM judgment by ID. + */ + private Map getLlmJudgment(String id) throws IOException, ParseException { + Request request = new Request("GET", JUDGMENT_ENDPOINT + "/" + id); + Response response = client().performRequest(request); + assertEquals(200, response.getStatusLine().getStatusCode()); + Map responseMap = parseResponse(response); + + // Extract the judgment from the search response + Map hits = (Map) responseMap.get("hits"); + if (hits != null && hits.get("hits") != null) { + java.util.List> hitsList = (java.util.List>) hits.get("hits"); + if (!hitsList.isEmpty()) { + return (Map) hitsList.get(0).get("_source"); + } + } + return null; + } + + /** + * Deletes an LLM judgment by ID. + */ + private void deleteLlmJudgment(String id) throws IOException { + Request request = new Request("DELETE", JUDGMENT_ENDPOINT + "/" + id); + try { + client().performRequest(request); + } catch (Exception e) { + // Ignore if judgment doesn't exist + } + } + + /** + * Parses HTTP response to Map. + */ + private Map parseResponse(Response response) throws IOException, ParseException { + String responseBody = EntityUtils.toString(response.getEntity()); + try ( + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + responseBody + ) + ) { + return parser.map(); + } + } +} diff --git a/settings.gradle b/settings.gradle index 109694a9..ed84b662 100644 --- a/settings.gradle +++ b/settings.gradle @@ -8,3 +8,7 @@ */ rootProject.name = 'opensearch-search-relevance' + +// Include BWC (Backward Compatibility) test modules +include ':qa' +include ':qa:rolling-upgrade' diff --git a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java index ad54312f..9a21b9ef 100644 --- a/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java +++ b/src/main/java/org/opensearch/searchrelevance/common/MLConstants.java @@ -10,6 +10,8 @@ import java.util.Locale; import java.util.Map; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; + /** * ML related constants. */ @@ -21,6 +23,25 @@ private MLConstants() {} * ML input field names */ public static final String PARAM_MESSAGES_FIELD = "messages"; + public static final String PROMPT_TEMPLATE = "promptTemplate"; + public static final String LLM_JUDGMENT_RATING_TYPE = "llmJudgmentRatingType"; + public static final String OVERWRITE_CACHE = "overwriteCache"; + + /** + * Prompt template placeholder names. + * These are the special variables that can be used in custom prompt templates. + */ + public static final String PLACEHOLDER_QUERY_TEXT = "queryText"; + public static final String PLACEHOLDER_SEARCH_TEXT = "searchText"; + public static final String PLACEHOLDER_HITS = "hits"; + public static final String PLACEHOLDER_RESULTS = "results"; + public static final String PLACEHOLDER_REFERENCE = "reference"; + public static final String PLACEHOLDER_REFERENCE_ANSWER = "referenceAnswer"; + + /** + * Default prompt template for LLM judgments (simple format without reference data) + */ + public static final String DEFAULT_PROMPT_TEMPLATE = "SearchText: {{searchText}}; Hits: {{hits}}"; /** * ML response field names @@ -29,6 +50,12 @@ private MLConstants() {} public static final String RESPONSE_MESSAGE_FIELD = "message"; public static final String RESPONSE_CONTENT_FIELD = "content"; + /** + * LLM RELEVANT/IRRELEVANT String + */ + public static final String RELEVANT_DECISION_STRING = "RELEVANT"; + public static final String IRRELEVANT_DECISION_STRING = "IRRELEVANT"; + /** * LLM defaulted token limits */ @@ -36,25 +63,79 @@ private MLConstants() {} public static final Integer MAXIMUM_TOKEN_LIMIT = 500000; public static final Integer MINIMUM_TOKEN_LIMIT = 1000; - /** - * Prompt strings that specific for llm-as-a-judge use case. - * TODO: need benchmark for final prompt definition. - */ - public static final String PROMPT_SEARCH_RELEVANCE = escapeJson( + public static final String PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START = escapeJson( "You are an expert search relevance rater. Your task is to evaluate the relevance between search query and results with these criteria:\n" + "- Score 1.0: Perfect match, highly relevant\n" + "- Score 0.7-0.9: Very relevant with minor variations\n" + "- Score 0.4-0.6: Moderately relevant\n" + "- Score 0.1-0.3: Slightly relevant\n" + "- Score 0.0: Completely irrelevant\n" - + "Evaluate based on: exact matches, semantic relevance, and overall context between the SearchText and content in Hits.\n" + ); + + public static final String PROMPT_SEARCH_RELEVANCE_SCORE_BINARY = escapeJson( + "You are an expert search relevance rater. Your task is to evaluate the relevance between search query and results with these criteria:\n" + + "RELEVANT: Perfect match, highly relevant\n" + + "IRRELEVANT: Completely irrelevant\n" + ); + + public static final String PROMPT_SEARCH_RELEVANCE_SCORE_END = escapeJson( + "\nEvaluate based on: exact matches, semantic relevance, and overall context between the SearchText and content in Hits.\n" + "When a reference is provided, evaluate based on the relevance to both SearchText and its reference.\n\n" - + "IMPORTANT: Provide your response ONLY as a JSON array of objects, each with \"id\" and \"rating_score\" fields. " - + "You MUST include a rating for EVERY hit provided, even if the rating is 0. " - + "Do not include any explanation or additional text." + + "IMPORTANT: You MUST include a rating for EVERY hit provided.\n\n" + + "Return ONLY a JSON object in this EXACT format:\n" + + "{\"ratings\": [{\"id\": \"doc_id_here\", \"rating_score\": }]}\n" + + "Do not include any explanation, commentary, or markdown formatting. Return only the JSON object." ); + + /** + * JSON Schema definitions for OpenAI structured output. + * These schemas enforce the output format at the model level. + */ + public static final String RATING_SCORE_NUMERIC_SCHEMA = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"id\":{\"type\":\"string\"}," + + "\"rating_score\":{\"type\":\"number\"}" + + "}," + + "\"required\":[\"id\",\"rating_score\"]," + + "\"additionalProperties\":false" + + "}"; + + public static final String RATING_SCORE_BINARY_SCHEMA = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"id\":{\"type\":\"string\"}," + + "\"rating_score\":{\"type\":\"string\",\"enum\":[\"RELEVANT\",\"IRRELEVANT\"]}" + + "}," + + "\"required\":[\"id\",\"rating_score\"]," + + "\"additionalProperties\":false" + + "}"; + + public static final String RESPONSE_FORMAT_TEMPLATE = "{" + + "\"type\":\"json_schema\"," + + "\"json_schema\":{" + + "\"name\":\"rating_response\"," + + "\"strict\":true," + + "\"schema\":{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"ratings\":{" + + "\"type\":\"array\"," + + "\"items\":%s" + + "}" + + "}," + + "\"required\":[\"ratings\"]," + + "\"additionalProperties\":false" + + "}" + + "}" + + "}"; + public static final String PROMPT_JSON_MESSAGES_SHELL = "[{\"role\":\"system\",\"content\":\"%s\"}," + "{\"role\":\"user\",\"content\":\"%s\"}]"; + public static final String PROMPT_JSON_MESSAGES_WITH_SCHEMA_SHELL = "{" + + "\"messages\":[{\"role\":\"system\",\"content\":\"%s\"},{\"role\":\"user\",\"content\":\"%s\"}]," + + "\"response_format\":%s" + + "}"; public static final String INPUT_FORMAT_SEARCH = "SearchText - %s; Hits - %s"; public static final String INPUT_FORMAT_SEARCH_WITH_REFERENCE = "SearchText: %s; Reference: %s; Hits: %s"; @@ -65,15 +146,19 @@ public static String escapeJson(String str) { return str.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t"); } - public static String sanitizeLLMResponse(String response) { - if (response == null) return ""; - - // Remove special characters that might cause parsing issues - String cleaned = response.replaceAll("``json", "").replace("`", "").replace("\n", " ").trim(); - if (!cleaned.startsWith("[")) { - cleaned = "[" + cleaned + "]"; + /** + * Get the appropriate response format schema based on rating type. + * @param ratingType The rating type to get the schema for + * @return The complete response_format JSON string with the appropriate schema + */ + public static String getResponseFormatSchema(LLMJudgmentRatingType ratingType) { + String itemSchema; + if (ratingType == LLMJudgmentRatingType.RELEVANT_IRRELEVANT) { + itemSchema = RATING_SCORE_BINARY_SCHEMA; + } else { + itemSchema = RATING_SCORE_NUMERIC_SCHEMA; } - return cleaned; + return String.format(Locale.ROOT, RESPONSE_FORMAT_TEMPLATE, itemSchema); } public static int validateTokenLimit(Map source) { diff --git a/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java b/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java index 7eb0529d..d4a800ec 100644 --- a/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java +++ b/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java @@ -10,6 +10,7 @@ import static org.opensearch.searchrelevance.indices.SearchRelevanceIndices.JUDGMENT_CACHE; import static org.opensearch.searchrelevance.model.JudgmentCache.CONTEXT_FIELDS_STR; import static org.opensearch.searchrelevance.model.JudgmentCache.DOCUMENT_ID; +import static org.opensearch.searchrelevance.model.JudgmentCache.PROMPT_TEMPLATE_ID; import static org.opensearch.searchrelevance.model.JudgmentCache.QUERY_TEXT; import static org.opensearch.searchrelevance.utils.ParserUtils.convertListToSortedStr; @@ -115,22 +116,25 @@ public void upsertJudgmentCache(final JudgmentCache judgmentCache, final ActionL * @param queryText - queryText to be searched * @param documentId - documentId to be searched * @param contextFields - contextFields to be searched + * @param promptTemplateCode - hash of promptTemplate and ratingType * @param listener - async operation */ public SearchResponse getJudgmentCache( String queryText, String documentId, List contextFields, + String promptTemplateCode, ActionListener listener ) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); String contextFieldsStr = contextFields != null ? convertListToSortedStr(contextFields) : ""; LOGGER.debug( - "Building cache search query - queryText: '{}', documentId: '{}', contextFields: '{}'", + "Building cache search query - queryText: '{}', documentId: '{}', contextFields: '{}', promptTemplateCode: '{}'", queryText, documentId, - contextFieldsStr + contextFieldsStr, + promptTemplateCode ); BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() @@ -141,6 +145,10 @@ public SearchResponse getJudgmentCache( boolQuery.must(QueryBuilders.matchQuery(CONTEXT_FIELDS_STR, contextFieldsStr)); } + if (promptTemplateCode != null && !promptTemplateCode.isEmpty()) { + boolQuery.must(QueryBuilders.termQuery(PROMPT_TEMPLATE_ID, promptTemplateCode)); + } + searchSourceBuilder.query(boolQuery); ActionListener wrappedListener = ActionListener.wrap(response -> { diff --git a/src/main/java/org/opensearch/searchrelevance/executors/ExperimentTaskContext.java b/src/main/java/org/opensearch/searchrelevance/executors/ExperimentTaskContext.java index ec0cc681..6ff1e6f0 100644 --- a/src/main/java/org/opensearch/searchrelevance/executors/ExperimentTaskContext.java +++ b/src/main/java/org/opensearch/searchrelevance/executors/ExperimentTaskContext.java @@ -84,11 +84,9 @@ public void scheduleVariantWrite(ExperimentVariant variant, String evaluationId, } } - CompletableFuture.runAsync(() -> { - experimentVariantDao.putExperimentVariantEfficient(variant, ActionListener.wrap(response -> { - log.debug("write successful for variant: {}", variant.getId()); - }, error -> { log.error("write failed for variant {}: {}", variant.getId(), error.getMessage()); })); - }); + experimentVariantDao.putExperimentVariantEfficient(variant, ActionListener.wrap(response -> { + log.debug("write successful for variant: {}", variant.getId()); + }, error -> { log.error("write failed for variant {}: {}", variant.getId(), error.getMessage()); })); } /** diff --git a/src/main/java/org/opensearch/searchrelevance/executors/JudgmentTaskContext.java b/src/main/java/org/opensearch/searchrelevance/executors/JudgmentTaskContext.java index 46556e04..5112f7f5 100644 --- a/src/main/java/org/opensearch/searchrelevance/executors/JudgmentTaskContext.java +++ b/src/main/java/org/opensearch/searchrelevance/executors/JudgmentTaskContext.java @@ -29,7 +29,7 @@ @Log4j2 @Getter public class JudgmentTaskContext { - private final String queryTextWithReference; + private final String queryTextWithCustomInput; private final String modelId; private final List contextFields; private final List searchConfigurations; @@ -47,14 +47,14 @@ public class JudgmentTaskContext { private ActionListener> completionListener; public JudgmentTaskContext( - String queryTextWithReference, + String queryTextWithCustomInput, String modelId, List contextFields, List searchConfigurations, boolean ignoreFailure, ActionListener> completionListener ) { - this.queryTextWithReference = queryTextWithReference; + this.queryTextWithCustomInput = queryTextWithCustomInput; this.modelId = modelId; this.contextFields = contextFields; this.searchConfigurations = searchConfigurations; @@ -72,7 +72,7 @@ public JudgmentTaskContext( log.info( "JudgmentTaskContext initialized for query: {} with {} search configurations", - queryTextWithReference, + queryTextWithCustomInput, searchConfigurations.size() ); } @@ -88,11 +88,11 @@ public void completeSearchTask(boolean success) { successfulTasks.incrementAndGet(); } else { failedTasks.incrementAndGet(); - log.warn("Search task failed for query: {} (ignoreFailure={})", queryTextWithReference, ignoreFailure); + log.warn("Search task failed for query: {} (ignoreFailure={})", queryTextWithCustomInput, ignoreFailure); } if (pendingSearchTasks.decrementAndGet() == 0) { - log.debug("All search tasks completed for query: {}", queryTextWithReference); + log.debug("All search tasks completed for query: {}", queryTextWithCustomInput); } } @@ -103,11 +103,11 @@ public void completeCacheTask(boolean success) { successfulTasks.incrementAndGet(); } else { failedTasks.incrementAndGet(); - log.warn("Cache task failed for query: {} (ignoreFailure={})", queryTextWithReference, ignoreFailure); + log.warn("Cache task failed for query: {} (ignoreFailure={})", queryTextWithCustomInput, ignoreFailure); } if (pendingCacheTasks.decrementAndGet() == 0) { - log.debug("All cache tasks completed for query: {}", queryTextWithReference); + log.debug("All cache tasks completed for query: {}", queryTextWithCustomInput); } } @@ -122,7 +122,7 @@ public void completeJudgment() { log.info( "Judgment completed for query: {} with {} ratings (success: {}, failed: {}, status: {})", - queryTextWithReference, + queryTextWithCustomInput, docIdToScore.size(), successfulTasks.get(), failedTasks.get(), @@ -161,7 +161,7 @@ public JudgmentBatchStatus getStatus() { public void failJudgment(Exception e) { if (hasTerminated.getAndSet(true)) return; - log.error("Judgment failed for query: {} (ignoreFailure={})", queryTextWithReference, ignoreFailure, e); + log.error("Judgment failed for query: {} (ignoreFailure={})", queryTextWithCustomInput, ignoreFailure, e); if (completionListener != null) { completionListener.onFailure(e); } diff --git a/src/main/java/org/opensearch/searchrelevance/executors/LlmJudgmentTaskManager.java b/src/main/java/org/opensearch/searchrelevance/executors/LlmJudgmentTaskManager.java index 4747e925..9bce40aa 100644 --- a/src/main/java/org/opensearch/searchrelevance/executors/LlmJudgmentTaskManager.java +++ b/src/main/java/org/opensearch/searchrelevance/executors/LlmJudgmentTaskManager.java @@ -50,27 +50,27 @@ public LlmJudgmentTaskManager(ThreadPool threadPool) { } public void scheduleTasksAsync( - List queryTextWithReferences, + List queryTextsWithCustomInput, Function> queryProcessor, boolean ignoreFailure, ActionListener>> listener ) { - int totalQueries = queryTextWithReferences.size(); + int totalQueries = queryTextsWithCustomInput.size(); log.info("Scheduling {} query text tasks for concurrent processing", totalQueries); try { - List>> futures = queryTextWithReferences.stream() - .map(queryTextWithReference -> CompletableFuture.supplyAsync(() -> { + List>> futures = queryTextsWithCustomInput.stream() + .map(queryTextWithCustomInput -> CompletableFuture.supplyAsync(() -> { try { rateLimiter.acquire(); try { - return queryProcessor.apply(queryTextWithReference); + return queryProcessor.apply(queryTextWithCustomInput); } finally { rateLimiter.release(); } } catch (Exception e) { - log.warn("Query processing failed, returning empty result for: {}", queryTextWithReference, e); - return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, Map.of()); + log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); + return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } }, threadPool.executor(THREAD_POOL_EXECUTOR_NAME))) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/JudgmentDataTransformer.java b/src/main/java/org/opensearch/searchrelevance/judgments/JudgmentDataTransformer.java index 007eaf72..92036826 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/JudgmentDataTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/JudgmentDataTransformer.java @@ -17,9 +17,9 @@ */ public class JudgmentDataTransformer { - public static Map createJudgmentResult(String queryTextWithReference, Map docIdToScore) { + public static Map createJudgmentResult(String queryTextWithCustomInput, Map docIdToScore) { Map judgmentForQuery = new HashMap<>(); - judgmentForQuery.put("query", queryTextWithReference); + judgmentForQuery.put("query", queryTextWithCustomInput); List> docIdRatings = docIdToScore == null ? List.of() @@ -32,7 +32,7 @@ public static Map createJudgmentResult(String queryTextWithRefer return judgmentForQuery; } - public static String extractQueryText(String queryTextWithReference, String delimiter) { - return queryTextWithReference.split(delimiter, 2)[0]; + public static String extractQueryText(String queryTextWithCustomInput, String delimiter) { + return queryTextWithCustomInput.split(delimiter, 2)[0]; } } diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index 629e7f17..be99c792 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -7,12 +7,16 @@ */ package org.opensearch.searchrelevance.judgments; -import static org.opensearch.searchrelevance.common.MLConstants.sanitizeLLMResponse; -import static org.opensearch.searchrelevance.model.QueryWithReference.DELIMITER; +import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.model.builder.SearchRequestBuilder.buildSearchRequest; import static org.opensearch.searchrelevance.utils.ParserUtils.combinedIndexAndDocId; +import static org.opensearch.searchrelevance.utils.ParserUtils.generatePromptTemplateCode; import static org.opensearch.searchrelevance.utils.ParserUtils.generateUniqueId; import static org.opensearch.searchrelevance.utils.ParserUtils.getDocIdFromCompositeKey; +import static org.opensearch.searchrelevance.utils.RatingOutputProcessor.convertRatingScore; +import static org.opensearch.searchrelevance.utils.RatingOutputProcessor.sanitizeLLMResponse; import java.util.ArrayList; import java.util.Collections; @@ -42,10 +46,12 @@ import org.opensearch.searchrelevance.ml.MLAccessor; import org.opensearch.searchrelevance.model.JudgmentCache; import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.model.QuerySet; import org.opensearch.searchrelevance.model.SearchConfiguration; import org.opensearch.searchrelevance.stats.events.EventStatName; import org.opensearch.searchrelevance.stats.events.EventStatsManager; +import org.opensearch.searchrelevance.utils.ParserUtils; import org.opensearch.searchrelevance.utils.TimeUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; @@ -107,13 +113,33 @@ private void generateJudgmentRatingInternal(Map metadata, Action int tokenLimit = (int) metadata.get("tokenLimit"); List contextFields = (List) metadata.get("contextFields"); boolean ignoreFailure = (boolean) metadata.get("ignoreFailure"); + String promptTemplate = (String) metadata.get(PROMPT_TEMPLATE); + LLMJudgmentRatingType ratingType = (LLMJudgmentRatingType) metadata.get(LLM_JUDGMENT_RATING_TYPE); + // Default to SCORE0_1 if ratingType is not provided + if (ratingType == null) { + ratingType = LLMJudgmentRatingType.SCORE0_1; + log.debug("No ratingType provided, defaulting to SCORE0_1"); + } + boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); QuerySet querySet = querySetDao.getQuerySetSync(querySetId); List searchConfigurations = searchConfigurationList.stream() .map(id -> searchConfigurationDao.getSearchConfigurationSync(id)) .collect(Collectors.toList()); - generateLLMJudgmentsAsync(modelId, size, tokenLimit, contextFields, querySet, searchConfigurations, ignoreFailure, listener); + generateLLMJudgmentsAsync( + modelId, + size, + tokenLimit, + contextFields, + querySet, + searchConfigurations, + ignoreFailure, + promptTemplate, + ratingType, + overwriteCache, + listener + ); } catch (Exception e) { log.error("Failed to generate LLM judgments", e); listener.onFailure(new SearchRelevanceException("Failed to generate LLM judgments", e, RestStatus.INTERNAL_SERVER_ERROR)); @@ -128,10 +154,13 @@ private void generateLLMJudgmentsAsync( QuerySet querySet, List searchConfigurations, boolean ignoreFailure, + String promptTemplate, + LLMJudgmentRatingType ratingType, + boolean overwriteCache, ActionListener>> listener ) { - List queryTextWithReferences = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList()); - int totalQueries = queryTextWithReferences.size(); + List queryTextsWithCustomInput = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList()); + int totalQueries = queryTextsWithCustomInput.size(); log.info("Starting LLM judgment generation for {} total queries", totalQueries); @@ -141,7 +170,7 @@ private void generateLLMJudgmentsAsync( cacheIndexListener.whenComplete(indexResult -> { log.debug("Judgment cache index creation completed, proceeding with task scheduling"); - taskManager.scheduleTasksAsync(queryTextWithReferences, queryTextWithReference -> { + taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { return processQueryTextAsync( modelId, @@ -149,16 +178,19 @@ private void generateLLMJudgmentsAsync( tokenLimit, contextFields, searchConfigurations, - queryTextWithReference, - ignoreFailure + queryTextWithCustomInput, + ignoreFailure, + promptTemplate, + ratingType, + overwriteCache ); } catch (Exception e) { if (ignoreFailure) { - log.warn("Query processing failed, returning empty result for: {}", queryTextWithReference, e); - return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, Map.of()); + log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); + return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { - log.error("Query processing failed for: {}", queryTextWithReference, e); - throw new RuntimeException("Query processing failed: " + queryTextWithReference, e); + log.error("Query processing failed for: {}", queryTextWithCustomInput, e); + throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } }, ignoreFailure, ActionListener.wrap(results -> { @@ -185,7 +217,7 @@ private void generateLLMJudgmentsAsync( }, indexError -> { log.warn("Failed to create judgment cache index, proceeding without cache optimization", indexError); - taskManager.scheduleTasksAsync(queryTextWithReferences, queryTextWithReference -> { + taskManager.scheduleTasksAsync(queryTextsWithCustomInput, queryTextWithCustomInput -> { try { return processQueryTextAsync( modelId, @@ -193,16 +225,19 @@ private void generateLLMJudgmentsAsync( tokenLimit, contextFields, searchConfigurations, - queryTextWithReference, - ignoreFailure + queryTextWithCustomInput, + ignoreFailure, + promptTemplate, + ratingType, + overwriteCache ); } catch (Exception e) { if (ignoreFailure) { - log.warn("Query processing failed, returning empty result for: {}", queryTextWithReference, e); - return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, Map.of()); + log.warn("Query processing failed, returning empty result for: {}", queryTextWithCustomInput, e); + return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, Map.of()); } else { - log.error("Query processing failed for: {}", queryTextWithReference, e); - throw new RuntimeException("Query processing failed: " + queryTextWithReference, e); + log.error("Query processing failed for: {}", queryTextWithCustomInput, e); + throw new RuntimeException("Query processing failed: " + queryTextWithCustomInput, e); } } }, ignoreFailure, ActionListener.wrap(results -> { @@ -235,49 +270,66 @@ private Map processQueryTextAsync( int tokenLimit, List contextFields, List searchConfigurations, - String queryTextWithReference, - boolean ignoreFailure + String queryTextWithCustomInput, + boolean ignoreFailure, + String promptTemplate, + LLMJudgmentRatingType ratingType, + boolean overwriteCache ) { - log.info("Processing query text judgment: {}", queryTextWithReference); + log.info("Processing query text judgment: {}", queryTextWithCustomInput); ConcurrentMap allHits = new ConcurrentHashMap<>(); ConcurrentMap docIdToScore = new ConcurrentHashMap<>(); - String queryText = queryTextWithReference.split(DELIMITER, 2)[0]; + String queryText = ParserUtils.parseQueryTextWithCustomInput(queryTextWithCustomInput).get("queryText"); try { // Step 1: Execute searches concurrently within this query text task processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure); - // Step 2: Deduplicate from cache + // Step 2: Deduplicate from cache (skip if overwriteCache is true) List docIds = new ArrayList<>(allHits.keySet()); + String index = searchConfigurations.get(0).index(); + String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); List unprocessedDocIds = deduplicateFromCache( index, - queryTextWithReference, + queryTextWithCustomInput, contextFields, docIds, docIdToScore, - ignoreFailure + ignoreFailure, + promptTemplateCode, + overwriteCache ); // Step 3: Process with LLM if needed if (!unprocessedDocIds.isEmpty()) { - processWithLLM(modelId, queryTextWithReference, tokenLimit, contextFields, unprocessedDocIds, allHits, index, docIdToScore); + processWithLLM( + modelId, + queryTextWithCustomInput, + tokenLimit, + contextFields, + unprocessedDocIds, + allHits, + index, + docIdToScore, + promptTemplate, + ratingType + ); } - Map result = JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, docIdToScore); - log.debug("Query processing completed for: {} with {} ratings", queryTextWithReference, docIdToScore.size()); + Map result = JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); return result; } catch (Exception e) { log.warn( "Query processing failed for: {} with {} ratings collected. Error: {}", - queryTextWithReference, + queryTextWithCustomInput, docIdToScore.size(), e.getMessage(), e ); // Always return a result with whatever ratings we managed to collect - return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, docIdToScore); + return JudgmentDataTransformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); } } @@ -312,12 +364,19 @@ private void processSearchConfigurationsAsync( private List deduplicateFromCache( String index, - String queryTextWithReference, + String queryTextWithCustomInput, List contextFields, List docIds, ConcurrentMap docIdToScore, - boolean ignoreFailure + boolean ignoreFailure, + String promptTemplateCode, + boolean overwriteCache ) throws Exception { + // If overwriteCache is true, skip cache lookup and return all docIds as unprocessed + if (overwriteCache) { + log.info("overwriteCache flag is enabled, skipping cache lookup for all {} docs", docIds.size()); + return docIds; + } List processedDocIds = Collections.synchronizedList(new ArrayList<>()); AtomicBoolean hasFailure = new AtomicBoolean(false); @@ -325,9 +384,10 @@ private List deduplicateFromCache( String compositeKey = combinedIndexAndDocId(index, docId); CompletableFuture future = new CompletableFuture<>(); judgmentCacheDao.getJudgmentCache( - queryTextWithReference, + queryTextWithCustomInput, compositeKey, contextFields, + promptTemplateCode, ActionListener.wrap(future::complete, future::completeExceptionally) ); @@ -356,13 +416,15 @@ private List deduplicateFromCache( private void processWithLLM( String modelId, - String queryTextWithReference, + String queryTextWithCustomInput, int tokenLimit, List contextFields, List unprocessedDocIds, ConcurrentMap allHits, String index, - ConcurrentMap docIdToScore + ConcurrentMap docIdToScore, + String promptTemplate, + LLMJudgmentRatingType ratingType ) throws Exception { Map unionHits = new HashMap<>(); @@ -375,10 +437,27 @@ private void processWithLLM( } log.info("Processing {} uncached docs with LLM", unionHits.size()); + log.debug("DEBUG: unionHits keys being sent to LLM: {}", unionHits.keySet()); + log.debug("DEBUG: queryTextWithCustomInput: {}", queryTextWithCustomInput); + log.debug("DEBUG: modelId: {}, tokenLimit: {}, ratingType: {}", modelId, tokenLimit, ratingType); + + // Generate promptTemplateCode for cache updates + String promptTemplateCode = generatePromptTemplateCode(promptTemplate, ratingType); // Synchronous LLM call PlainActionFuture> llmFuture = PlainActionFuture.newFuture(); - generateLLMJudgmentForQueryText(modelId, queryTextWithReference, tokenLimit, contextFields, unionHits, new HashMap<>(), llmFuture); + generateLLMJudgmentForQueryText( + modelId, + queryTextWithCustomInput, + tokenLimit, + contextFields, + unionHits, + new HashMap<>(), + promptTemplate, + ratingType, + promptTemplateCode, + llmFuture + ); Map llmResults = llmFuture.actionGet(); docIdToScore.putAll(llmResults); @@ -388,98 +467,135 @@ private void processWithLLM( private void generateLLMJudgmentForQueryText( String modelId, - String queryTextWithReference, + String queryTextWithCustomInput, int tokenLimit, List contextFields, Map unprocessedUnionHits, Map docIdToRating, + String promptTemplate, + LLMJudgmentRatingType ratingType, + String promptTemplateCode, ActionListener> listener ) { log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", modelId, unprocessedUnionHits); log.debug("processed docIdToRating before llm evaluation: {}", docIdToRating); if (unprocessedUnionHits.isEmpty()) { - log.info("All hits found in cache, returning cached results for query: {}", queryTextWithReference); + log.info("All hits found in cache, returning cached results for query: {}", queryTextWithCustomInput); listener.onResponse(docIdToRating); return; } - String[] queryTextRefArr = queryTextWithReference.split(DELIMITER); - String queryText = queryTextRefArr[0]; - String referenceAnswer = queryTextRefArr.length > 1 ? queryTextWithReference.split(DELIMITER, 2)[1] : null; + // Parse queryTextWithCustomInput to extract query and reference data + Map parsedData = ParserUtils.parseQueryTextWithCustomInput(queryTextWithCustomInput); + String queryText = parsedData.remove("queryText"); + Map referenceData = parsedData; // Remaining entries are reference data ConcurrentMap processedRatings = new ConcurrentHashMap<>(docIdToRating); ConcurrentMap>> combinedResponses = new ConcurrentHashMap<>(); AtomicBoolean hasFailure = new AtomicBoolean(false); - mlAccessor.predict(modelId, tokenLimit, queryText, referenceAnswer, unprocessedUnionHits, new ActionListener() { - @Override - public void onResponse(ChunkResult chunkResult) { - try { - // Process all chunks, let query level decide on failures + mlAccessor.predict( + modelId, + tokenLimit, + queryText, + referenceData, + unprocessedUnionHits, + promptTemplate, + ratingType, + new ActionListener() { + @Override + public void onResponse(ChunkResult chunkResult) { + try { + // Process all chunks, let query level decide on failures + + Map succeededChunks = chunkResult.getSucceededChunks(); + for (Map.Entry entry : succeededChunks.entrySet()) { + Integer chunkIndex = entry.getKey(); + if (combinedResponses.containsKey(chunkIndex)) { + continue; + } - Map succeededChunks = chunkResult.getSucceededChunks(); - for (Map.Entry entry : succeededChunks.entrySet()) { - Integer chunkIndex = entry.getKey(); - if (combinedResponses.containsKey(chunkIndex)) { - continue; + log.debug("response before sanitization: {}", entry.getValue()); + String sanitizedResponse = sanitizeLLMResponse(entry.getValue()); + log.debug("response after sanitization: {}", sanitizedResponse); + List> scores = OBJECT_MAPPER.readValue( + sanitizedResponse, + new TypeReference>>() { + } + ); + combinedResponses.put(chunkIndex, scores); } - log.debug("response before sanitization: {}", entry.getValue()); - String sanitizedResponse = sanitizeLLMResponse(entry.getValue()); - log.debug("response after sanitization: {}", sanitizedResponse); - List> scores = OBJECT_MAPPER.readValue( - sanitizedResponse, - new TypeReference>>() { + logFailedChunks(chunkResult); + + if (chunkResult.isLastChunk() && !hasFailure.get()) { + log.info( + "Processing final results for query: {}. Successful chunks: {}, Failed chunks: {}", + queryTextWithCustomInput, + chunkResult.getSuccessfulChunksCount(), + chunkResult.getFailedChunksCount() + ); + + log.debug("DEBUG: combinedResponses size: {}", combinedResponses.size()); + for (List> ratings : combinedResponses.values()) { + log.debug("DEBUG: Processing ratings batch with {} ratings", ratings.size()); + for (Map rating : ratings) { + String compositeKey = (String) rating.get("id"); + Object rawRatingScore = rating.get("rating_score"); + log.debug( + "DEBUG: Processing rating - compositeKey: {}, rawRatingScore: {}", + compositeKey, + rawRatingScore + ); + Double ratingScore = convertRatingScore(rawRatingScore, ratingType); + String docId = getDocIdFromCompositeKey(compositeKey); + log.debug("DEBUG: Converted rating - docId: {}, ratingScore: {}", docId, ratingScore); + processedRatings.put(docId, ratingScore.toString()); + updateJudgmentCache( + compositeKey, + queryTextWithCustomInput, + contextFields, + ratingScore.toString(), + modelId, + promptTemplateCode + ); + } } - ); - combinedResponses.put(chunkIndex, scores); - } - - logFailedChunks(chunkResult); - if (chunkResult.isLastChunk() && !hasFailure.get()) { - log.info( - "Processing final results for query: {}. Successful chunks: {}, Failed chunks: {}", - queryTextWithReference, - chunkResult.getSuccessfulChunksCount(), - chunkResult.getFailedChunksCount() - ); - - for (List> ratings : combinedResponses.values()) { - for (Map rating : ratings) { - String compositeKey = (String) rating.get("id"); - Double ratingScore = ((Number) rating.get("rating_score")).doubleValue(); - String docId = getDocIdFromCompositeKey(compositeKey); - processedRatings.put(docId, ratingScore.toString()); - updateJudgmentCache(compositeKey, queryTextWithReference, contextFields, ratingScore.toString(), modelId); - } + log.debug("DEBUG: Final processedRatings size: {}, ratings: {}", processedRatings.size(), processedRatings); + listener.onResponse(processedRatings); } - - listener.onResponse(processedRatings); + } catch (Exception e) { + handleProcessingError(e, chunkResult.isLastChunk()); } - } catch (Exception e) { - handleProcessingError(e, chunkResult.isLastChunk()); } - } - @Override - public void onFailure(Exception e) { - handleProcessingError(e, true); - } + @Override + public void onFailure(Exception e) { + handleProcessingError(e, true); + } - private void handleProcessingError(Exception e, boolean isLastChunk) { - if (!hasFailure.getAndSet(true)) { - log.error("Failed to process chunk response", e); - listener.onFailure( - new SearchRelevanceException("Failed to process chunk response", e, RestStatus.INTERNAL_SERVER_ERROR) - ); + private void handleProcessingError(Exception e, boolean isLastChunk) { + if (!hasFailure.getAndSet(true)) { + log.error("Failed to process chunk response", e); + listener.onFailure( + new SearchRelevanceException("Failed to process chunk response", e, RestStatus.INTERNAL_SERVER_ERROR) + ); + } } } - }); + ); } - private void updateJudgmentCache(String compositeKey, String queryText, List contextFields, String rating, String modelId) { + private void updateJudgmentCache( + String compositeKey, + String queryText, + List contextFields, + String rating, + String modelId, + String promptTemplateCode + ) { try { JudgmentCache judgmentCache = new JudgmentCache( generateUniqueId(queryText, compositeKey, contextFields), @@ -488,7 +604,8 @@ private void updateJudgmentCache(String compositeKey, String queryText, List createIndexStep = new StepListener<>(); judgmentCacheDao.createIndexIfAbsent(createIndexStep); @@ -548,4 +665,5 @@ private String getContextSource(SearchHit hit, List contextFields) { throw new RuntimeException("Failed to process context source", e); } } + } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java index 070390be..210ecca6 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLAccessor.java @@ -7,6 +7,7 @@ */ package org.opensearch.searchrelevance.ml; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -14,7 +15,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.searchrelevance.utils.RatingOutputProcessor; import lombok.extern.log4j.Log4j2; @@ -38,12 +42,22 @@ public void predict( String modelId, int tokenLimit, String searchText, - String reference, + Map referenceData, Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType, ActionListener progressListener ) { - List mlInputs = transformer.createMLInputs(tokenLimit, searchText, reference, hits); + log.debug( + "DEBUG: MLAccessor.predict called with modelId: {}, searchText: {}, hits count: {}, ratingType: {}", + modelId, + searchText, + hits.size(), + ratingType + ); + List mlInputs = transformer.createMLInputs(tokenLimit, searchText, referenceData, hits, promptTemplate, ratingType); log.info("Number of chunks: {}", mlInputs.size()); + log.debug("DEBUG: Created {} MLInput chunks", mlInputs.size()); ChunkProcessingContext context = new ChunkProcessingContext(mlInputs.size(), progressListener); @@ -53,10 +67,32 @@ public void predict( } private void processChunk(String modelId, MLInput mlInput, int chunkIndex, ChunkProcessingContext context) { - predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, 0, ActionListener.wrap(response -> { + processChunkWithFallback(modelId, mlInput, chunkIndex, false, context); + } + + private void processChunkWithFallback( + String modelId, + MLInput mlInput, + int chunkIndex, + boolean triedWithoutResponseFormat, + ChunkProcessingContext context + ) { + predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, 0, triedWithoutResponseFormat, ActionListener.wrap(response -> { log.info("Chunk {} processed successfully", chunkIndex); String processedResponse = cleanResponse(response); - context.handleSuccess(chunkIndex, processedResponse); + + // Check if parsing failed (empty ratings array) and we haven't tried without response_format yet + if ("[]".equals(processedResponse) && !triedWithoutResponseFormat) { + log.warn( + "Chunk {} returned empty ratings with response_format. Retrying without response_format for GPT-3.5 compatibility...", + chunkIndex + ); + // Create new MLInput without response_format and retry + MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); + scheduleRetry(() -> processChunkWithFallback(modelId, mlInputWithoutFormat, chunkIndex, true, context), RETRY_DELAY_MS); + } else { + context.handleSuccess(chunkIndex, processedResponse); + } }, e -> { log.error("Chunk {} failed after all retries", chunkIndex, e); context.handleFailure(chunkIndex, e); @@ -64,29 +100,71 @@ private void processChunk(String modelId, MLInput mlInput, int chunkIndex, Chunk } private String cleanResponse(String response) { - return response.substring(1, response.length() - 1); // remove brackets + // Use sanitizeLLMResponse to handle both structured (with response_format) and unstructured responses + // For GPT-4o with response_format: extracts {"ratings": [...]} + // For GPT-3.5 without response_format: parses and sanitizes unstructured JSON + return RatingOutputProcessor.sanitizeLLMResponse(response); } + /** + * Retries prediction with automatic fallback to non-structured output. + * First tries with response_format, then falls back to without response_format if it fails. + * + * @param triedWithoutResponseFormat Tracks if we've already tried without response_format + */ private void predictSingleChunkWithRetry( String modelId, MLInput mlInput, int chunkIndex, int retryCount, + boolean triedWithoutResponseFormat, ActionListener chunkListener ) { predictSingleChunk(modelId, mlInput, new ActionListener() { @Override public void onResponse(String response) { + log.debug( + "DEBUG: Chunk {} received response (length: {}). First 200 chars: {}", + chunkIndex, + response.length(), + response.substring(0, Math.min(200, response.length())) + ); chunkListener.onResponse(response); } @Override public void onFailure(Exception e) { - if (retryCount < MAX_RETRY_NUMBER) { + log.debug( + "DEBUG: Chunk {} failed with error: {}. triedWithoutResponseFormat: {}, retryCount: {}", + chunkIndex, + e.getMessage(), + triedWithoutResponseFormat, + retryCount + ); + // If we haven't tried without response_format yet, try that first before regular retries + if (!triedWithoutResponseFormat) { + log.warn( + "Chunk {} failed with response_format. Retrying without response_format for GPT-3.5 compatibility...", + chunkIndex + ); + log.debug("DEBUG: Creating MLInput without response_format for chunk {}", chunkIndex); + + // Create new MLInput without response_format + MLInput mlInputWithoutFormat = recreateMLInputWithoutResponseFormat(mlInput); + + long delay = RETRY_DELAY_MS; + scheduleRetry( + () -> predictSingleChunkWithRetry(modelId, mlInputWithoutFormat, chunkIndex, 0, true, chunkListener), + delay + ); + } else if (retryCount < MAX_RETRY_NUMBER) { log.warn("Chunk {} failed, attempt {}/{}. Retrying...", chunkIndex, retryCount + 1, MAX_RETRY_NUMBER); long delay = RETRY_DELAY_MS * (long) Math.pow(2, retryCount); - scheduleRetry(() -> predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, chunkListener), delay); + scheduleRetry( + () -> predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, true, chunkListener), + delay + ); } else { chunkListener.onFailure(e); } @@ -94,16 +172,45 @@ public void onFailure(Exception e) { }); } + /** + * Recreates MLInput without response_format parameter for models that don't support it (e.g., GPT-3.5). + */ + private MLInput recreateMLInputWithoutResponseFormat(MLInput originalInput) { + // Extract the parameters from the original input and rebuild without response_format + RemoteInferenceInputDataSet originalDataSet = (RemoteInferenceInputDataSet) originalInput.getInputDataset(); + Map originalParams = originalDataSet.getParameters(); + + // Create new parameters map without response_format + Map newParams = new HashMap<>(); + for (Map.Entry entry : originalParams.entrySet()) { + if (!"response_format".equals(entry.getKey())) { + newParams.put(entry.getKey(), entry.getValue()); + } + } + + return MLInput.builder().algorithm(originalInput.getAlgorithm()).inputDataset(new RemoteInferenceInputDataSet(newParams)).build(); + } + private void scheduleRetry(Runnable runnable, long delayMs) { CompletableFuture.delayedExecutor(delayMs, TimeUnit.MILLISECONDS).execute(runnable); } public void predictSingleChunk(String modelId, MLInput mlInput, ActionListener listener) { - mlClient.predict( - modelId, - mlInput, - ActionListener.wrap(mlOutput -> listener.onResponse(transformer.extractResponseContent(mlOutput)), listener::onFailure) + log.debug("DEBUG: predictSingleChunk called with modelId: {}", modelId); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map params = dataset.getParameters(); + log.debug( + "DEBUG: MLInput parameters - has response_format: {}, has messages: {}", + params.containsKey("response_format"), + params.containsKey("messages") ); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + log.debug("DEBUG: ML prediction succeeded, extracting response content"); + listener.onResponse(transformer.extractResponseContent(mlOutput)); + }, e -> { + log.debug("DEBUG: ML prediction failed with error: {}", e.getMessage()); + listener.onFailure(e); + })); } } diff --git a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java index f23bcb46..c1141521 100644 --- a/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java +++ b/src/main/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformer.java @@ -7,13 +7,16 @@ */ package org.opensearch.searchrelevance.ml; -import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH; -import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH_WITH_REFERENCE; import static org.opensearch.searchrelevance.common.MLConstants.PARAM_MESSAGES_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_JSON_MESSAGES_SHELL; -import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_BINARY; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_SEARCH_RELEVANCE_SCORE_END; +import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_BINARY_SCHEMA; +import static org.opensearch.searchrelevance.common.MLConstants.RATING_SCORE_NUMERIC_SCHEMA; import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CHOICES_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_CONTENT_FIELD; +import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_FORMAT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.RESPONSE_MESSAGE_FIELD; import static org.opensearch.searchrelevance.common.MLConstants.escapeJson; @@ -23,7 +26,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.util.CollectionUtils; @@ -35,6 +37,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import lombok.extern.log4j.Log4j2; @@ -44,7 +47,14 @@ @Log4j2 public class MLInputOutputTransformer { - public List createMLInputs(int tokenLimit, String searchText, String reference, Map hits) { + public List createMLInputs( + int tokenLimit, + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { List mlInputs = new ArrayList<>(); Map currentChunk = new HashMap<>(); @@ -52,14 +62,14 @@ public List createMLInputs(int tokenLimit, String searchText, String re Map tempChunk = new HashMap<>(currentChunk); tempChunk.put(entry.getKey(), entry.getValue()); - String messages = formatMessages(searchText, reference, tempChunk); + String messages = buildMessagesArray(searchText, referenceData, tempChunk, promptTemplate, ratingType); int totalTokens = TokenizerUtil.countTokens(messages); if (totalTokens > tokenLimit) { if (currentChunk.isEmpty()) { - mlInputs.add(handleOversizedEntry(entry, searchText, reference, tokenLimit)); + mlInputs.add(handleOversizedEntry(entry, searchText, referenceData, tokenLimit, promptTemplate, ratingType)); } else { - mlInputs.add(createMLInput(searchText, reference, currentChunk)); + mlInputs.add(createMLInput(searchText, referenceData, currentChunk, promptTemplate, ratingType)); currentChunk = new HashMap<>(); currentChunk.put(entry.getKey(), entry.getValue()); } @@ -69,43 +79,117 @@ public List createMLInputs(int tokenLimit, String searchText, String re } if (!currentChunk.isEmpty()) { - mlInputs.add(createMLInput(searchText, reference, currentChunk)); + mlInputs.add(createMLInput(searchText, referenceData, currentChunk, promptTemplate, ratingType)); } return mlInputs; } - private MLInput handleOversizedEntry(Map.Entry entry, String searchText, String reference, int tokenLimit) { + private MLInput handleOversizedEntry( + Map.Entry entry, + String searchText, + Map referenceData, + int tokenLimit, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { log.warn("Entry with key {} causes total tokens to exceed limit of {}", entry.getKey(), tokenLimit); Map testChunk = Map.of(entry.getKey(), entry.getValue()); - String testMessages = formatMessages(searchText, reference, testChunk); + String testMessages = buildMessagesArray(searchText, referenceData, testChunk, promptTemplate, ratingType); int excessTokens = TokenizerUtil.countTokens(testMessages) - tokenLimit; int currentTokens = TokenizerUtil.countTokens(entry.getValue()); String truncatedValue = TokenizerUtil.truncateString(entry.getValue(), Math.max(1, currentTokens - excessTokens)); Map singleEntryChunk = Map.of(entry.getKey(), truncatedValue); - return createMLInput(searchText, reference, singleEntryChunk); + return createMLInput(searchText, referenceData, singleEntryChunk, promptTemplate, ratingType); } - public MLInput createMLInput(String searchText, String reference, Map hits) { + public MLInput createMLInput( + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { + return createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + } + + /** + * Creates MLInput with optional response_format parameter. + * Some models (like GPT-3.5) don't support response_format, so we can disable it for fallback. + * + * @param includeResponseFormat If true, includes response_format parameter; if false, excludes it + */ + public MLInput createMLInput( + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType, + boolean includeResponseFormat + ) { Map parameters = new HashMap<>(); - parameters.put(PARAM_MESSAGES_FIELD, formatMessages(searchText, reference, hits)); + String messagesArray = buildMessagesArray(searchText, referenceData, hits, promptTemplate, ratingType); + + parameters.put(PARAM_MESSAGES_FIELD, messagesArray); + + // Only add response_format if requested (for models that support it) + if (includeResponseFormat) { + String responseFormat = getResponseFormat(ratingType); + parameters.put("response_format", responseFormat); + } + return MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(new RemoteInferenceInputDataSet(parameters)).build(); } - public String formatMessages(String searchText, String reference, Map hits) { + private String buildMessagesArray( + String searchText, + Map referenceData, + Map hits, + String promptTemplate, + LLMJudgmentRatingType ratingType + ) { try { String hitsJson = buildHitsJson(hits); - String userContent = buildUserContent(searchText, reference, hitsJson); - return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, PROMPT_SEARCH_RELEVANCE, escapeJson(userContent)); + String userContent = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, promptTemplate); + String systemPrompt = getSystemPrompt(ratingType); + return String.format(Locale.ROOT, PROMPT_JSON_MESSAGES_SHELL, systemPrompt, escapeJson(userContent)); } catch (IOException e) { log.error("Error converting hits to JSON string", e); throw new IllegalArgumentException("Failed to process hits", e); } } + private static String getSystemPrompt(LLMJudgmentRatingType ratingType) { + String systemPromptStart; + String systemPromptEnd = PROMPT_SEARCH_RELEVANCE_SCORE_END; + switch (ratingType) { + case LLMJudgmentRatingType.SCORE0_1: + systemPromptStart = PROMPT_SEARCH_RELEVANCE_SCORE_0_1_START; + break; + default: + systemPromptStart = PROMPT_SEARCH_RELEVANCE_SCORE_BINARY; + } + return systemPromptStart + systemPromptEnd; + } + + private static String getResponseFormat(LLMJudgmentRatingType ratingType) { + String schema; + switch (ratingType) { + case LLMJudgmentRatingType.SCORE0_1: + schema = RATING_SCORE_NUMERIC_SCHEMA; + break; + case LLMJudgmentRatingType.RELEVANT_IRRELEVANT: + schema = RATING_SCORE_BINARY_SCHEMA; + break; + default: + schema = RATING_SCORE_NUMERIC_SCHEMA; + } + return String.format(Locale.ROOT, RESPONSE_FORMAT_TEMPLATE, schema); + } + private String buildHitsJson(Map hits) throws IOException { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { builder.startArray(); @@ -120,14 +204,6 @@ private String buildHitsJson(Map hits) throws IOException { } } - private String buildUserContent(String searchText, String reference, String hitsJson) { - if (Objects.isNull(reference) || reference.isEmpty()) { - return String.format(Locale.ROOT, INPUT_FORMAT_SEARCH, searchText, hitsJson); - } else { - return String.format(Locale.ROOT, INPUT_FORMAT_SEARCH_WITH_REFERENCE, searchText, reference, hitsJson); - } - } - public String extractResponseContent(MLOutput mlOutput) { if (!(mlOutput instanceof ModelTensorOutput)) { throw new IllegalArgumentException("Expected ModelTensorOutput, but got " + mlOutput.getClass().getSimpleName()); diff --git a/src/main/java/org/opensearch/searchrelevance/ml/UserPromptFactory.java b/src/main/java/org/opensearch/searchrelevance/ml/UserPromptFactory.java new file mode 100644 index 00000000..fd9711b4 --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/ml/UserPromptFactory.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH; +import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH_WITH_REFERENCE; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_HITS; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_QUERY_TEXT; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_REFERENCE; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_REFERENCE_ANSWER; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_RESULTS; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_SEARCH_TEXT; + +import java.util.Locale; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Factory class for building user prompts with template variable replacement. + * Handles both custom prompt templates and default formats. + */ +public class UserPromptFactory { + + private static final Pattern TEMPLATE_VARIABLE_PATTERN = Pattern.compile("\\{\\{([^}]+)\\}\\}"); + + private UserPromptFactory() {} + + /** + * Build user content for the LLM prompt. + * If promptTemplate is provided, replaces template variables with actual values. + * If promptTemplate is null/empty, uses default INPUT_FORMAT_SEARCH or INPUT_FORMAT_SEARCH_WITH_REFERENCE. + * + * @param searchText The search query text + * @param referenceData Map of reference data (e.g., {"referenceAnswer": "value", "category": "value"}) + * @param hitsJson The JSON string representation of search hits + * @param promptTemplate Optional custom prompt template with {{variable}} placeholders + * @return The formatted user content string + */ + public static String buildUserContent(String searchText, Map referenceData, String hitsJson, String promptTemplate) { + // If no template provided, use default format + if (promptTemplate == null || promptTemplate.trim().isEmpty()) { + return buildDefaultUserContent(searchText, referenceData, hitsJson); + } + + // Replace template variables + return replaceTemplateVariables(promptTemplate, searchText, referenceData, hitsJson); + } + + /** + * Build default user content using INPUT_FORMAT_SEARCH or INPUT_FORMAT_SEARCH_WITH_REFERENCE. + */ + private static String buildDefaultUserContent(String searchText, Map referenceData, String hitsJson) { + if (referenceData == null || referenceData.isEmpty()) { + return String.format(Locale.ROOT, INPUT_FORMAT_SEARCH, searchText, hitsJson); + } else { + // Use referenceAnswer if available, otherwise use all reference data as a single string + String referenceValue = getReferenceValue(referenceData); + return String.format(Locale.ROOT, INPUT_FORMAT_SEARCH_WITH_REFERENCE, searchText, referenceValue, hitsJson); + } + } + + /** + * Get reference value from referenceData map. + * Prioritizes "referenceAnswer" key, falls back to concatenating all values. + */ + private static String getReferenceValue(Map referenceData) { + if (referenceData.containsKey(PLACEHOLDER_REFERENCE_ANSWER)) { + return referenceData.get(PLACEHOLDER_REFERENCE_ANSWER); + } + // Fallback: concatenate all values with delimiter + return String.join("; ", referenceData.values()); + } + + /** + * Replace template variables in the prompt template with actual values. + * Supports placeholders like {{variable_name}}. + * + * Supported variables: + * - {{queryText}} or {{searchText}} - replaced with the search query + * - {{reference}} or {{referenceAnswer}} - replaced with reference answer if available + * - {{hits}} or {{results}} - replaced with the JSON string of search hits + * - {{key_name}} - any key from referenceData map (e.g., {{category}}, {{expectedScore}}) + * + * @param template The template string with {{variable}} placeholders + * @param searchText The search query text + * @param referenceData Map of reference data + * @param hitsJson The JSON string representation of search hits + * @return The template with all placeholders replaced + */ + private static String replaceTemplateVariables(String template, String searchText, Map referenceData, String hitsJson) { + if (template == null || template.isEmpty()) { + return ""; + } + + String result = template; + Matcher matcher = TEMPLATE_VARIABLE_PATTERN.matcher(template); + + while (matcher.find()) { + String variableName = matcher.group(1).trim(); + String replacement = getVariableValue(variableName, searchText, referenceData, hitsJson); + result = result.replace("{{" + variableName + "}}", replacement); + } + + return result; + } + + /** + * Get the value for a template variable. + */ + private static String getVariableValue(String variableName, String searchText, Map referenceData, String hitsJson) { + // Handle queryText/searchText + if (PLACEHOLDER_QUERY_TEXT.equals(variableName) || PLACEHOLDER_SEARCH_TEXT.equals(variableName)) { + return searchText != null ? searchText : ""; + } + + // Handle hits/results + if (PLACEHOLDER_HITS.equals(variableName) || PLACEHOLDER_RESULTS.equals(variableName)) { + return hitsJson != null ? hitsJson : ""; + } + + // Handle reference/referenceAnswer + if (PLACEHOLDER_REFERENCE.equals(variableName) || PLACEHOLDER_REFERENCE_ANSWER.equals(variableName)) { + if (referenceData != null && referenceData.containsKey(PLACEHOLDER_REFERENCE_ANSWER)) { + return referenceData.get(PLACEHOLDER_REFERENCE_ANSWER); + } + return ""; + } + + // Handle any custom key from referenceData + if (referenceData != null && referenceData.containsKey(variableName)) { + return referenceData.get(variableName); + } + + // Variable not found, return empty string + return ""; + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/model/JudgmentCache.java b/src/main/java/org/opensearch/searchrelevance/model/JudgmentCache.java index 21525aac..de1b8266 100644 --- a/src/main/java/org/opensearch/searchrelevance/model/JudgmentCache.java +++ b/src/main/java/org/opensearch/searchrelevance/model/JudgmentCache.java @@ -23,6 +23,7 @@ public class JudgmentCache implements ToXContentObject { public static final String TIME_STAMP = "timestamp"; public static final String RATING = "rating"; public static final String MODEL_ID = "modelId"; + public static final String PROMPT_TEMPLATE_ID = "encodedPromptTemplate"; /** * Identifier of the system index @@ -34,6 +35,7 @@ public class JudgmentCache implements ToXContentObject { private String contextFieldsStr; private String rating; private String modelId; + private String promptTemplateId; public JudgmentCache( String id, @@ -42,7 +44,8 @@ public JudgmentCache( String documentId, List contextFields, String rating, - String modelId + String modelId, + String promptTemplateId ) { this.id = id; this.timestamp = timestamp; @@ -51,6 +54,7 @@ public JudgmentCache( this.contextFieldsStr = convertListToSortedStr(contextFields); this.rating = rating; this.modelId = modelId; + this.promptTemplateId = promptTemplateId; } @Override @@ -63,6 +67,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(CONTEXT_FIELDS_STR, this.contextFieldsStr); xContentBuilder.field(RATING, this.rating.trim()); xContentBuilder.field(MODEL_ID, this.modelId.trim()); + xContentBuilder.field(PROMPT_TEMPLATE_ID, this.promptTemplateId.trim()); return xContentBuilder.endObject(); } diff --git a/src/main/java/org/opensearch/searchrelevance/model/LLMJudgmentRatingType.java b/src/main/java/org/opensearch/searchrelevance/model/LLMJudgmentRatingType.java new file mode 100644 index 00000000..5503fe7e --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/model/LLMJudgmentRatingType.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.model; + +import java.io.IOException; +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +public enum LLMJudgmentRatingType implements Writeable { + SCORE0_1, + RELEVANT_IRRELEVANT; + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + public static LLMJudgmentRatingType readFromStream(StreamInput in) throws IOException { + return in.readEnum(LLMJudgmentRatingType.class); + } + + /** + * Get a comma-separated string of all valid rating type values. + * @return String containing all valid enum values + */ + public static String getValidValues() { + return Arrays.stream(LLMJudgmentRatingType.values()).map(Enum::name).collect(Collectors.joining(", ")); + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/model/QueryWithReference.java b/src/main/java/org/opensearch/searchrelevance/model/QueryWithReference.java index a92670e7..79a2407a 100644 --- a/src/main/java/org/opensearch/searchrelevance/model/QueryWithReference.java +++ b/src/main/java/org/opensearch/searchrelevance/model/QueryWithReference.java @@ -8,6 +8,8 @@ package org.opensearch.searchrelevance.model; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import java.util.Objects; import org.opensearch.core.common.io.stream.StreamInput; @@ -16,32 +18,42 @@ public class QueryWithReference implements Writeable { private final String queryText; - private final String referenceAnswer; + private final Map customizedKeyValueMap; public final static String DELIMITER = "#"; - public QueryWithReference(String queryText, String referenceAnswer) { + public QueryWithReference(String queryText, Map customizedKeyValueMap) { this.queryText = queryText; - this.referenceAnswer = referenceAnswer; + this.customizedKeyValueMap = customizedKeyValueMap != null ? customizedKeyValueMap : Collections.emptyMap(); } public QueryWithReference(StreamInput in) throws IOException { this.queryText = in.readString(); - this.referenceAnswer = in.readString(); + boolean hasCustomizedKeyValueMap = in.readBoolean(); + if (hasCustomizedKeyValueMap) { + this.customizedKeyValueMap = in.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.customizedKeyValueMap = Collections.emptyMap(); + } } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(queryText); - out.writeString(referenceAnswer); + if (customizedKeyValueMap != null && !customizedKeyValueMap.isEmpty()) { + out.writeBoolean(true); + out.writeMap(customizedKeyValueMap, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } } public String getQueryText() { return queryText; } - public String getReferenceAnswer() { - return referenceAnswer; + public Map getCustomizedKeyValueMap() { + return customizedKeyValueMap; } @Override @@ -49,16 +61,16 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; QueryWithReference that = (QueryWithReference) o; - return Objects.equals(queryText, that.queryText) && Objects.equals(referenceAnswer, that.referenceAnswer); + return Objects.equals(queryText, that.queryText) && Objects.equals(customizedKeyValueMap, that.customizedKeyValueMap); } @Override public int hashCode() { - return Objects.hash(queryText, referenceAnswer); + return Objects.hash(queryText, customizedKeyValueMap); } @Override public String toString() { - return "QueryWithReference{" + "queryText='" + queryText + '\'' + ", referenceAnswer='" + referenceAnswer + '\'' + '}'; + return "QueryWithReference{" + "queryText='" + queryText + '\'' + ", customizedKeyValueMap=" + customizedKeyValueMap + '}'; } } diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java index beb590e2..f82ebbbb 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java @@ -9,6 +9,10 @@ import static java.util.Collections.singletonList; import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MLConstants.validateTokenLimit; import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; import static org.opensearch.searchrelevance.common.PluginConstants.CLICK_MODEL; @@ -28,6 +32,7 @@ import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -44,6 +49,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; import org.opensearch.searchrelevance.transport.judgment.PutImportJudgmentRequest; import org.opensearch.searchrelevance.transport.judgment.PutJudgmentAction; @@ -126,6 +132,40 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli int tokenLimit = validateTokenLimit(source); List contextFields = ParserUtils.convertObjToList(source, CONTEXT_FIELDS); + + // Prompt template - validate and use simple default if not provided + String promptTemplate = (String) source.get(PROMPT_TEMPLATE); + + // Validate prompt template contains required {{hits}} or {{results}} placeholder + TextValidationUtil.ValidationResult promptValidation = TextValidationUtil.validatePromptTemplate(promptTemplate); + if (!promptValidation.isValid()) { + throw new SearchRelevanceException(promptValidation.getErrorMessage(), RestStatus.BAD_REQUEST); + } + + if (promptTemplate == null || promptTemplate.trim().isEmpty()) { + promptTemplate = DEFAULT_PROMPT_TEMPLATE; + } + + // Rating type - can be null, will be validated at processor level + String llmJudgmentRatingTypeStr = (String) source.get(LLM_JUDGMENT_RATING_TYPE); + LLMJudgmentRatingType llmJudgmentRatingType = null; + if (llmJudgmentRatingTypeStr != null) { + try { + llmJudgmentRatingType = LLMJudgmentRatingType.valueOf(llmJudgmentRatingTypeStr); + } catch (IllegalArgumentException e) { + throw new SearchRelevanceException( + String.format( + Locale.ROOT, + "Invalid RatingType: '%s'. Valid values are: %s", + llmJudgmentRatingTypeStr, + LLMJudgmentRatingType.getValidValues() + ), + RestStatus.BAD_REQUEST + ); + } + } + boolean overwriteCache = Optional.ofNullable((Boolean) source.get(OVERWRITE_CACHE)).orElse(Boolean.FALSE); + createRequest = new PutLlmJudgmentRequest( type, name, @@ -136,7 +176,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli size, tokenLimit, contextFields, - ignoreFailure + ignoreFailure, + promptTemplate, + llmJudgmentRatingType, + overwriteCache ); } case UBI_JUDGMENT -> { diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutQuerySetAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutQuerySetAction.java index 5c2b5ec1..666f9155 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutQuerySetAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutQuerySetAction.java @@ -91,27 +91,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (rawQueries.size() > settingsAccessor.getMaxQuerySetAllowed()) { return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.FORBIDDEN, "Query Set Limit Exceeded.")); } + + // Validate and parse each query using the utility method try { querySetQueries = rawQueries.stream().map(obj -> { - Map queryMap = (Map) obj; - String queryText = queryMap.get("queryText"); - String referenceAnswer = queryMap.getOrDefault("referenceAnswer", ""); - - // Validate queryText - TextValidationUtil.ValidationResult queryTextValidation = TextValidationUtil.validateText(queryText); - if (!queryTextValidation.isValid()) { - throw new IllegalArgumentException("Invalid queryText: " + queryTextValidation.getErrorMessage()); - } + Map queryMap = (Map) obj; + TextValidationUtil.QueryValidationResult validationResult = TextValidationUtil.validateAndParseQuery(queryMap); - // Validate referenceAnswer if it's not empty - if (!referenceAnswer.isEmpty()) { - TextValidationUtil.ValidationResult referenceAnswerValidation = TextValidationUtil.validateText(referenceAnswer); - if (!referenceAnswerValidation.isValid()) { - throw new IllegalArgumentException("Invalid referenceAnswer: " + referenceAnswerValidation.getErrorMessage()); - } + if (!validationResult.isValid()) { + throw new IllegalArgumentException(validationResult.getErrorMessage()); } - return new QueryWithReference(queryText, referenceAnswer); + return validationResult.getQueryWithReference(); }).collect(Collectors.toList()); } catch (IllegalArgumentException e) { return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, e.getMessage())); diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 64eb9f04..0c3c3df6 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -7,6 +7,9 @@ */ package org.opensearch.searchrelevance.transport.judgment; +import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; +import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; +import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; import static org.opensearch.searchrelevance.common.MetricsConstants.MODEL_ID; import static org.opensearch.searchrelevance.ubi.UbiValidator.checkUbiIndicesExist; @@ -103,6 +106,9 @@ private Map buildMetadata(PutJudgmentRequest request) { metadata.put("tokenLimit", llmRequest.getTokenLimit()); metadata.put("contextFields", llmRequest.getContextFields()); metadata.put("ignoreFailure", llmRequest.isIgnoreFailure()); + metadata.put(PROMPT_TEMPLATE, llmRequest.getPromptTemplate()); + metadata.put(LLM_JUDGMENT_RATING_TYPE, llmRequest.getLlmJudgmentRatingType()); + metadata.put(OVERWRITE_CACHE, llmRequest.isOverwriteCache()); } case UBI_JUDGMENT -> { if (!checkUbiIndicesExist(clusterService)) { diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java index be29ef4b..24328e9b 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java @@ -13,6 +13,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import reactor.util.annotation.NonNull; @@ -41,6 +42,21 @@ public class PutLlmJudgmentRequest extends PutJudgmentRequest { */ private boolean ignoreFailure; + /** + * Customized prompt template input by customers. + */ + private String promptTemplate; // contains place_holder with vals defined in QuerySet + + /** + * Output type defined for prefilled prompt and JSON output processor + */ + private LLMJudgmentRatingType llmJudgmentRatingType; + + /** + * Flag to indicate whether to use judgment cache + */ + private boolean overwriteCache; + public PutLlmJudgmentRequest( @NonNull JudgmentType type, @NonNull String name, @@ -51,7 +67,10 @@ public PutLlmJudgmentRequest( int size, int tokenLimit, List contextFields, - boolean ignoreFailure + boolean ignoreFailure, + String promptTemplate, + LLMJudgmentRatingType llmJudgmentRatingType, + boolean overwriteCache ) { super(type, name, description); this.modelId = modelId; @@ -61,6 +80,9 @@ public PutLlmJudgmentRequest( this.tokenLimit = tokenLimit; this.contextFields = contextFields; this.ignoreFailure = ignoreFailure; + this.promptTemplate = promptTemplate; + this.llmJudgmentRatingType = llmJudgmentRatingType; + this.overwriteCache = overwriteCache; } public PutLlmJudgmentRequest(StreamInput in) throws IOException { @@ -72,6 +94,9 @@ public PutLlmJudgmentRequest(StreamInput in) throws IOException { this.tokenLimit = in.readOptionalInt(); this.contextFields = in.readOptionalStringList(); this.ignoreFailure = Boolean.TRUE.equals(in.readOptionalBoolean()); // by defaulted as false if not provided + this.promptTemplate = in.readOptionalString(); + this.llmJudgmentRatingType = in.readOptionalWriteable(LLMJudgmentRatingType::readFromStream); + this.overwriteCache = Boolean.TRUE.equals(in.readOptionalBoolean()); } @Override @@ -84,6 +109,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(tokenLimit); out.writeOptionalStringArray(contextFields.toArray(new String[0])); out.writeOptionalBoolean(ignoreFailure); + out.writeOptionalString(promptTemplate); + out.writeOptionalWriteable(llmJudgmentRatingType); + out.writeOptionalBoolean(overwriteCache); } public String getModelId() { @@ -114,4 +142,16 @@ public boolean isIgnoreFailure() { return ignoreFailure; } + public String getPromptTemplate() { + return promptTemplate; + } + + public LLMJudgmentRatingType getLlmJudgmentRatingType() { + return llmJudgmentRatingType; + } + + public boolean isOverwriteCache() { + return overwriteCache; + } + } diff --git a/src/main/java/org/opensearch/searchrelevance/transport/queryset/PutQuerySetTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/queryset/PutQuerySetTransportAction.java index 91b04766..0753de68 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/queryset/PutQuerySetTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/queryset/PutQuerySetTransportAction.java @@ -29,7 +29,11 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + public class PutQuerySetTransportAction extends HandledTransportAction { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private final ClusterService clusterService; private final QuerySetDao querySetDao; @@ -72,24 +76,37 @@ protected void doExecute(Task task, PutQuerySetRequest request, ActionListener convertQuerySetQueriesList(List queryWithReferenceList) { return queryWithReferenceList.stream().map(queryWithReference -> { - String queryText; - if (queryWithReference.getReferenceAnswer() != null && !queryWithReference.getReferenceAnswer().isEmpty()) { - queryText = String.join(DELIMITER, queryWithReference.getQueryText(), queryWithReference.getReferenceAnswer()); - } else { - queryText = queryWithReference.getQueryText(); + StringBuilder queryTextBuilder = new StringBuilder(queryWithReference.getQueryText()); + + // Append customizedKeyValueMap as JSON format + if (queryWithReference.getCustomizedKeyValueMap() != null && !queryWithReference.getCustomizedKeyValueMap().isEmpty()) { + try { + queryTextBuilder.append(DELIMITER); + queryTextBuilder.append(OBJECT_MAPPER.writeValueAsString(queryWithReference.getCustomizedKeyValueMap())); + } catch (JsonProcessingException e) { + throw new SearchRelevanceException( + "Failed to serialize custom fields to JSON: " + e.getMessage(), + RestStatus.INTERNAL_SERVER_ERROR + ); + } } - return QuerySetEntry.Builder.builder().queryText(queryText).build(); + + return QuerySetEntry.Builder.builder().queryText(queryTextBuilder.toString()).build(); }).collect(Collectors.toList()); } } diff --git a/src/main/java/org/opensearch/searchrelevance/utils/ParserUtils.java b/src/main/java/org/opensearch/searchrelevance/utils/ParserUtils.java index 5970c3ee..342d9785 100644 --- a/src/main/java/org/opensearch/searchrelevance/utils/ParserUtils.java +++ b/src/main/java/org/opensearch/searchrelevance/utils/ParserUtils.java @@ -9,18 +9,31 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.RestRequest; +import org.opensearch.searchrelevance.model.QueryWithReference; import org.opensearch.searchrelevance.model.SearchParams; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + public class ParserUtils { + private static final Logger LOGGER = LogManager.getLogger(ParserUtils.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final String SHA_256_ALGORITHM = "SHA-256"; + public static SearchParams parseSearchParams(RestRequest request) throws IOException { SearchParams.Builder builder = SearchParams.builder(); @@ -127,7 +140,81 @@ public static String combinedIndexAndDocId(String index, String docId) { } public static String getDocIdFromCompositeKey(String compositeKey) { - return compositeKey.split("::")[1]; + // Handle both composite keys (index::docId) and plain docIds + // LLM may return just docId instead of the full composite key + if (compositeKey.contains("::")) { + return compositeKey.split("::")[1]; + } + return compositeKey; + } + + /** + * Generate a hash code from prompt template and rating type + * @param promptTemplate the prompt template string + * @param ratingType the rating type enum (can be null) + * @return SHA-256 hash as hexadecimal string + */ + public static String generatePromptTemplateCode(String promptTemplate, Object ratingType) { + try { + String input = (promptTemplate != null ? promptTemplate : "") + "::" + (ratingType != null ? ratingType.toString() : ""); + MessageDigest digest = MessageDigest.getInstance(SHA_256_ALGORITHM); + byte[] hash = digest.digest(input.getBytes(StandardCharsets.UTF_8)); + + // Convert to hexadecimal string + StringBuilder hexString = new StringBuilder(); + for (byte b : hash) { + String hex = Integer.toHexString(0xff & b); + if (hex.length() == 1) { + hexString.append('0'); + } + hexString.append(hex); + } + return hexString.toString(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-256 algorithm not available", e); + } + } + + /** + * Parse query text with custom input to extract query and reference data. + * Supports two formats: + * - Current format: "queryText#{"key1":"value1","key2":"value2"}" (JSON) + * - Legacy format: "queryText#referenceAnswer" (plain text) + * + * @param queryTextWithCustomInput the query text with optional custom input + * @return a map with "queryText" and optional reference data entries + */ + public static Map parseQueryTextWithCustomInput(String queryTextWithCustomInput) { + Map result = new HashMap<>(); + String[] queryTextRefArr = queryTextWithCustomInput.split(QueryWithReference.DELIMITER, 2); + String queryText = queryTextRefArr[0]; + result.put("queryText", queryText); + + if (queryTextRefArr.length > 1 && !queryTextRefArr[1].isEmpty()) { + String referenceContent = queryTextRefArr[1]; + + // Try to parse as JSON first (current format) + if (referenceContent.trim().startsWith("{") && referenceContent.trim().endsWith("}")) { + try { + Map jsonMap = OBJECT_MAPPER.readValue(referenceContent, new TypeReference>() { + }); + result.putAll(jsonMap); + return result; + } catch (Exception e) { + LOGGER.debug( + "Failed to parse reference content as JSON, falling back to legacy format. Content: '{}', Error: {}", + referenceContent, + e.getMessage() + ); + // Not valid JSON, fall through to legacy format + } + } + + // Legacy format: queryText#referenceAnswer + result.put("referenceAnswer", referenceContent); + } + + return result; } } diff --git a/src/main/java/org/opensearch/searchrelevance/utils/RatingOutputProcessor.java b/src/main/java/org/opensearch/searchrelevance/utils/RatingOutputProcessor.java new file mode 100644 index 00000000..e718948a --- /dev/null +++ b/src/main/java/org/opensearch/searchrelevance/utils/RatingOutputProcessor.java @@ -0,0 +1,362 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.utils; + +import static org.opensearch.searchrelevance.common.MLConstants.IRRELEVANT_DECISION_STRING; +import static org.opensearch.searchrelevance.common.MLConstants.RELEVANT_DECISION_STRING; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Processor for handling LLM rating outputs with structured JSON parsing. + * When using OpenAI's structured output feature, responses should already be properly formatted JSON. + */ +public class RatingOutputProcessor { + + private static final Logger log = LogManager.getLogger(RatingOutputProcessor.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private RatingOutputProcessor() {} + + /** + * Parse and extract the ratings array from LLM output. + * Handles both structured output (GPT-4o with response_format) and unstructured output (GPT-3.5). + * + * For structured output: {"ratings": [{"id": "...", "rating_score": ...}, ...]} + * For unstructured output: Extracts JSON from markdown code blocks or embedded JSON patterns + * + * @param response The raw LLM response + * @return JSON array string containing the ratings + */ + public static String sanitizeLLMResponse(String response) { + if (response == null || response.trim().isEmpty()) { + return "[]"; + } + + try { + // Try to parse as structured JSON first (GPT-4o with response_format) + JsonNode rootNode = OBJECT_MAPPER.readTree(response); + + // Extract the "ratings" array if it exists + if (rootNode.has("ratings")) { + JsonNode ratingsArray = rootNode.get("ratings"); + if (ratingsArray.isArray()) { + return ratingsArray.toString(); + } + } + + // If the response is already an array, return it as-is + if (rootNode.isArray()) { + return rootNode.toString(); + } + + // If response is a single object, wrap it in an array + if (rootNode.isObject()) { + return "[" + response + "]"; + } + + return "[]"; + } catch (JsonProcessingException e) { + // If JSON parsing fails, try to extract JSON from unstructured text (GPT-3.5) + return extractJsonFromUnstructuredText(response); + } + } + + /** + * Extracts JSON from unstructured text responses (for models like GPT-3.5 that don't support structured output). + * Handles markdown code blocks and embedded JSON patterns. + */ + private static String extractJsonFromUnstructuredText(String response) { + if (response == null || response.trim().isEmpty()) { + log.debug("Empty or null response, returning empty array"); + return "[]"; + } + + log.debug("Attempting to extract JSON from unstructured text. Response length: {}", response.length()); + + // Try to extract JSON from markdown code blocks (```json ... ``` or ``` ... ```) + String jsonContent = extractFromMarkdownCodeBlock(response); + if (jsonContent != null) { + log.debug("Found markdown code block, attempting to parse"); + try { + JsonNode node = OBJECT_MAPPER.readTree(jsonContent); + if (node.has("ratings") && node.get("ratings").isArray()) { + log.debug("Successfully extracted ratings array from code block"); + return node.get("ratings").toString(); + } + if (node.isArray()) { + log.debug("Successfully extracted array from code block"); + return node.toString(); + } + } catch (JsonProcessingException e) { + log.debug("Failed to parse JSON from code block: {}", e.getMessage()); + // Continue to next extraction method + } + } + + // Try to find JSON object or array patterns in the text + jsonContent = extractJsonPattern(response); + if (jsonContent != null) { + log.debug("Found JSON pattern, attempting to parse. Length: {}", jsonContent.length()); + try { + JsonNode node = OBJECT_MAPPER.readTree(jsonContent); + if (node.has("ratings") && node.get("ratings").isArray()) { + log.debug("Successfully extracted ratings array from pattern"); + return node.get("ratings").toString(); + } + if (node.isArray()) { + log.debug("Successfully extracted array from pattern"); + return node.toString(); + } + // If it's an object with ratings, extract it + if (node.isObject()) { + log.debug("Wrapping object in array"); + return "[" + jsonContent + "]"; + } + } catch (JsonProcessingException e) { + log.warn("Failed to parse extracted JSON pattern. Error: {}. Extracted content: {}", e.getMessage(), jsonContent); + // Parsing failed, return empty array + } + } else { + log.warn( + "No JSON pattern found in response. Response preview: {}", + response.length() > 200 ? response.substring(0, 200) + "..." : response + ); + } + + return "[]"; + } + + /** + * Extracts content from markdown code blocks. + */ + private static String extractFromMarkdownCodeBlock(String text) { + // Match ```json ... ``` or ``` ... ``` + java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("```(?:json)?\\s*\\n?([\\s\\S]*?)```"); + java.util.regex.Matcher matcher = pattern.matcher(text); + if (matcher.find()) { + return matcher.group(1).trim(); + } + return null; + } + + /** + * Extracts JSON object or array patterns from text. + * Looks for the first occurrence of a JSON structure, prioritizing arrays if they appear first. + */ + private static String extractJsonPattern(String text) { + int startObj = text.indexOf('{'); + int startArr = text.indexOf('['); + + // Determine which JSON structure appears first + if (startArr != -1 && (startObj == -1 || startArr < startObj)) { + // Array appears first or object not found + int endArr = findMatchingBracket(text, startArr); + if (endArr != -1) { + return text.substring(startArr, endArr + 1); + } + } + + // Try to extract object if array extraction failed or object appears first + if (startObj != -1) { + int endObj = findMatchingBrace(text, startObj); + if (endObj != -1) { + return text.substring(startObj, endObj + 1); + } + } + + // Fallback: try array again if object extraction failed + if (startArr != -1) { + int endArr = findMatchingBracket(text, startArr); + if (endArr != -1) { + return text.substring(startArr, endArr + 1); + } + } + + return null; + } + + /** + * Finds the matching closing brace for an opening brace using a state machine + * that properly handles strings and escaped characters. + * + * This is a heuristic approach since we don't have access to a full JSON parser state, + * but it handles most common LLM response patterns correctly. + * + * @param text The text to search + * @param start The index of the opening brace + * @return The index of the matching closing brace, or -1 if not found + */ + private static int findMatchingBrace(String text, int start) { + int count = 0; + boolean inString = false; + char stringQuote = 0; // Track which quote character started the string (" or ') + boolean escaped = false; + + for (int i = start; i < text.length(); i++) { + char c = text.charAt(i); + + // Handle escape sequences + if (escaped) { + escaped = false; + continue; + } + + if (c == '\\') { + escaped = true; + continue; + } + + // Handle string boundaries + if (c == '"' || c == '\'') { + if (!inString) { + // Entering a string + inString = true; + stringQuote = c; + } else if (c == stringQuote) { + // Exiting a string (must match the opening quote) + inString = false; + stringQuote = 0; + } + continue; + } + + // Only count braces outside of strings + if (!inString) { + if (c == '{') { + count++; + } else if (c == '}') { + count--; + if (count == 0) { + return i; + } + } + } + } + + log.debug("Failed to find matching brace. Final count: {}, inString: {}", count, inString); + return -1; // No matching brace found + } + + /** + * Finds the matching closing bracket for an opening bracket using a state machine + * that properly handles strings and escaped characters. + * + * This is a heuristic approach since we don't have access to a full JSON parser state, + * but it handles most common LLM response patterns correctly. + * + * @param text The text to search + * @param start The index of the opening bracket + * @return The index of the matching closing bracket, or -1 if not found + */ + private static int findMatchingBracket(String text, int start) { + int count = 0; + boolean inString = false; + char stringQuote = 0; // Track which quote character started the string (" or ') + boolean escaped = false; + + for (int i = start; i < text.length(); i++) { + char c = text.charAt(i); + + // Handle escape sequences + if (escaped) { + escaped = false; + continue; + } + + if (c == '\\') { + escaped = true; + continue; + } + + // Handle string boundaries + if (c == '"' || c == '\'') { + if (!inString) { + // Entering a string + inString = true; + stringQuote = c; + } else if (c == stringQuote) { + // Exiting a string (must match the opening quote) + inString = false; + stringQuote = 0; + } + continue; + } + + // Only count brackets outside of strings + if (!inString) { + if (c == '[') { + count++; + } else if (c == ']') { + count--; + if (count == 0) { + return i; + } + } + } + } + + log.debug("Failed to find matching bracket. Final count: {}, inString: {}", count, inString); + return -1; // No matching bracket found + } + + /** + * Convert rating score from LLM response to double value. + * For RELEVANT_IRRELEVANT type: converts "RELEVANT" to 1.0 and "IRRELEVANT" to 0.0 + * For SCORE0_1 type: parses the number value to double + * + * Public for testing purposes. + * + * @param ratingScoreObj The rating_score object from LLM response + * @param ratingType The judgment rating type + * @return The rating score as a double value + */ + public static Double convertRatingScore(Object ratingScoreObj, LLMJudgmentRatingType ratingType) { + // Check for null rating score + if (ratingScoreObj == null) { + throw new IllegalArgumentException( + "Missing rating_score field in LLM response. Ensure the prompt template asks the LLM to return JSON with 'rating_score' field." + ); + } + + if (ratingType == LLMJudgmentRatingType.RELEVANT_IRRELEVANT) { + // Handle binary string ratings + if (!(ratingScoreObj instanceof String)) { + throw new IllegalArgumentException( + "Invalid rating_score type for RELEVANT_IRRELEVANT. Expected String but got: " + + ratingScoreObj.getClass().getSimpleName() + ); + } + String ratingStr = (String) ratingScoreObj; + if (RELEVANT_DECISION_STRING.equals(ratingStr)) { + return 1.0; + } else if (IRRELEVANT_DECISION_STRING.equals(ratingStr)) { + return 0.0; + } else { + throw new IllegalArgumentException("Invalid binary rating value: " + ratingStr + ". Expected RELEVANT or IRRELEVANT"); + } + } else { + // Handle numeric ratings (SCORE0_1) + if (!(ratingScoreObj instanceof Number)) { + throw new IllegalArgumentException( + "Invalid rating_score type for SCORE0_1. Expected Number but got: " + + ratingScoreObj.getClass().getSimpleName() + + ". Value: " + + ratingScoreObj + ); + } + return ((Number) ratingScoreObj).doubleValue(); + } + } +} diff --git a/src/main/java/org/opensearch/searchrelevance/utils/TextValidationUtil.java b/src/main/java/org/opensearch/searchrelevance/utils/TextValidationUtil.java index 697a0a1d..a6a3d6bc 100644 --- a/src/main/java/org/opensearch/searchrelevance/utils/TextValidationUtil.java +++ b/src/main/java/org/opensearch/searchrelevance/utils/TextValidationUtil.java @@ -7,12 +7,27 @@ */ package org.opensearch.searchrelevance.utils; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_HITS; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_QUERY_TEXT; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_RESULTS; +import static org.opensearch.searchrelevance.common.MLConstants.PLACEHOLDER_SEARCH_TEXT; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.searchrelevance.model.QueryWithReference; + public class TextValidationUtil { private static final int DEFAULT_MAX_TEXT_LENGTH = 2000; private static final int MAX_NAME_LENGTH = 50; private static final int MAX_DESCRIPTION_LENGTH = 250; + private static final int MAX_PROMPT_TEMPLATE_LENGTH = 10000; // Characters that could break JSON or cause security issues private static final String DANGEROUS_CHARS_PATTERN = "[\"\\\\<>]+"; // Excludes quotes, backslashes, and HTML tags + // Characters that could break QuerySet parsing logic + // Newline (\n), delimiter (#), and colon (:) are reserved for the format: "queryText#\nkey: value" + private static final String QUERYSET_RESERVED_CHARS_PATTERN = "[\\r\\n#:]+"; // Excludes newline, carriage return, #, and colon public static class ValidationResult { private final boolean valid; @@ -89,4 +104,250 @@ public static ValidationResult validateDescription(String description) { return validateText(description, MAX_DESCRIPTION_LENGTH); } + /** + * Validates QuerySet field values (queryText and custom field values). + * Checks for reserved characters that would break the QuerySet parsing logic: + * - Newline (\n) - used to separate key-value pairs in the new format + * - Hash (#) - used as delimiter between queryText and custom fields + * - Colon (:) - used to separate keys from values in the new format + * + * @param text The text to validate + * @return ValidationResult indicating if the text is valid for QuerySet + */ + public static ValidationResult validateQuerySetValue(String text) { + return validateQuerySetValue(text, DEFAULT_MAX_TEXT_LENGTH); + } + + /** + * Validates QuerySet field values with a specified maximum length. + * Checks for reserved characters that would break the QuerySet parsing logic: + * - Newline (\n) - used to separate key-value pairs in the new format + * - Hash (#) - used as delimiter between queryText and custom fields + * - Colon (:) - used to separate keys from values in the new format + * + * @param text The text to validate + * @param maxLength The maximum allowed length + * @return ValidationResult indicating if the text is valid for QuerySet + */ + public static ValidationResult validateQuerySetValue(String text, int maxLength) { + if (text == null) { + return new ValidationResult(false, "Text cannot be null"); + } + + if (text.isEmpty()) { + return new ValidationResult(false, "Text cannot be empty"); + } + + if (text.length() > maxLength) { + return new ValidationResult(false, "Text exceeds maximum length of " + maxLength + " characters"); + } + + if (text.matches(".*" + DANGEROUS_CHARS_PATTERN + ".*")) { + return new ValidationResult(false, "Text contains invalid characters (quotes, backslashes, or HTML tags are not allowed)"); + } + + // Check for reserved characters - use contains() for better detection including newlines + if (text.contains("\n") || text.contains("\r") || text.contains("#") || text.contains(":")) { + return new ValidationResult(false, "Text contains reserved characters (newline, #, or : are not allowed in QuerySet values)"); + } + + return new ValidationResult(true, null); + } + + /** + * Validates QuerySet custom field keys. + * Keys have additional restrictions to ensure they are valid identifiers. + * + * @param key The key to validate + * @return ValidationResult indicating if the key is valid + */ + public static ValidationResult validateQuerySetKey(String key) { + if (key == null) { + return new ValidationResult(false, "Key cannot be null"); + } + + if (key.isEmpty()) { + return new ValidationResult(false, "Key cannot be empty"); + } + + if (key.length() > MAX_NAME_LENGTH) { + return new ValidationResult(false, "Key exceeds maximum length of " + MAX_NAME_LENGTH + " characters"); + } + + // Keys should not contain reserved characters - use contains() for better detection including newlines + if (key.contains("\n") || key.contains("\r") || key.contains("#") || key.contains(":")) { + return new ValidationResult(false, "Key contains reserved characters (newline, #, or : are not allowed in QuerySet keys)"); + } + + // Keys should not contain whitespace (except single spaces within the key, not at start/end) + if (key.trim().length() != key.length()) { + return new ValidationResult(false, "Key cannot have leading or trailing whitespace"); + } + + // Reserved key name + if ("queryText".equals(key)) { + return new ValidationResult(false, "Key 'queryText' is reserved and cannot be used as a custom field name"); + } + + return new ValidationResult(true, null); + } + + /** + * Result class for QueryWithReference validation + */ + public static class QueryValidationResult { + private final boolean valid; + private final String errorMessage; + private final QueryWithReference queryWithReference; + + private QueryValidationResult(boolean valid, String errorMessage, QueryWithReference queryWithReference) { + this.valid = valid; + this.errorMessage = errorMessage; + this.queryWithReference = queryWithReference; + } + + public static QueryValidationResult success(QueryWithReference queryWithReference) { + return new QueryValidationResult(true, null, queryWithReference); + } + + public static QueryValidationResult failure(String errorMessage) { + return new QueryValidationResult(false, errorMessage, null); + } + + public boolean isValid() { + return valid; + } + + public String getErrorMessage() { + return errorMessage; + } + + public QueryWithReference getQueryWithReference() { + return queryWithReference; + } + } + + /** + * Validates that a prompt template contains the required placeholders and meets formatting requirements. + * - Must contain {{hits}} or {{results}} to provide documents to the LLM for rating + * - Must contain {{queryText}} or {{searchText}} to provide the search query + * - Must not contain the reserved delimiter character (#) + * - Must not exceed maximum length + * + * @param promptTemplate The prompt template to validate + * @return ValidationResult indicating if the template is valid + */ + public static ValidationResult validatePromptTemplate(String promptTemplate) { + if (promptTemplate == null || promptTemplate.trim().isEmpty()) { + // Null/empty templates are allowed - they will use defaults + return new ValidationResult(true, null); + } + + // Check length + if (promptTemplate.length() > MAX_PROMPT_TEMPLATE_LENGTH) { + return new ValidationResult(false, "Prompt template exceeds maximum length of " + MAX_PROMPT_TEMPLATE_LENGTH + " characters"); + } + + // Check for reserved delimiter character + if (promptTemplate.contains(QueryWithReference.DELIMITER)) { + return new ValidationResult( + false, + "Prompt template cannot contain the reserved delimiter character '" + + QueryWithReference.DELIMITER + + "' which is used to separate query text from custom fields" + ); + } + + // Check if template contains {{hits}} or {{results}} placeholder + boolean hasHits = promptTemplate.contains("{{" + PLACEHOLDER_HITS + "}}") + || promptTemplate.contains("{{" + PLACEHOLDER_RESULTS + "}}"); + if (!hasHits) { + return new ValidationResult( + false, + String.format( + Locale.ROOT, + "Prompt template must include either {{%s}} or {{%s}} placeholder to provide documents for rating. " + + "Example: 'Query: {{%s}}\\n\\nDocuments: {{%s}}'", + PLACEHOLDER_HITS, + PLACEHOLDER_RESULTS, + PLACEHOLDER_QUERY_TEXT, + PLACEHOLDER_HITS + ) + ); + } + + // Check if template contains {{queryText}} or {{searchText}} placeholder + boolean hasQuery = promptTemplate.contains("{{" + PLACEHOLDER_QUERY_TEXT + "}}") + || promptTemplate.contains("{{" + PLACEHOLDER_SEARCH_TEXT + "}}"); + if (!hasQuery) { + return new ValidationResult( + false, + String.format( + Locale.ROOT, + "Prompt template must include either {{%s}} or {{%s}} placeholder to provide the search query. " + + "Example: 'Query: {{%s}}\\n\\nDocuments: {{%s}}'", + PLACEHOLDER_QUERY_TEXT, + PLACEHOLDER_SEARCH_TEXT, + PLACEHOLDER_QUERY_TEXT, + PLACEHOLDER_HITS + ) + ); + } + + return new ValidationResult(true, null); + } + + /** + * Validates and parses a query map into a QueryWithReference object. + * Extracts queryText and validates all fields including custom key-value pairs. + * + * @param queryMap The raw query map from the request + * @return QueryValidationResult containing either the validated QueryWithReference or an error message + */ + public static QueryValidationResult validateAndParseQuery(Map queryMap) { + if (queryMap == null) { + return QueryValidationResult.failure("Query object cannot be null"); + } + + // Extract queryText + Object queryTextObj = queryMap.get("queryText"); + if (queryTextObj == null) { + return QueryValidationResult.failure("queryText is required"); + } + String queryText = String.valueOf(queryTextObj); + + // Validate queryText + ValidationResult queryTextValidation = validateQuerySetValue(queryText); + if (!queryTextValidation.isValid()) { + return QueryValidationResult.failure("Invalid queryText: " + queryTextValidation.getErrorMessage()); + } + + // Create customizedKeyValueMap with all entries except queryText, converting values to strings + Map customizedKeyValueMap = new HashMap<>(); + for (Map.Entry entry : queryMap.entrySet()) { + if (!"queryText".equals(entry.getKey()) && entry.getValue() != null) { + String key = entry.getKey(); + String value = String.valueOf(entry.getValue()); + + // Validate key + ValidationResult keyValidation = validateQuerySetKey(key); + if (!keyValidation.isValid()) { + return QueryValidationResult.failure("Invalid field name '" + key + "': " + keyValidation.getErrorMessage()); + } + + // Validate value (if not empty) + if (!value.isEmpty()) { + ValidationResult valueValidation = validateQuerySetValue(value); + if (!valueValidation.isValid()) { + return QueryValidationResult.failure("Invalid value for field '" + key + "': " + valueValidation.getErrorMessage()); + } + } + + customizedKeyValueMap.put(key, value); + } + } + + return QueryValidationResult.success(new QueryWithReference(queryText, customizedKeyValueMap)); + } + } diff --git a/src/main/resources/mappings/judgment_cache.json b/src/main/resources/mappings/judgment_cache.json index 09a7aaee..61fa52b8 100644 --- a/src/main/resources/mappings/judgment_cache.json +++ b/src/main/resources/mappings/judgment_cache.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 0 + "schema_version": 1 }, "properties": { "id": { "type": "keyword" }, @@ -8,6 +8,8 @@ "querySet": { "type": "keyword" }, "documentId": { "type": "keyword" }, "contextFieldsStr": { "type": "keyword" }, - "rating": { "type": "keyword" } + "rating": { "type": "keyword" }, + "modelId": { "type": "keyword"}, + "encodedPromptTemplate": { "type": "keyword"} } } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java new file mode 100644 index 00000000..91cca1c3 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentTemplateIT.java @@ -0,0 +1,411 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.action.judgment; + +import static org.opensearch.searchrelevance.common.MLConstants.DEFAULT_PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENTS_URL; +import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENT_INDEX; +import static org.opensearch.searchrelevance.common.PluginConstants.QUERYSETS_URL; +import static org.opensearch.searchrelevance.common.PluginConstants.SEARCH_CONFIGURATIONS_URL; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.Response; +import org.opensearch.rest.RestRequest; +import org.opensearch.searchrelevance.BaseSearchRelevanceIT; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import com.google.common.collect.ImmutableList; + +import lombok.SneakyThrows; + +/** + * Integration tests for LLM Judgment Template functionality. + * Tests the new fields: promptTemplate, llmJudgmentRatingType, and overwriteCache. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE) +public class LlmJudgmentTemplateIT extends BaseSearchRelevanceIT { + + private static final String TEST_INDEX = "test_llm_products"; + + @SneakyThrows + public void testLlmJudgmentWithPromptTemplate_thenSuccessful() { + // Step 1: Create test index + String indexConfig = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateTestIndex.json").toURI())); + createIndexWithConfiguration(TEST_INDEX, indexConfig); + + // Step 2: Bulk ingest test documents + String bulkData = Files.readString(Path.of(classLoader.getResource("llmjudgment/BulkIngestProducts.json").toURI())); + bulkIngest(TEST_INDEX, bulkData); + + // Step 3: Create query set with custom fields + String querySetBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateQuerySetWithCustomFields.json").toURI())); + Response querySetResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + QUERYSETS_URL, + null, + toHttpEntity(querySetBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map querySetResult = entityAsMap(querySetResponse); + String querySetId = querySetResult.get("query_set_id").toString(); + assertNotNull(querySetId); + + // Step 4: Create search configuration + String searchConfigBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateSearchConfiguration.json").toURI())); + searchConfigBody = replacePlaceholders(searchConfigBody, Map.of("index", TEST_INDEX)); + Response searchConfigResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + SEARCH_CONFIGURATIONS_URL, + null, + toHttpEntity(searchConfigBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map searchConfigResult = entityAsMap(searchConfigResponse); + String searchConfigId = searchConfigResult.get("search_configuration_id").toString(); + assertNotNull(searchConfigId); + + // Step 5: Create LLM judgment with promptTemplate + String llmJudgmentBody = Files.readString( + Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentWithPromptTemplate.json").toURI()) + ); + llmJudgmentBody = replacePlaceholders(llmJudgmentBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response llmJudgmentResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(llmJudgmentBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map llmJudgmentResult = entityAsMap(llmJudgmentResponse); + String judgmentId = llmJudgmentResult.get("judgment_id").toString(); + assertNotNull(judgmentId); + + // Step 6: Wait for judgment processing to complete + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Step 7: Verify the judgment + String getJudgmentUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId); + Response getJudgmentResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentDoc = entityAsMap(getJudgmentResponse); + assertNotNull(judgmentDoc); + assertEquals(judgmentId, judgmentDoc.get("_id")); + + Map source = (Map) judgmentDoc.get("_source"); + assertNotNull(source); + assertEquals("LLM_JUDGMENT", source.get("type")); + assertNotNull(source.get("status")); // Should be COMPLETED or IN_PROGRESS + + // Verify metadata contains new fields + Map metadata = (Map) source.get("metadata"); + assertNotNull(metadata); + assertNotNull(metadata.get("promptTemplate")); + assertTrue(((String) metadata.get("promptTemplate")).contains("{{queryText}}")); + assertNotNull(metadata.get("llmJudgmentRatingType")); + assertEquals("SCORE0_1", metadata.get("llmJudgmentRatingType")); + assertNotNull(metadata.get("overwriteCache")); + + // Verify judgmentRatings format + List> judgmentRatings = (List>) source.get("judgmentRatings"); + assertNotNull(judgmentRatings); + + // If there are judgment ratings, verify custom input format with delimiter + // Note: Ratings may be empty if no actual ML model is configured + if (!judgmentRatings.isEmpty()) { + Map firstRating = judgmentRatings.get(0); + String queryText = (String) firstRating.get("query"); + assertNotNull(queryText); + assertTrue(queryText.contains("#\n")); // Custom delimiter + assertTrue(queryText.contains("category:")); + assertTrue(queryText.contains("referenceAnswer:")); + } + } + + @SneakyThrows + public void testLlmJudgmentWithDifferentRatingTypes_thenSuccessful() { + // Create query set + String querySetBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateQuerySetSimple.json").toURI())); + Response querySetResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + QUERYSETS_URL, + null, + toHttpEntity(querySetBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map querySetResult = entityAsMap(querySetResponse); + String querySetId = querySetResult.get("query_set_id").toString(); + + // Create search configuration + String searchConfigBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateSearchConfiguration.json").toURI())); + searchConfigBody = replacePlaceholders(searchConfigBody, Map.of("index", TEST_INDEX)); + Response searchConfigResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + SEARCH_CONFIGURATIONS_URL, + null, + toHttpEntity(searchConfigBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map searchConfigResult = entityAsMap(searchConfigResponse); + String searchConfigId = searchConfigResult.get("search_configuration_id").toString(); + + // Test SCORE0_1 rating type + String score01Body = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentScore01.json").toURI())); + score01Body = replacePlaceholders(score01Body, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response score01Response = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(score01Body), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map score01Result = entityAsMap(score01Response); + String judgmentId01 = score01Result.get("judgment_id").toString(); + assertNotNull(judgmentId01); + + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify SCORE0_1 + String getJudgment01Url = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId01); + Response getJudgment01Response = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgment01Url, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgment01Doc = entityAsMap(getJudgment01Response); + Map source01 = (Map) judgment01Doc.get("_source"); + Map metadata01 = (Map) source01.get("metadata"); + assertEquals("SCORE0_1", metadata01.get("llmJudgmentRatingType")); + + // Test RELEVANT_IRRELEVANT rating type + String binaryBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentBinary.json").toURI())); + binaryBody = replacePlaceholders(binaryBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response binaryResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(binaryBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map binaryResult = entityAsMap(binaryResponse); + String judgmentIdBinary = binaryResult.get("judgment_id").toString(); + assertNotNull(judgmentIdBinary); + + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify RELEVANT_IRRELEVANT + String getJudgmentBinaryUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentIdBinary); + Response getJudgmentBinaryResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentBinaryUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentBinaryDoc = entityAsMap(getJudgmentBinaryResponse); + Map sourceBinary = (Map) judgmentBinaryDoc.get("_source"); + Map metadataBinary = (Map) sourceBinary.get("metadata"); + assertEquals("RELEVANT_IRRELEVANT", metadataBinary.get("llmJudgmentRatingType")); + } + + @SneakyThrows + public void testLlmJudgmentWithOverwriteCache_thenSuccessful() { + // Create query set + String querySetBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateQuerySetSimple.json").toURI())); + Response querySetResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + QUERYSETS_URL, + null, + toHttpEntity(querySetBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map querySetResult = entityAsMap(querySetResponse); + String querySetId = querySetResult.get("query_set_id").toString(); + + // Create search configuration + String searchConfigBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateSearchConfiguration.json").toURI())); + searchConfigBody = replacePlaceholders(searchConfigBody, Map.of("index", TEST_INDEX)); + Response searchConfigResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + SEARCH_CONFIGURATIONS_URL, + null, + toHttpEntity(searchConfigBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map searchConfigResult = entityAsMap(searchConfigResponse); + String searchConfigId = searchConfigResult.get("search_configuration_id").toString(); + + // Test with overwriteCache = true + String overwriteTrueBody = Files.readString( + Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentOverwriteTrue.json").toURI()) + ); + overwriteTrueBody = replacePlaceholders(overwriteTrueBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response overwriteTrueResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(overwriteTrueBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map overwriteTrueResult = entityAsMap(overwriteTrueResponse); + String judgmentIdTrue = overwriteTrueResult.get("judgment_id").toString(); + assertNotNull(judgmentIdTrue); + + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify overwriteCache = true + String getJudgmentTrueUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentIdTrue); + Response getJudgmentTrueResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentTrueUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentTrueDoc = entityAsMap(getJudgmentTrueResponse); + Map sourceTrue = (Map) judgmentTrueDoc.get("_source"); + Map metadataTrue = (Map) sourceTrue.get("metadata"); + assertEquals(true, metadataTrue.get("overwriteCache")); + + // Test with overwriteCache = false + String overwriteFalseBody = Files.readString( + Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentOverwriteFalse.json").toURI()) + ); + overwriteFalseBody = replacePlaceholders(overwriteFalseBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response overwriteFalseResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(overwriteFalseBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map overwriteFalseResult = entityAsMap(overwriteFalseResponse); + String judgmentIdFalse = overwriteFalseResult.get("judgment_id").toString(); + assertNotNull(judgmentIdFalse); + + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify overwriteCache = false + String getJudgmentFalseUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentIdFalse); + Response getJudgmentFalseResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentFalseUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentFalseDoc = entityAsMap(getJudgmentFalseResponse); + Map sourceFalse = (Map) judgmentFalseDoc.get("_source"); + Map metadataFalse = (Map) sourceFalse.get("metadata"); + assertEquals(false, metadataFalse.get("overwriteCache")); + } + + @SneakyThrows + public void testLlmJudgmentWithoutOptionalFields_thenSuccessfulWithDefaults() { + // Create query set + String querySetBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateQuerySetSimple.json").toURI())); + Response querySetResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + QUERYSETS_URL, + null, + toHttpEntity(querySetBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map querySetResult = entityAsMap(querySetResponse); + String querySetId = querySetResult.get("query_set_id").toString(); + + // Create search configuration + String searchConfigBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateSearchConfiguration.json").toURI())); + searchConfigBody = replacePlaceholders(searchConfigBody, Map.of("index", TEST_INDEX)); + Response searchConfigResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + SEARCH_CONFIGURATIONS_URL, + null, + toHttpEntity(searchConfigBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map searchConfigResult = entityAsMap(searchConfigResponse); + String searchConfigId = searchConfigResult.get("search_configuration_id").toString(); + + // Create LLM judgment WITHOUT promptTemplate, llmJudgmentRatingType, overwriteCache + String minimalBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentMinimal.json").toURI())); + minimalBody = replacePlaceholders(minimalBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response minimalResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(minimalBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map minimalResult = entityAsMap(minimalResponse); + String judgmentId = minimalResult.get("judgment_id").toString(); + assertNotNull(judgmentId); + + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify defaults + String getJudgmentUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId); + Response getJudgmentResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentDoc = entityAsMap(getJudgmentResponse); + Map source = (Map) judgmentDoc.get("_source"); + Map metadata = (Map) source.get("metadata"); + + // promptTemplate should have the default value when not provided + Object promptTemplate = metadata.get("promptTemplate"); + assertNotNull("promptTemplate should not be null when not provided", promptTemplate); + assertEquals("promptTemplate should have default value", DEFAULT_PROMPT_TEMPLATE, promptTemplate); + + // llmJudgmentRatingType should have a default or be null + Object ratingType = metadata.get("llmJudgmentRatingType"); + // Either null or has a default value + + // overwriteCache should default to false + Object overwriteCache = metadata.get("overwriteCache"); + assertTrue(overwriteCache == null || overwriteCache.equals(false)); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java index 73e80d57..adf9b2f7 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java @@ -14,8 +14,10 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; import org.opensearch.searchrelevance.transport.judgment.PutImportJudgmentRequest; import org.opensearch.searchrelevance.transport.judgment.PutJudgmentRequest; +import org.opensearch.searchrelevance.transport.judgment.PutLlmJudgmentRequest; import org.opensearch.searchrelevance.transport.judgment.PutUbiJudgmentRequest; import org.opensearch.test.OpenSearchTestCase; @@ -76,4 +78,71 @@ public void testImportJudgementStream() throws IOException { assertEquals("B077ZJXCTS", ratings.get("docId")); assertEquals("0.700", ratings.get("rating")); } + + public void testLlmJudgmentRequestStreams() throws IOException { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test_name", + "test_description", + "test_model_id", + "test_query_set_id", + List.of("config1", "config2"), + 10, + 1000, + List.of("field1", "field2"), + false, + "test_prompt_template", + LLMJudgmentRatingType.SCORE0_1, + true + ); + + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + StreamInput in = StreamInput.wrap(output.bytes().toBytesRef().bytes); + PutLlmJudgmentRequest serialized = new PutLlmJudgmentRequest(in); + + assertEquals("test_name", serialized.getName()); + assertEquals(JudgmentType.LLM_JUDGMENT, serialized.getType()); + assertEquals("test_description", serialized.getDescription()); + assertEquals("test_model_id", serialized.getModelId()); + assertEquals("test_query_set_id", serialized.getQuerySetId()); + assertEquals(List.of("config1", "config2"), serialized.getSearchConfigurationList()); + assertEquals(10, serialized.getSize()); + assertEquals(1000, serialized.getTokenLimit()); + assertEquals(List.of("field1", "field2"), serialized.getContextFields()); + assertEquals(false, serialized.isIgnoreFailure()); + assertEquals("test_prompt_template", serialized.getPromptTemplate()); + assertEquals(LLMJudgmentRatingType.SCORE0_1, serialized.getLlmJudgmentRatingType()); + assertEquals(true, serialized.isOverwriteCache()); + } + + public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOException { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test_name", + "test_description", + "test_model_id", + "test_query_set_id", + List.of("config1"), + 5, + 500, + List.of("field1"), + true, + null, + null, + false + ); + + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + StreamInput in = StreamInput.wrap(output.bytes().toBytesRef().bytes); + PutLlmJudgmentRequest serialized = new PutLlmJudgmentRequest(in); + + assertEquals("test_name", serialized.getName()); + assertEquals(JudgmentType.LLM_JUDGMENT, serialized.getType()); + assertEquals("test_description", serialized.getDescription()); + assertNull(serialized.getPromptTemplate()); + assertNull(serialized.getLlmJudgmentRatingType()); + assertEquals(false, serialized.isOverwriteCache()); + } } diff --git a/src/test/java/org/opensearch/searchrelevance/action/queryset/PutQuerySetActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/queryset/PutQuerySetActionTests.java index e25ac3fd..3b5e68c2 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/queryset/PutQuerySetActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/queryset/PutQuerySetActionTests.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -38,8 +39,8 @@ public void testRequestValidation() { private List getQuerySetQueries() { List querySetQueries = new ArrayList<>(); - querySetQueries.add(new QueryWithReference("apple", "")); - querySetQueries.add(new QueryWithReference("banana", "")); + querySetQueries.add(new QueryWithReference("apple", new HashMap<>())); + querySetQueries.add(new QueryWithReference("banana", new HashMap<>())); return querySetQueries; } } diff --git a/src/test/java/org/opensearch/searchrelevance/common/MLConstantsTests.java b/src/test/java/org/opensearch/searchrelevance/common/MLConstantsTests.java new file mode 100644 index 00000000..8cacb883 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/common/MLConstantsTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.common; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +public class MLConstantsTests extends OpenSearchTestCase { + + public void testValidateTokenLimit_ValidInteger() { + Map source = new HashMap<>(); + source.put("tokenLimit", 2000); + + int result = MLConstants.validateTokenLimit(source); + assertEquals(2000, result); + } + + public void testValidateTokenLimit_ValidString() { + Map source = new HashMap<>(); + source.put("tokenLimit", "3000"); + + int result = MLConstants.validateTokenLimit(source); + assertEquals(3000, result); + } + + public void testValidateTokenLimit_MissingField() { + Map source = new HashMap<>(); + + int result = MLConstants.validateTokenLimit(source); + assertEquals((int) MLConstants.DEFAULTED_TOKEN_LIMIT, result); + } + + public void testValidateTokenLimit_BelowMinimum() { + Map source = new HashMap<>(); + source.put("tokenLimit", 500); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> MLConstants.validateTokenLimit(source)); + assertTrue(exception.getMessage().contains("must be between")); + } + + public void testValidateTokenLimit_AboveMaximum() { + Map source = new HashMap<>(); + source.put("tokenLimit", 600000); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> MLConstants.validateTokenLimit(source)); + assertTrue(exception.getMessage().contains("must be between")); + } + + public void testValidateTokenLimit_InvalidType() { + Map source = new HashMap<>(); + source.put("tokenLimit", new Object()); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> MLConstants.validateTokenLimit(source)); + assertTrue(exception.getMessage().contains("Invalid tokenLimit type")); + } + + public void testEscapeJson_NullInput() { + String result = MLConstants.escapeJson(null); + assertEquals("", result); + } + + public void testEscapeJson_WithSpecialCharacters() { + String input = "Line1\nLine2\tTab\"Quote\\Backslash\rReturn"; + String result = MLConstants.escapeJson(input); + + assertTrue(result.contains("\\n")); + assertTrue(result.contains("\\t")); + assertTrue(result.contains("\\\"")); + assertTrue(result.contains("\\\\")); + assertTrue(result.contains("\\r")); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/executors/JudgmentDataTransformerTests.java b/src/test/java/org/opensearch/searchrelevance/executors/JudgmentDataTransformerTests.java index c6c8c934..60d91878 100644 --- a/src/test/java/org/opensearch/searchrelevance/executors/JudgmentDataTransformerTests.java +++ b/src/test/java/org/opensearch/searchrelevance/executors/JudgmentDataTransformerTests.java @@ -25,14 +25,14 @@ public void setUp() throws Exception { public void testCreateJudgmentResultWithRatings() { // Arrange - String queryTextWithReference = "laptop||Professional laptop for business"; + String queryTextWithCustomInput = "laptop||Professional laptop for business"; Map docIdToScore = Map.of("doc1", "0.9", "doc2", "0.7", "doc3", "0.5"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert - assertEquals(queryTextWithReference, result.get("query")); + assertEquals(queryTextWithCustomInput, result.get("query")); List> ratings = (List>) result.get("ratings"); assertEquals(3, ratings.size()); @@ -54,14 +54,14 @@ public void testCreateJudgmentResultWithRatings() { public void testCreateJudgmentResultWithEmptyRatings() { // Arrange - String queryTextWithReference = "laptop||Professional laptop for business"; + String queryTextWithCustomInput = "laptop||Professional laptop for business"; Map docIdToScore = Map.of(); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert - assertEquals(queryTextWithReference, result.get("query")); + assertEquals(queryTextWithCustomInput, result.get("query")); List> ratings = (List>) result.get("ratings"); assertEquals(0, ratings.size()); @@ -69,14 +69,14 @@ public void testCreateJudgmentResultWithEmptyRatings() { public void testCreateJudgmentResultWithNullRatings() { // Arrange - String queryTextWithReference = "laptop||Professional laptop for business"; + String queryTextWithCustomInput = "laptop||Professional laptop for business"; Map docIdToScore = null; // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert - assertEquals(queryTextWithReference, result.get("query")); + assertEquals(queryTextWithCustomInput, result.get("query")); List> ratings = (List>) result.get("ratings"); assertEquals(0, ratings.size()); @@ -84,14 +84,14 @@ public void testCreateJudgmentResultWithNullRatings() { public void testCreateJudgmentResultWithQueryOnly() { // Arrange - String queryTextWithReference = "laptop"; + String queryTextWithCustomInput = "laptop"; Map docIdToScore = Map.of("doc1", "0.8"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert - assertEquals(queryTextWithReference, result.get("query")); + assertEquals(queryTextWithCustomInput, result.get("query")); List> ratings = (List>) result.get("ratings"); assertEquals(1, ratings.size()); @@ -101,11 +101,11 @@ public void testCreateJudgmentResultWithQueryOnly() { public void testCreateJudgmentResultRatingStructure() { // Arrange - String queryTextWithReference = "test query"; + String queryTextWithCustomInput = "test query"; Map docIdToScore = Map.of("testDoc", "0.95"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert List> ratings = (List>) result.get("ratings"); @@ -120,11 +120,11 @@ public void testCreateJudgmentResultRatingStructure() { public void testCreateJudgmentResultMultipleRatingsOrder() { // Arrange - String queryTextWithReference = "test query"; + String queryTextWithCustomInput = "test query"; Map docIdToScore = Map.of("docA", "0.1", "docB", "0.2", "docC", "0.3"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert List> ratings = (List>) result.get("ratings"); @@ -140,14 +140,14 @@ public void testCreateJudgmentResultMultipleRatingsOrder() { public void testCreateJudgmentResultWithSpecialCharacters() { // Arrange - String queryTextWithReference = "special||query with \"quotes\" and 'apostrophes'"; + String queryTextWithCustomInput = "special||query with \"quotes\" and 'apostrophes'"; Map docIdToScore = Map.of("doc-with-dash", "0.6"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert - assertEquals(queryTextWithReference, result.get("query")); + assertEquals(queryTextWithCustomInput, result.get("query")); List> ratings = (List>) result.get("ratings"); assertEquals(1, ratings.size()); @@ -157,11 +157,11 @@ public void testCreateJudgmentResultWithSpecialCharacters() { public void testCreateJudgmentResultWithZeroRating() { // Arrange - String queryTextWithReference = "test query"; + String queryTextWithCustomInput = "test query"; Map docIdToScore = Map.of("doc1", "0.0"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert List> ratings = (List>) result.get("ratings"); @@ -172,11 +172,11 @@ public void testCreateJudgmentResultWithZeroRating() { public void testCreateJudgmentResultWithMaxRating() { // Arrange - String queryTextWithReference = "test query"; + String queryTextWithCustomInput = "test query"; Map docIdToScore = Map.of("doc1", "1.0"); // Act - Map result = transformer.createJudgmentResult(queryTextWithReference, docIdToScore); + Map result = transformer.createJudgmentResult(queryTextWithCustomInput, docIdToScore); // Assert List> ratings = (List>) result.get("ratings"); diff --git a/src/test/java/org/opensearch/searchrelevance/executors/JudgmentTaskContextTests.java b/src/test/java/org/opensearch/searchrelevance/executors/JudgmentTaskContextTests.java index f4478e36..2e45d58d 100644 --- a/src/test/java/org/opensearch/searchrelevance/executors/JudgmentTaskContextTests.java +++ b/src/test/java/org/opensearch/searchrelevance/executors/JudgmentTaskContextTests.java @@ -24,7 +24,7 @@ public class JudgmentTaskContextTests extends OpenSearchTestCase { public void testTaskContextInitialization() { // Arrange - String queryTextWithReference = "laptop#Professional laptop for business"; + String queryTextWithCustomInput = "laptop#Professional laptop for business"; String modelId = "test-model-id"; List contextFields = List.of("name", "description"); List searchConfigurations = List.of(mock(SearchConfiguration.class)); @@ -33,7 +33,7 @@ public void testTaskContextInitialization() { // Act JudgmentTaskContext context = new JudgmentTaskContext( - queryTextWithReference, + queryTextWithCustomInput, modelId, contextFields, searchConfigurations, @@ -42,7 +42,7 @@ public void testTaskContextInitialization() { ); // Assert - assertEquals(queryTextWithReference, context.getQueryTextWithReference()); + assertEquals(queryTextWithCustomInput, context.getQueryTextWithCustomInput()); assertEquals(modelId, context.getModelId()); assertEquals(contextFields, context.getContextFields()); assertEquals(searchConfigurations, context.getSearchConfigurations()); diff --git a/src/test/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesTests.java b/src/test/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesTests.java index 3cc8a55d..753d1a6f 100644 --- a/src/test/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesTests.java +++ b/src/test/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesTests.java @@ -77,13 +77,4 @@ public void testSchemaVersionParsedFromMapping() { } } - /** - * Test that all indices currently have schema_version = 0 - * (This test documents the current state and should be updated when versions are bumped) - */ - public void testAllIndicesHaveSchemaVersionZero() { - for (SearchRelevanceIndices index : SearchRelevanceIndices.values()) { - assertEquals("Index " + index.getIndexName() + " should have schema_version = 0", 0, index.getSchemaVersion()); - } - } } diff --git a/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorRatingConversionTests.java b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorRatingConversionTests.java new file mode 100644 index 00000000..16fd6ed8 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorRatingConversionTests.java @@ -0,0 +1,258 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.judgments; + +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.searchrelevance.utils.RatingOutputProcessor; +import org.opensearch.test.OpenSearchTestCase; + +/** + * Unit tests for RatingOutputProcessor's convertRatingScore method. + * These tests verify the conversion logic for different rating types. + */ +public class LlmJudgmentsProcessorRatingConversionTests extends OpenSearchTestCase { + + /** + * Helper method to call the convertRatingScore method + */ + private Double invokeConvertRatingScore(Object ratingScoreObj, LLMJudgmentRatingType ratingType) { + return RatingOutputProcessor.convertRatingScore(ratingScoreObj, ratingType); + } + + // ============================================ + // SCORE0_1 Rating Type Tests + // ============================================ + + /** + * Test convertRatingScore for SCORE0_1 with Double input + */ + public void testConvertRatingScore_SCORE0_1_WithDouble() throws Exception { + Double result = invokeConvertRatingScore(0.9, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Double 0.9 correctly", 0.9, result, 0.0001); + } + + /** + * Test convertRatingScore for SCORE0_1 with Integer input + */ + public void testConvertRatingScore_SCORE0_1_WithInteger() throws Exception { + Double result = invokeConvertRatingScore(1, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Integer 1 to 1.0", 1.0, result, 0.0001); + } + + /** + * Test convertRatingScore for SCORE0_1 with Float input + */ + public void testConvertRatingScore_SCORE0_1_WithFloat() throws Exception { + Double result = invokeConvertRatingScore(0.75f, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Float 0.75 correctly", 0.75, result, 0.0001); + } + + /** + * Test convertRatingScore for SCORE0_1 with boundary values + */ + public void testConvertRatingScore_SCORE0_1_BoundaryValues() throws Exception { + // Minimum value + Double min = invokeConvertRatingScore(0.0, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should handle 0.0", 0.0, min, 0.0001); + + // Maximum value + Double max = invokeConvertRatingScore(1.0, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should handle 1.0", 1.0, max, 0.0001); + + // Mid value + Double mid = invokeConvertRatingScore(0.5, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should handle 0.5", 0.5, mid, 0.0001); + } + + /** + * Test convertRatingScore for SCORE0_1 with various numeric types + */ + public void testConvertRatingScore_SCORE0_1_VariousNumericTypes() throws Exception { + // Long + Double fromLong = invokeConvertRatingScore(1L, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Long", 1.0, fromLong, 0.0001); + + // Short + Short shortVal = 0; + Double fromShort = invokeConvertRatingScore(shortVal, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Short", 0.0, fromShort, 0.0001); + + // Byte + Byte byteVal = 1; + Double fromByte = invokeConvertRatingScore(byteVal, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should convert Byte", 1.0, fromByte, 0.0001); + } + + // ============================================ + // RELEVANT_IRRELEVANT Rating Type Tests + // ============================================ + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with "RELEVANT" + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_Relevant() throws Exception { + Double result = invokeConvertRatingScore("RELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + assertEquals("RELEVANT should convert to 1.0", 1.0, result, 0.0001); + } + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with "IRRELEVANT" + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_Irrelevant() throws Exception { + Double result = invokeConvertRatingScore("IRRELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + assertEquals("IRRELEVANT should convert to 0.0", 0.0, result, 0.0001); + } + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with invalid value + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_InvalidValue() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { + invokeConvertRatingScore("MAYBE", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + }); + + assertTrue("Error message should mention invalid value", exception.getMessage().contains("Invalid binary rating value")); + assertTrue("Error message should mention MAYBE", exception.getMessage().contains("MAYBE")); + } + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with case-sensitive values + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_CaseSensitive() { + // Lowercase "relevant" should fail (case-sensitive) + IllegalArgumentException lowercase = expectThrows(IllegalArgumentException.class, () -> { + invokeConvertRatingScore("relevant", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + }); + assertNotNull("Lowercase should throw exception", lowercase); + + // Mixed case should fail + IllegalArgumentException mixedCase = expectThrows(IllegalArgumentException.class, () -> { + invokeConvertRatingScore("Relevant", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + }); + assertNotNull("Mixed case should throw exception", mixedCase); + } + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with null value + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_NullValue() { + Exception exception = expectThrows( + Exception.class, + () -> { invokeConvertRatingScore(null, LLMJudgmentRatingType.RELEVANT_IRRELEVANT); } + ); + assertNotNull("Should throw exception for null", exception); + } + + /** + * Test convertRatingScore for RELEVANT_IRRELEVANT with numeric value (wrong type) + */ + public void testConvertRatingScore_RELEVANT_IRRELEVANT_WrongType() { + Exception exception = expectThrows( + Exception.class, + () -> { invokeConvertRatingScore(1.0, LLMJudgmentRatingType.RELEVANT_IRRELEVANT); } + ); + assertNotNull("Should throw exception for numeric value", exception); + } + + // ============================================ + // Edge Cases and Error Handling + // ============================================ + + /** + * Test convertRatingScore with null rating type + * When ratingType is null, it falls through to the else clause and treats it as numeric (SCORE0_1) + */ + public void testConvertRatingScore_NullRatingType() throws Exception { + Double result = invokeConvertRatingScore(0.9, null); + assertEquals("Null rating type should default to numeric conversion", 0.9, result, 0.0001); + } + + /** + * Test convertRatingScore for SCORE0_1 with null value + */ + public void testConvertRatingScore_SCORE0_1_NullValue() { + Exception exception = expectThrows(Exception.class, () -> { invokeConvertRatingScore(null, LLMJudgmentRatingType.SCORE0_1); }); + assertNotNull("Should throw exception for null value", exception); + } + + /** + * Test that SCORE0_1 accepts values outside 0-1 range (no validation) + * Note: The method doesn't validate range, only converts the value + */ + public void testConvertRatingScore_SCORE0_1_OutOfRangeValues() throws Exception { + // Negative value + Double negative = invokeConvertRatingScore(-0.5, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should accept negative value", -0.5, negative, 0.0001); + + // Value greater than 1 + Double overOne = invokeConvertRatingScore(1.5, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should accept value > 1", 1.5, overOne, 0.0001); + + // Large value + Double large = invokeConvertRatingScore(100.0, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Should accept large value", 100.0, large, 0.0001); + } + + // ============================================ + // Real-world Scenario Tests + // ============================================ + + /** + * Test conversion with typical LLM responses for SCORE0_1 + */ + public void testConvertRatingScore_RealWorld_SCORE0_1() throws Exception { + // LLM typically returns doubles between 0 and 1 + Double highRelevance = invokeConvertRatingScore(0.95, LLMJudgmentRatingType.SCORE0_1); + assertEquals("High relevance score", 0.95, highRelevance, 0.0001); + + Double mediumRelevance = invokeConvertRatingScore(0.6, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Medium relevance score", 0.6, mediumRelevance, 0.0001); + + Double lowRelevance = invokeConvertRatingScore(0.2, LLMJudgmentRatingType.SCORE0_1); + assertEquals("Low relevance score", 0.2, lowRelevance, 0.0001); + + Double noRelevance = invokeConvertRatingScore(0.0, LLMJudgmentRatingType.SCORE0_1); + assertEquals("No relevance score", 0.0, noRelevance, 0.0001); + } + + /** + * Test conversion with typical LLM responses for RELEVANT_IRRELEVANT + */ + public void testConvertRatingScore_RealWorld_RELEVANT_IRRELEVANT() throws Exception { + // LLM returns "RELEVANT" or "IRRELEVANT" strings + Double relevant = invokeConvertRatingScore("RELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + assertEquals("RELEVANT converts to 1.0", 1.0, relevant, 0.0001); + + Double irrelevant = invokeConvertRatingScore("IRRELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + assertEquals("IRRELEVANT converts to 0.0", 0.0, irrelevant, 0.0001); + + // Verify these can be directly used as rating strings + assertEquals("1.0", relevant.toString()); + assertEquals("0.0", irrelevant.toString()); + } + + /** + * Test that converted values can be properly used as strings + */ + public void testConvertRatingScore_StringConversion() throws Exception { + // SCORE0_1 to string + Double score = invokeConvertRatingScore(0.85, LLMJudgmentRatingType.SCORE0_1); + String scoreStr = score.toString(); + assertEquals("Should convert to string correctly", "0.85", scoreStr); + + // RELEVANT to string (should be "1.0") + Double relevant = invokeConvertRatingScore("RELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + String relevantStr = relevant.toString(); + assertEquals("RELEVANT as string should be 1.0", "1.0", relevantStr); + + // IRRELEVANT to string (should be "0.0") + Double irrelevant = invokeConvertRatingScore("IRRELEVANT", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + String irrelevantStr = irrelevant.toString(); + assertEquals("IRRELEVANT as string should be 0.0", "0.0", irrelevantStr); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorTests.java b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorTests.java new file mode 100644 index 00000000..932ae064 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorTests.java @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.judgments; + +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.searchrelevance.dao.JudgmentCacheDao; +import org.opensearch.searchrelevance.dao.QuerySetDao; +import org.opensearch.searchrelevance.dao.SearchConfigurationDao; +import org.opensearch.searchrelevance.ml.MLAccessor; +import org.opensearch.searchrelevance.model.JudgmentType; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; +import org.opensearch.searchrelevance.stats.events.EventStatsManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for LlmJudgmentsProcessor focusing on prompt templates and rating types. + */ +public class LlmJudgmentsProcessorTests extends OpenSearchTestCase { + + private LlmJudgmentsProcessor processor; + private ThreadPool threadPool; + + @Mock + private MLAccessor mockMLAccessor; + + @Mock + private QuerySetDao mockQuerySetDao; + + @Mock + private SearchConfigurationDao mockSearchConfigurationDao; + + @Mock + private JudgmentCacheDao mockJudgmentCacheDao; + + @Mock + private Client mockClient; + + @Mock + private SearchRelevanceSettingsAccessor mockSettingsAccessor; + + private EventStatsManager eventStatsManager; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Configure the mock settings accessor + when(mockSettingsAccessor.isStatsEnabled()).thenReturn(false); + + // Initialize and configure EventStatsManager with our mock + eventStatsManager = EventStatsManager.instance(); + eventStatsManager.initialize(mockSettingsAccessor); + + // Create a real thread pool for testing + threadPool = new TestThreadPool("test-thread-pool"); + + processor = new LlmJudgmentsProcessor( + mockMLAccessor, + mockQuerySetDao, + mockSearchConfigurationDao, + mockJudgmentCacheDao, + mockClient, + threadPool + ); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + + public void testGetJudgmentType() { + assertEquals(JudgmentType.LLM_JUDGMENT, processor.getJudgmentType()); + } + + // ============================================ + // Metadata Validation Tests + // ============================================ + + public void testMetadata_AllRatingTypes() { + // Test that all rating types are valid values for metadata + Map metadata = createBasicMetadata(); + + // SCORE0_1 + metadata.put("llmJudgmentRatingType", LLMJudgmentRatingType.SCORE0_1); + assertNotNull("SCORE0_1 should be valid", metadata.get("llmJudgmentRatingType")); + + // RELEVANT_IRRELEVANT + metadata.put("llmJudgmentRatingType", LLMJudgmentRatingType.RELEVANT_IRRELEVANT); + assertNotNull("RELEVANT_IRRELEVANT should be valid", metadata.get("llmJudgmentRatingType")); + } + + public void testMetadata_DefaultRatingTypeIsNull() { + // Test that null rating type in metadata is acceptable + Map metadata = createBasicMetadata(); + metadata.put("llmJudgmentRatingType", null); + + // This should not throw any exception + assertNull("Rating type can be null", metadata.get("llmJudgmentRatingType")); + } + + public void testMetadata_PromptTemplateVariations() { + // Test various prompt template values + Map metadata = createBasicMetadata(); + + // Custom template + String customTemplate = "Rate relevance from 0 to 1"; + metadata.put("promptTemplate", customTemplate); + assertEquals("Custom template should be set", customTemplate, metadata.get("promptTemplate")); + + // Empty template + metadata.put("promptTemplate", ""); + assertEquals("Empty template should be set", "", metadata.get("promptTemplate")); + + // Null template + metadata.put("promptTemplate", null); + assertNull("Null template should be allowed", metadata.get("promptTemplate")); + } + + public void testMetadata_CombinedRatingTypeAndPrompt() { + // Test that metadata can hold both rating type and prompt template + Map metadata = new HashMap<>(); + + metadata.put("llmJudgmentRatingType", LLMJudgmentRatingType.SCORE0_1); + metadata.put("promptTemplate", "Custom prompt for 0-1 scale"); + + assertEquals(LLMJudgmentRatingType.SCORE0_1, metadata.get("llmJudgmentRatingType")); + assertEquals("Custom prompt for 0-1 scale", metadata.get("promptTemplate")); + } + + public void testMetadata_RequiredFields() { + // Test that basic metadata contains all required fields + Map metadata = createBasicMetadata(); + + assertTrue("Metadata should contain querySetId", metadata.containsKey("querySetId")); + assertTrue("Metadata should contain searchConfigurationList", metadata.containsKey("searchConfigurationList")); + assertTrue("Metadata should contain size", metadata.containsKey("size")); + assertTrue("Metadata should contain modelId", metadata.containsKey("modelId")); + assertTrue("Metadata should contain tokenLimit", metadata.containsKey("tokenLimit")); + assertTrue("Metadata should contain contextFields", metadata.containsKey("contextFields")); + assertTrue("Metadata should contain ignoreFailure", metadata.containsKey("ignoreFailure")); + assertTrue("Metadata should contain overwriteCache", metadata.containsKey("overwriteCache")); + } + + // ============================================ + // Rating Type Enum Tests + // ============================================ + + public void testRatingTypeEnum_AllValues() { + // Verify all expected rating types exist + LLMJudgmentRatingType[] ratingTypes = LLMJudgmentRatingType.values(); + + assertEquals("Should have exactly 2 rating types", 2, ratingTypes.length); + + boolean hasSCORE0_1 = false; + boolean hasRELEVANT_IRRELEVANT = false; + + for (LLMJudgmentRatingType type : ratingTypes) { + if (type == LLMJudgmentRatingType.SCORE0_1) hasSCORE0_1 = true; + if (type == LLMJudgmentRatingType.RELEVANT_IRRELEVANT) hasRELEVANT_IRRELEVANT = true; + } + + assertTrue("Should have SCORE0_1", hasSCORE0_1); + assertTrue("Should have RELEVANT_IRRELEVANT", hasRELEVANT_IRRELEVANT); + } + + public void testRatingTypeEnum_GetValidValues() { + // Test that getValidValues() returns all rating types + String validValues = LLMJudgmentRatingType.getValidValues(); + + assertTrue("Valid values should contain SCORE0_1", validValues.contains("SCORE0_1")); + assertTrue("Valid values should contain RELEVANT_IRRELEVANT", validValues.contains("RELEVANT_IRRELEVANT")); + } + + // ============================================ + // Helper Methods + // ============================================ + + private Map createBasicMetadata() { + Map metadata = new HashMap<>(); + metadata.put("querySetId", "test-query-set"); + metadata.put("searchConfigurationList", List.of("test-config")); + metadata.put("size", 10); + metadata.put("modelId", "test-model"); + metadata.put("tokenLimit", 4000); + metadata.put("contextFields", List.of("title", "description")); + metadata.put("ignoreFailure", false); + metadata.put("promptTemplate", "Default prompt template"); + metadata.put("llmJudgmentRatingType", LLMJudgmentRatingType.SCORE0_1); + metadata.put("overwriteCache", false); + return metadata; + } + + private void setupMocksForSuccessfulExecution() { + // Since LlmJudgmentsProcessor uses complex async operations and thread pool, + // we just verify that the methods don't throw exceptions with valid inputs. + // The actual processing logic is tested through integration tests. + + // For unit tests, we're primarily testing: + // 1. Default rating type behavior + // 2. Handling of different rating types + // 3. Handling of different prompt templates + // 4. No exceptions are thrown for valid inputs + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java new file mode 100644 index 00000000..966780df --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorIntegrationTests.java @@ -0,0 +1,157 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +/** + * Integration tests for MLAccessor focusing on: + * - First attempt success with response_format (GPT-4o scenario) + * - Response processing with structured outputs + * + * Note: Tests for retry logic and fallback behavior (GPT-3.5 compatibility) are documented + * in TESTING_GPT35_FALLBACK.md as manual tests because they require delayed retries which + * create thread leaks in the OpenSearch test framework. The retry mechanism uses + * CompletableFuture.delayedExecutor which creates daemon threads that cannot be properly + * cleaned up within test execution. + * + * Covered by unit tests: + * - MLInputOutputTransformerTests: Verifies response_format parameter is correctly included/excluded + * - RatingOutputProcessorTests: Verifies both structured and unstructured response parsing + */ +public class MLAccessorIntegrationTests extends OpenSearchTestCase { + + /** + * Note: GPT-3.5 fallback testing is documented in TESTING_GPT35_FALLBACK.md as "Scenario 2" + * This scenario requires triggering scheduleRetry which creates CompletableFuture threads that leak. + * Coverage is provided by: + * - Unit tests: MLInputOutputTransformerTests verifies response_format parameter handling + * - Manual tests: Real OpenAI GPT-3.5 API integration testing + */ + + /** + * Test that MLAccessor works correctly on first attempt when model supports response_format. + * This simulates GPT-4o model with structured output support. + */ + public void testFirstAttemptSuccess_WhenModelSupportsResponseFormat() throws Exception { + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + MLAccessor mlAccessor = new MLAccessor(mlClient); + + AtomicInteger attemptCount = new AtomicInteger(0); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + + // Mock ML client - succeeds on first attempt with response_format + doAnswer(invocation -> { + MLInput mlInput = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); + + attemptCount.incrementAndGet(); + + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map params = dataset.getParameters(); + + // Verify response_format is present + assertTrue("Should have response_format", params.containsKey("response_format")); + + // Return structured output + String structuredResponse = "{\"ratings\":[{\"id\":\"doc1\",\"rating_score\":0.9}]}"; + MLOutput mockOutput = createMockMLOutput(structuredResponse); + listener.onResponse(mockOutput); + + return null; + }).when(mlClient).predict(any(), any(MLInput.class), any()); + + // Execute prediction + Map hits = Map.of("doc1", "test content"); + mlAccessor.predict( + "gpt-4o-mini", + 4000, + "test query", + new HashMap<>(), + hits, + "Test prompt", + LLMJudgmentRatingType.SCORE0_1, + ActionListener.wrap(chunkResult -> { + result.set(chunkResult); + latch.countDown(); + }, e -> latch.countDown()) + ); + + assertTrue("Should complete", latch.await(10, TimeUnit.SECONDS)); + + // Verify only one attempt was made + assertEquals("Should only need one attempt", 1, attemptCount.get()); + + // Verify successful result + ChunkResult chunkResult = result.get(); + assertNotNull(chunkResult); + assertEquals(1, chunkResult.getSuccessfulChunksCount()); + assertEquals(0, chunkResult.getFailedChunksCount()); + } + + /** + * Note: Binary rating (RELEVANT/IRRELEVANT) fallback testing is documented in + * TESTING_GPT35_FALLBACK.md as "Scenario 3". This test would trigger scheduleRetry + * creating thread leaks. Coverage is provided by: + * - Unit tests: MLInputOutputTransformerTests.testCreateMLInput_BinaryRatingWithoutResponseFormat + * - Unit tests: RatingOutputProcessorTests verifies RELEVANT→1.0, IRRELEVANT→0.0 conversion + * - Manual tests: Real OpenAI API integration testing + */ + + /** + * Note: Testing retry exhaustion (all attempts fail) is documented in TESTING_GPT35_FALLBACK.md + * as a manual test scenario because it requires delayed retries which create thread leaks in tests. + * The retry logic with exponential backoff uses CompletableFuture.delayedExecutor which creates + * daemon threads that cannot be properly cleaned up in the OpenSearch test framework. + */ + + // ============================================ + // Helper Methods + // ============================================ + + /** + * Creates a mock MLOutput with the given JSON response. + */ + private MLOutput createMockMLOutput(String jsonResponse) { + Map dataMap = new HashMap<>(); + List> choices = new ArrayList<>(); + Map choice = new HashMap<>(); + Map message = new HashMap<>(); + message.put("content", jsonResponse); + choice.put("message", message); + choices.add(choice); + dataMap.put("choices", choices); + + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorTests.java index e9c95405..020a94d5 100644 --- a/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorTests.java +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLAccessorTests.java @@ -128,4 +128,63 @@ public void testMessageFormattingWithSpecialCharacters() throws Exception { JsonNode jsonNode = OBJECT_MAPPER.readTree(messagesJson); assertNotNull("JSON should not be null", jsonNode); } + + /** + * Test that cleanResponse does not corrupt valid JSON from OpenAI structured output. + * This is a regression test for the bug where cleanResponse was stripping characters + * from valid JSON, causing it to be unparseable. + */ + public void testCleanResponsePreservesValidJson() throws Exception { + // Valid JSON response from OpenAI structured output + String validJsonResponse = "{\"ratings\":[{\"id\":\"1\",\"rating_score\":0.9}]}"; + + // cleanResponse should return the response as-is + // (We can't directly test the private method, but we verify the concept) + JsonNode jsonNode = OBJECT_MAPPER.readTree(validJsonResponse); + assertNotNull("JSON should be parseable", jsonNode); + assertTrue("JSON should have ratings array", jsonNode.has("ratings")); + assertTrue("Ratings should be an array", jsonNode.get("ratings").isArray()); + assertEquals("Should have one rating", 1, jsonNode.get("ratings").size()); + + JsonNode rating = jsonNode.get("ratings").get(0); + assertEquals("ID should be preserved", "1", rating.get("id").asText()); + assertEquals("Rating score should be preserved", 0.9, rating.get("rating_score").asDouble(), 0.001); + } + + /** + * Test various valid JSON formats that should be preserved by cleanResponse + */ + public void testCleanResponseVariousFormats() throws Exception { + // Test empty ratings array + String emptyRatings = "{\"ratings\":[]}"; + JsonNode node1 = OBJECT_MAPPER.readTree(emptyRatings); + assertNotNull("Empty ratings should be valid JSON", node1); + assertEquals("Should have empty ratings array", 0, node1.get("ratings").size()); + + // Test multiple ratings + String multipleRatings = "{\"ratings\":[{\"id\":\"1\",\"rating_score\":0.9},{\"id\":\"2\",\"rating_score\":0.5}]}"; + JsonNode node2 = OBJECT_MAPPER.readTree(multipleRatings); + assertNotNull("Multiple ratings should be valid JSON", node2); + assertEquals("Should have two ratings", 2, node2.get("ratings").size()); + + // Test with composite keys + String compositeKeys = "{\"ratings\":[{\"id\":\"test_products::1\",\"rating_score\":1.0}]}"; + JsonNode node3 = OBJECT_MAPPER.readTree(compositeKeys); + assertNotNull("Composite keys should be valid JSON", node3); + assertEquals("Composite key should be preserved", "test_products::1", node3.get("ratings").get(0).get("id").asText()); + } + + /** + * Test that malformed responses from LLM would be handled + * (This tests the sanitization logic in RatingOutputProcessor, not cleanResponse) + */ + public void testMalformedJsonHandling() { + // These would be handled by sanitizeLLMResponse, not cleanResponse + String withCodeBlock = "```json\n{\"ratings\":[{\"id\":\"1\",\"rating_score\":0.9}]}\n```"; + String withText = "Here are the ratings:\n{\"ratings\":[{\"id\":\"1\",\"rating_score\":0.9}]}"; + + // Both contain valid JSON that should be extractable by sanitization + assertTrue("Code block should contain valid JSON", withCodeBlock.contains("{\"ratings\"")); + assertTrue("Text response should contain valid JSON", withText.contains("{\"ratings\"")); + } } diff --git a/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerTests.java b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerTests.java new file mode 100644 index 00000000..c968a576 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/MLInputOutputTransformerTests.java @@ -0,0 +1,243 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.searchrelevance.model.LLMJudgmentRatingType; +import org.opensearch.test.OpenSearchTestCase; + +/** + * Tests for MLInputOutputTransformer focusing on response_format parameter handling. + */ +public class MLInputOutputTransformerTests extends OpenSearchTestCase { + + private MLInputOutputTransformer transformer; + + @Override + public void setUp() throws Exception { + super.setUp(); + transformer = new MLInputOutputTransformer(); + } + + // ============================================ + // Response Format Parameter Tests + // ============================================ + + public void testCreateMLInput_WithResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + // Should include response_format parameter + assertTrue("response_format parameter should be present", parameters.containsKey("response_format")); + assertNotNull("response_format should not be null", parameters.get("response_format")); + assertTrue("response_format should contain json_schema", parameters.get("response_format").contains("json_schema")); + } + + public void testCreateMLInput_WithoutResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, false); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + // Should NOT include response_format parameter + assertFalse("response_format parameter should not be present for GPT-3.5 compatibility", parameters.containsKey("response_format")); + // Messages parameter should still be present + assertTrue("messages parameter should be present", parameters.containsKey("messages")); + } + + public void testCreateMLInput_DefaultIncludesResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + // Using the method without includeResponseFormat parameter (default = true) + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + // Default should include response_format + assertTrue("Default behavior should include response_format", parameters.containsKey("response_format")); + } + + // ============================================ + // Different Rating Types with Response Format + // ============================================ + + public void testCreateMLInput_BinaryRatingWithResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.RELEVANT_IRRELEVANT; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + assertTrue("response_format parameter should be present", parameters.containsKey("response_format")); + String responseFormat = parameters.get("response_format"); + // Binary rating should use string enum schema + assertTrue("Binary rating should use enum schema", responseFormat.contains("enum")); + assertTrue("Binary rating should include RELEVANT", responseFormat.contains("RELEVANT")); + assertTrue("Binary rating should include IRRELEVANT", responseFormat.contains("IRRELEVANT")); + } + + public void testCreateMLInput_BinaryRatingWithoutResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.RELEVANT_IRRELEVANT; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, false); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + // Should NOT include response_format for GPT-3.5 compatibility + assertFalse("response_format should not be present", parameters.containsKey("response_format")); + } + + public void testCreateMLInput_NumericRatingWithResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + assertTrue("response_format parameter should be present", parameters.containsKey("response_format")); + String responseFormat = parameters.get("response_format"); + // Numeric rating should use number type + assertTrue("Numeric rating should use number type", responseFormat.contains("\"type\":\"number\"")); + } + + // ============================================ + // Multiple Hits Scenarios + // ============================================ + + public void testCreateMLInput_MultipleHitsWithResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "content 1"); + hits.put("doc2", "content 2"); + hits.put("doc3", "content 3"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + assertTrue("response_format should be present even with multiple hits", parameters.containsKey("response_format")); + assertNotNull("messages parameter should not be null", parameters.get("messages")); + assertFalse("messages parameter should not be empty", parameters.get("messages").isEmpty()); + } + + public void testCreateMLInput_MultipleHitsWithoutResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); + hits.put("doc1", "content 1"); + hits.put("doc2", "content 2"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, false); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + assertFalse("response_format should not be present", parameters.containsKey("response_format")); + assertNotNull("messages parameter should not be null", parameters.get("messages")); + assertFalse("messages parameter should not be empty", parameters.get("messages").isEmpty()); + } + + // ============================================ + // Edge Cases + // ============================================ + + public void testCreateMLInput_EmptyHitsWithResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + Map hits = new HashMap<>(); // Empty hits + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + // Should still have response_format even with empty hits + assertTrue("response_format should be present even with empty hits", parameters.containsKey("response_format")); + } + + public void testCreateMLInput_WithReferenceDataAndResponseFormat() { + String searchText = "test query"; + Map referenceData = new HashMap<>(); + referenceData.put("reference", "Expected answer"); + Map hits = new HashMap<>(); + hits.put("doc1", "test content"); + String promptTemplate = "Test prompt"; + LLMJudgmentRatingType ratingType = LLMJudgmentRatingType.SCORE0_1; + + MLInput mlInput = transformer.createMLInput(searchText, referenceData, hits, promptTemplate, ratingType, true); + + assertNotNull(mlInput); + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + Map parameters = dataset.getParameters(); + + assertTrue("response_format should be present", parameters.containsKey("response_format")); + assertNotNull("messages parameter should not be null", parameters.get("messages")); + assertFalse("messages parameter should not be empty", parameters.get("messages").isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/ml/UserPromptFactoryTests.java b/src/test/java/org/opensearch/searchrelevance/ml/UserPromptFactoryTests.java new file mode 100644 index 00000000..faaf7113 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/ml/UserPromptFactoryTests.java @@ -0,0 +1,395 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.ml; + +import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH; +import static org.opensearch.searchrelevance.common.MLConstants.INPUT_FORMAT_SEARCH_WITH_REFERENCE; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +/** + * Unit tests for UserPromptFactory focusing on template variable replacement. + */ +public class UserPromptFactoryTests extends OpenSearchTestCase { + + // ============================================ + // Default Format Tests (No Template Provided) + // ============================================ + + public void testBuildUserContent_NoTemplate_NoReferenceData() { + // Test default format when no template and no reference data + String searchText = "What is OpenSearch?"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"},{\"id\":\"2\",\"source\":\"doc2\"}]"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, null); + + String expected = String.format(Locale.ROOT, INPUT_FORMAT_SEARCH, searchText, hitsJson); + assertEquals("Should use INPUT_FORMAT_SEARCH when no reference data", expected, result); + } + + public void testBuildUserContent_NoTemplate_WithReferenceData() { + // Test default format when no template but reference data exists + String searchText = "What is OpenSearch?"; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "OpenSearch is a search and analytics suite"); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, null); + + String expected = String.format( + Locale.ROOT, + INPUT_FORMAT_SEARCH_WITH_REFERENCE, + searchText, + "OpenSearch is a search and analytics suite", + hitsJson + ); + assertEquals("Should use INPUT_FORMAT_SEARCH_WITH_REFERENCE when reference data exists", expected, result); + } + + public void testBuildUserContent_NoTemplate_MultipleReferenceFields() { + // Test default format with multiple reference fields (should concatenate) + String searchText = "red shoes"; + Map referenceData = new HashMap<>(); + referenceData.put("color", "red"); + referenceData.put("category", "footwear"); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, null); + + // Should concatenate all values with "; " delimiter + assertTrue("Should contain search text", result.contains(searchText)); + assertTrue("Should contain hitsJson", result.contains(hitsJson)); + // Should use one of the reference values + assertTrue("Should contain reference data", result.contains("red") || result.contains("footwear")); + } + + public void testBuildUserContent_EmptyTemplate() { + // Test that empty template falls back to default format + String searchText = "test query"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, ""); + + String expected = String.format(Locale.ROOT, INPUT_FORMAT_SEARCH, searchText, hitsJson); + assertEquals("Empty template should use default format", expected, result); + } + + public void testBuildUserContent_WhitespaceTemplate() { + // Test that whitespace-only template falls back to default format + String searchText = "test query"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, " "); + + String expected = String.format(Locale.ROOT, INPUT_FORMAT_SEARCH, searchText, hitsJson); + assertEquals("Whitespace template should use default format", expected, result); + } + + // ============================================ + // Template Variable Replacement Tests + // ============================================ + + public void testBuildUserContent_Template_QueryVariable() { + // Test replacement of {{queryText}} variable + String searchText = "What is OpenSearch?"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + String template = "User query: {{queryText}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{queryText}} with searchText", "User query: What is OpenSearch?", result); + } + + public void testBuildUserContent_Template_SearchTextVariable() { + // Test replacement of {{searchText}} variable + String searchText = "red shoes"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + String template = "Search: {{searchText}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{searchText}} with searchText", "Search: red shoes", result); + } + + public void testBuildUserContent_Template_HitsVariable() { + // Test replacement of {{hits}} variable + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + String template = "Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{hits}} with hitsJson", "Results: " + hitsJson, result); + } + + public void testBuildUserContent_Template_ResultsVariable() { + // Test replacement of {{results}} variable (alias for hits) + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Search results: {{results}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{results}} with hitsJson", "Search results: " + hitsJson, result); + } + + public void testBuildUserContent_Template_ReferenceVariable() { + // Test replacement of {{reference}} variable + String searchText = "test"; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "This is the reference answer"); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Reference: {{reference}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{reference}} with referenceAnswer", "Reference: This is the reference answer", result); + } + + public void testBuildUserContent_Template_ReferenceAnswerVariable() { + // Test replacement of {{referenceAnswer}} variable + String searchText = "test"; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "Expected answer"); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Expected: {{referenceAnswer}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace {{referenceAnswer}} with referenceAnswer", "Expected: Expected answer", result); + } + + public void testBuildUserContent_Template_CustomField() { + // Test replacement of custom field from referenceData + String searchText = "test"; + Map referenceData = new HashMap<>(); + referenceData.put("category", "electronics"); + referenceData.put("brand", "Sony"); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Category: {{category}}, Brand: {{brand}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace custom fields", "Category: electronics, Brand: Sony", result); + } + + public void testBuildUserContent_Template_MultipleVariables() { + // Test replacement of multiple variables in one template + String searchText = "What is OpenSearch?"; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "OpenSearch is a search suite"); + referenceData.put("category", "technology"); + String hitsJson = "[{\"id\":\"1\",\"source\":\"doc1\"}]"; + String template = "Query: {{queryText}}\nReference: {{referenceAnswer}}\nCategory: {{category}}\nResults: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + String expected = "Query: What is OpenSearch?\n" + + "Reference: OpenSearch is a search suite\n" + + "Category: technology\n" + + "Results: " + + hitsJson; + assertEquals("Should replace all variables", expected, result); + } + + public void testBuildUserContent_Template_UnknownVariable() { + // Test that unknown variables are replaced with empty string + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{queryText}}, Unknown: {{unknownField}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace unknown variable with empty string", "Query: test, Unknown: ", result); + } + + public void testBuildUserContent_Template_NoReferenceAnswer() { + // Test {{reference}} when referenceAnswer doesn't exist + String searchText = "test"; + Map referenceData = new HashMap<>(); + referenceData.put("category", "tech"); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{queryText}}, Reference: {{reference}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should replace missing reference with empty string", "Query: test, Reference: ", result); + } + + public void testBuildUserContent_Template_NullReferenceData() { + // Test template with null referenceData + String searchText = "test"; + Map referenceData = null; + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{queryText}}, Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should handle null referenceData", "Query: test, Results: " + hitsJson, result); + } + + public void testBuildUserContent_Template_SameVariableMultipleTimes() { + // Test using the same variable multiple times + String searchText = "OpenSearch"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "{{queryText}} is awesome. {{queryText}} is open source. What is {{queryText}}?"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals( + "Should replace all occurrences of same variable", + "OpenSearch is awesome. OpenSearch is open source. What is OpenSearch?", + result + ); + } + + public void testBuildUserContent_Template_VariableWithSpaces() { + // Test that variables with spaces are NOT replaced (trimming happens but replacement doesn't match) + // This is current behavior - the matcher extracts and trims the variable name, + // but the replacement looks for the exact original pattern with spaces + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{ query }}, Results: {{ hits }}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + // Current behavior: variables with spaces are left as-is because replacement doesn't match + assertEquals("Variables with spaces should be left as-is (current behavior)", template, result); + } + + public void testBuildUserContent_Template_ComplexRealWorldExample() { + // Test a complex real-world template + String searchText = "red leather shoes"; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "High quality red leather shoes with rubber sole"); + referenceData.put("expectedScore", "0.9"); + referenceData.put("category", "footwear"); + String hitsJson = "[{\"id\":\"doc1\",\"source\":\"Red shoes\"},{\"id\":\"doc2\",\"source\":\"Leather boots\"}]"; + String template = "Given the search query: {{queryText}}\n\n" + + "Expected answer: {{referenceAnswer}}\n" + + "Expected relevance score: {{expectedScore}}\n" + + "Product category: {{category}}\n\n" + + "Please rate the following search results:\n" + + "{{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + String expected = "Given the search query: red leather shoes\n\n" + + "Expected answer: High quality red leather shoes with rubber sole\n" + + "Expected relevance score: 0.9\n" + + "Product category: footwear\n\n" + + "Please rate the following search results:\n" + + hitsJson; + assertEquals("Should handle complex real-world template", expected, result); + } + + public void testBuildUserContent_Template_EmptySearchText() { + // Test with empty search text + String searchText = ""; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{queryText}}, Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should handle empty search text", "Query: , Results: " + hitsJson, result); + } + + public void testBuildUserContent_Template_NullSearchText() { + // Test with null search text + String searchText = null; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {{queryText}}, Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should handle null search text", "Query: , Results: " + hitsJson, result); + } + + public void testBuildUserContent_Template_EmptyHitsJson() { + // Test with empty hits JSON + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = ""; + String template = "Query: {{queryText}}, Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should handle empty hits JSON", "Query: test, Results: ", result); + } + + public void testBuildUserContent_Template_NullHitsJson() { + // Test with null hits JSON + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = null; + String template = "Query: {{queryText}}, Results: {{hits}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should handle null hits JSON", "Query: test, Results: ", result); + } + + public void testBuildUserContent_Template_SpecialCharactersInValues() { + // Test with special characters in values + String searchText = "test \"quoted\" & special "; + Map referenceData = new HashMap<>(); + referenceData.put("referenceAnswer", "Answer with 'quotes' & symbols"); + String hitsJson = "[{\"id\":\"1\",\"source\":\"data\"}]"; + String template = "Query: {{queryText}}\nReference: {{referenceAnswer}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals( + "Should handle special characters", + "Query: test \"quoted\" & special \nReference: Answer with 'quotes' & symbols", + result + ); + } + + public void testBuildUserContent_Template_NoVariables() { + // Test template with no variables (static text) + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "This is a static prompt with no variables."; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + assertEquals("Should return template as-is when no variables", template, result); + } + + public void testBuildUserContent_Template_MalformedVariables() { + // Test template with malformed variables + String searchText = "test"; + Map referenceData = new HashMap<>(); + String hitsJson = "[{\"id\":\"1\"}]"; + String template = "Query: {query} or {{query or query}} or {{ or {{}}"; + + String result = UserPromptFactory.buildUserContent(searchText, referenceData, hitsJson, template); + + // Malformed variables should be left as-is + assertTrue("Should not replace malformed variables", result.contains("{query}")); + assertTrue("Should handle empty variable", result.contains("{{}}")); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/rest/RestPutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/rest/RestPutJudgmentActionTests.java index f0e25a1e..2a81548a 100644 --- a/src/test/java/org/opensearch/searchrelevance/rest/RestPutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/rest/RestPutJudgmentActionTests.java @@ -43,6 +43,22 @@ public class RestPutJudgmentActionTests extends SearchRelevanceRestTestCase { + "\"ignoreFailure\": false" + "}"; + private static final String LLM_JUDGMENT_CONTENT_WITH_NEW_FIELDS = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"type\": \"LLM_JUDGMENT\"," + + "\"modelId\": \"test_model_id\"," + + "\"querySetId\": \"test_query_set_id\"," + + "\"searchConfigurationList\": [\"config1\", \"config2\"]," + + "\"size\": 10," + + "\"tokenLimit\": 1000," + + "\"contextFields\": [\"field1\", \"field2\"]," + + "\"ignoreFailure\": false," + + "\"promptTemplate\": \"Query: {{queryText}}\\\\n\\\\nDocuments: {{hits}}\"," + + "\"llmJudgmentRatingType\": \"SCORE0_1\"," + + "\"overwriteCache\": true" + + "}"; + private static final String UBI_JUDGMENT_CONTENT = "{" + "\"name\": \"test_name\"," + "\"description\": \"test_description\"," @@ -233,4 +249,99 @@ public void testPutJudgment_Failure() throws Exception { verify(channel).sendResponse(responseCaptor.capture()); assertEquals(RestStatus.INTERNAL_SERVER_ERROR, responseCaptor.getValue().status()); } + + public void testPutLlmJudgment_WithNewFields_Success() throws Exception { + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + RestRequest request = createPutRestRequestWithContent(LLM_JUDGMENT_CONTENT_WITH_NEW_FIELDS, "judgment"); + when(channel.request()).thenReturn(request); + + // Mock index response + IndexResponse mockIndexResponse = mock(IndexResponse.class); + when(mockIndexResponse.getId()).thenReturn("test_id"); + + // Capture the request to verify new fields + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutLlmJudgmentRequest.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockIndexResponse); + return null; + }).when(client).execute(eq(PutJudgmentAction.INSTANCE), requestCaptor.capture(), any()); + + // Execute + restPutJudgmentAction.handleRequest(request, channel, client); + + // Verify response + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); + verify(channel).sendResponse(responseCaptor.capture()); + assertEquals(RestStatus.OK, responseCaptor.getValue().status()); + + // Verify new fields in the captured request + PutLlmJudgmentRequest capturedRequest = requestCaptor.getValue(); + assertEquals("Query: {{queryText}}\\n\\nDocuments: {{hits}}", capturedRequest.getPromptTemplate()); + assertEquals("SCORE0_1", capturedRequest.getLlmJudgmentRatingType().name()); + assertEquals(true, capturedRequest.isOverwriteCache()); + } + + public void testPutLlmJudgment_InvalidRatingType() throws Exception { + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"type\": \"LLM_JUDGMENT\"," + + "\"modelId\": \"test_model_id\"," + + "\"querySetId\": \"test_query_set_id\"," + + "\"searchConfigurationList\": [\"config1\", \"config2\"]," + + "\"size\": 10," + + "\"tokenLimit\": 1000," + + "\"contextFields\": [\"field1\", \"field2\"]," + + "\"ignoreFailure\": false," + + "\"llmJudgmentRatingType\": \"INVALID_RATING_TYPE\"" + + "}"; + RestRequest request = createPutRestRequestWithContent(content, "judgment"); + when(channel.request()).thenReturn(request); + + // Execute and verify + SearchRelevanceException exception = expectThrows( + SearchRelevanceException.class, + () -> restPutJudgmentAction.handleRequest(request, channel, client) + ); + assertTrue(exception.getMessage().contains("Invalid RatingType")); + assertTrue(exception.getMessage().contains("INVALID_RATING_TYPE")); + assertTrue(exception.getMessage().contains("Valid values are")); + assertTrue(exception.getMessage().contains("SCORE0_1")); + assertTrue(exception.getMessage().contains("RELEVANT_IRRELEVANT")); + assertEquals(RestStatus.BAD_REQUEST, exception.status()); + } + + public void testPutLlmJudgment_InvalidPromptTemplate_MissingHitsPlaceholder() throws Exception { + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"type\": \"LLM_JUDGMENT\"," + + "\"modelId\": \"test_model_id\"," + + "\"querySetId\": \"test_query_set_id\"," + + "\"searchConfigurationList\": [\"config1\", \"config2\"]," + + "\"size\": 10," + + "\"tokenLimit\": 1000," + + "\"contextFields\": [\"field1\", \"field2\"]," + + "\"ignoreFailure\": false," + + "\"promptTemplate\": \"Query: {{queryText}}\\\\nRate relevance from 0.0 to 1.0\"" + + "}"; + RestRequest request = createPutRestRequestWithContent(content, "judgment"); + when(channel.request()).thenReturn(request); + + // Execute and verify + SearchRelevanceException exception = expectThrows( + SearchRelevanceException.class, + () -> restPutJudgmentAction.handleRequest(request, channel, client) + ); + assertTrue(exception.getMessage().contains("must include either {{hits}} or {{results}} placeholder")); + assertTrue(exception.getMessage().contains("Example:")); + assertEquals(RestStatus.BAD_REQUEST, exception.status()); + } } diff --git a/src/test/java/org/opensearch/searchrelevance/rest/RestPutQuerySetActionTests.java b/src/test/java/org/opensearch/searchrelevance/rest/RestPutQuerySetActionTests.java index 424e4a10..12901479 100644 --- a/src/test/java/org/opensearch/searchrelevance/rest/RestPutQuerySetActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/rest/RestPutQuerySetActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.*; import java.io.IOException; +import java.util.Map; import org.mockito.ArgumentCaptor; import org.opensearch.action.index.IndexResponse; @@ -189,6 +190,161 @@ public void testPrepareRequest_InvalidReferenceAnswer() throws Exception { verify(channel).sendResponse(responseCaptor.capture()); assertEquals(RestStatus.BAD_REQUEST, responseCaptor.getValue().status()); String response = responseCaptor.getValue().content().utf8ToString(); - assertTrue("Response should contain 'Invalid referenceAnswer': " + response, response.contains("Invalid referenceAnswer")); + assertTrue( + "Response should contain error about invalid referenceAnswer value: " + response, + response.contains("referenceAnswer") && response.contains("invalid characters") + ); + } + + public void testPrepareRequest_WithNumericExpectedScore() throws Exception { + // Test that numeric values like expectedScore are properly converted to strings + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"test\", \"expectedScore\": 1.0}" + + "]" + + "}"; + + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + when(settingsAccessor.getMaxQuerySetAllowed()).thenReturn(1000); + RestRequest request = createPutRestRequestWithContent(content, "query_sets"); + when(channel.request()).thenReturn(request); + + // Mock index response + IndexResponse mockIndexResponse = mock(IndexResponse.class); + when(mockIndexResponse.getId()).thenReturn("test_id"); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutQuerySetRequest.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockIndexResponse); + return null; + }).when(client).execute(eq(PutQuerySetAction.INSTANCE), requestCaptor.capture(), any()); + + // Execute + restPutQuerySetAction.handleRequest(request, channel, client); + + // Verify the expectedScore was converted to string "1.0" + PutQuerySetRequest capturedRequest = requestCaptor.getValue(); + assertEquals("1.0", capturedRequest.getQuerySetQueries().get(0).getCustomizedKeyValueMap().get("expectedScore")); + } + + public void testPrepareRequest_WithBooleanValue() throws Exception { + // Test that boolean values are properly converted to strings + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"test\", \"isRelevant\": true}" + + "]" + + "}"; + + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + when(settingsAccessor.getMaxQuerySetAllowed()).thenReturn(1000); + RestRequest request = createPutRestRequestWithContent(content, "query_sets"); + when(channel.request()).thenReturn(request); + + // Mock index response + IndexResponse mockIndexResponse = mock(IndexResponse.class); + when(mockIndexResponse.getId()).thenReturn("test_id"); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutQuerySetRequest.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockIndexResponse); + return null; + }).when(client).execute(eq(PutQuerySetAction.INSTANCE), requestCaptor.capture(), any()); + + // Execute + restPutQuerySetAction.handleRequest(request, channel, client); + + // Verify the boolean was converted to string "true" + PutQuerySetRequest capturedRequest = requestCaptor.getValue(); + assertEquals("true", capturedRequest.getQuerySetQueries().get(0).getCustomizedKeyValueMap().get("isRelevant")); + } + + public void testPrepareRequest_WithIntegerValue() throws Exception { + // Test that integer values are properly converted to strings + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"test\", \"rank\": 5}" + + "]" + + "}"; + + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + when(settingsAccessor.getMaxQuerySetAllowed()).thenReturn(1000); + RestRequest request = createPutRestRequestWithContent(content, "query_sets"); + when(channel.request()).thenReturn(request); + + // Mock index response + IndexResponse mockIndexResponse = mock(IndexResponse.class); + when(mockIndexResponse.getId()).thenReturn("test_id"); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutQuerySetRequest.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockIndexResponse); + return null; + }).when(client).execute(eq(PutQuerySetAction.INSTANCE), requestCaptor.capture(), any()); + + // Execute + restPutQuerySetAction.handleRequest(request, channel, client); + + // Verify the integer was converted to string "5" + PutQuerySetRequest capturedRequest = requestCaptor.getValue(); + assertEquals("5", capturedRequest.getQuerySetQueries().get(0).getCustomizedKeyValueMap().get("rank")); + } + + public void testPrepareRequest_WithMixedTypes() throws Exception { + // Test that multiple different types are all properly converted to strings + String content = "{" + + "\"name\": \"test_name\"," + + "\"description\": \"test_description\"," + + "\"sampling\": \"manual\"," + + "\"querySetQueries\": [" + + " {\"queryText\": \"test\", \"expectedScore\": 1.5, \"rank\": 3, \"isRelevant\": true, \"category\": \"product\"}" + + "]" + + "}"; + + // Setup + when(settingsAccessor.isWorkbenchEnabled()).thenReturn(true); + when(settingsAccessor.getMaxQuerySetAllowed()).thenReturn(1000); + RestRequest request = createPutRestRequestWithContent(content, "query_sets"); + when(channel.request()).thenReturn(request); + + // Mock index response + IndexResponse mockIndexResponse = mock(IndexResponse.class); + when(mockIndexResponse.getId()).thenReturn("test_id"); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutQuerySetRequest.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockIndexResponse); + return null; + }).when(client).execute(eq(PutQuerySetAction.INSTANCE), requestCaptor.capture(), any()); + + // Execute + restPutQuerySetAction.handleRequest(request, channel, client); + + // Verify all types were converted to strings + PutQuerySetRequest capturedRequest = requestCaptor.getValue(); + Map customMap = capturedRequest.getQuerySetQueries().get(0).getCustomizedKeyValueMap(); + assertEquals("1.5", customMap.get("expectedScore")); + assertEquals("3", customMap.get("rank")); + assertEquals("true", customMap.get("isRelevant")); + assertEquals("product", customMap.get("category")); } } diff --git a/src/test/java/org/opensearch/searchrelevance/util/TextValidationUtilTests.java b/src/test/java/org/opensearch/searchrelevance/util/TextValidationUtilTests.java index 9623f55c..aee23322 100644 --- a/src/test/java/org/opensearch/searchrelevance/util/TextValidationUtilTests.java +++ b/src/test/java/org/opensearch/searchrelevance/util/TextValidationUtilTests.java @@ -117,4 +117,419 @@ public void testValidateDescription() { assertFalse(result.isValid()); assertEquals("Text exceeds maximum length of 250 characters", result.getErrorMessage()); } + + // ============================================ + // QuerySet Value Validation Tests + // ============================================ + + public void testValidateQuerySetValue_ValidValues() { + // Test valid values that don't contain reserved characters + List validValues = List.of( + "What is OpenSearch?", + "red shoes", + "High quality leather shoes", + "OpenSearch is a search and analytics suite", + "Category footwear", + "Expected score 0.95", + "user@example.com", + "path/to/resource", + "100%", + "$price", + "value=123", + "a+b", + "item1;item2" + ); + + for (String value : validValues) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(value); + assertTrue("Value should be valid: " + value, result.isValid()); + assertNull("Error message should be null for valid value: " + value, result.getErrorMessage()); + } + } + + public void testValidateQuerySetValue_ReservedCharacter_Newline() { + // Test that newline character is rejected + String valueWithNewline = "text with\nnewline"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(valueWithNewline); + assertFalse("Value with newline should be invalid", result.isValid()); + assertEquals("Text contains reserved characters (newline, #, or : are not allowed in QuerySet values)", result.getErrorMessage()); + } + + public void testValidateQuerySetValue_ReservedCharacter_Hash() { + // Test that hash character is rejected + String valueWithHash = "text with # hash"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(valueWithHash); + assertFalse("Value with # should be invalid", result.isValid()); + assertEquals("Text contains reserved characters (newline, #, or : are not allowed in QuerySet values)", result.getErrorMessage()); + } + + public void testValidateQuerySetValue_ReservedCharacter_Colon() { + // Test that colon character is rejected + String valueWithColon = "text with: colon"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(valueWithColon); + assertFalse("Value with : should be invalid", result.isValid()); + assertEquals("Text contains reserved characters (newline, #, or : are not allowed in QuerySet values)", result.getErrorMessage()); + } + + public void testValidateQuerySetValue_MultipleReservedCharacters() { + // Test values with multiple reserved characters + List invalidValues = List.of( + "query#text", + "key: value", + "line1\nline2", + "query#\nkey: value", + "text#with:multiple\nreserved" + ); + + for (String value : invalidValues) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(value); + assertFalse("Value should be invalid: " + value, result.isValid()); + assertEquals( + "Text contains reserved characters (newline, #, or : are not allowed in QuerySet values)", + result.getErrorMessage() + ); + } + } + + public void testValidateQuerySetValue_NullAndEmpty() { + // Test null value + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(null); + assertFalse(result.isValid()); + assertEquals("Text cannot be null", result.getErrorMessage()); + + // Test empty value + result = TextValidationUtil.validateQuerySetValue(""); + assertFalse(result.isValid()); + assertEquals("Text cannot be empty", result.getErrorMessage()); + } + + public void testValidateQuerySetValue_DangerousCharacters() { + // Test that dangerous characters are still caught + List dangerousValues = List.of("text with \"quotes\"", "text with \\backslash", "text with "); + + for (String value : dangerousValues) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(value); + assertFalse("Value with dangerous char should be invalid: " + value, result.isValid()); + assertTrue( + "Error should mention dangerous characters", + result.getErrorMessage().contains("invalid characters (quotes, backslashes, or HTML tags") + ); + } + } + + public void testValidateQuerySetValue_MaxLength() { + String validValue = "a".repeat(2000); + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(validValue); + assertTrue(result.isValid()); + assertNull(result.getErrorMessage()); + + String invalidValue = "a".repeat(2001); + result = TextValidationUtil.validateQuerySetValue(invalidValue); + assertFalse(result.isValid()); + assertEquals("Text exceeds maximum length of 2000 characters", result.getErrorMessage()); + } + + // ============================================ + // QuerySet Key Validation Tests + // ============================================ + + public void testValidateQuerySetKey_ValidKeys() { + // Test valid keys + List validKeys = List.of( + "referenceAnswer", + "category", + "brand", + "price", + "expectedScore", + "productCategory", + "targetAudience", + "priceRange", + "color", + "size", + "metadata" + ); + + for (String key : validKeys) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(key); + assertTrue("Key should be valid: " + key, result.isValid()); + assertNull("Error message should be null for valid key: " + key, result.getErrorMessage()); + } + } + + public void testValidateQuerySetKey_ReservedKeyName() { + // Test that "queryText" is a reserved key name + String reservedKey = "queryText"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(reservedKey); + assertFalse("'queryText' should be a reserved key", result.isValid()); + assertEquals("Key 'queryText' is reserved and cannot be used as a custom field name", result.getErrorMessage()); + } + + public void testValidateQuerySetKey_ReservedCharacters() { + // Test keys with reserved characters + List invalidKeys = List.of("key#with#hash", "key:with:colon", "key\nwith\nnewline", "key#with:multiple\nreserved"); + + for (String key : invalidKeys) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(key); + assertFalse("Key with reserved char should be invalid: " + key, result.isValid()); + assertEquals("Key contains reserved characters (newline, #, or : are not allowed in QuerySet keys)", result.getErrorMessage()); + } + } + + public void testValidateQuerySetKey_LeadingTrailingWhitespace() { + // Test keys with leading/trailing whitespace + List keysWithWhitespace = List.of(" leadingSpace", "trailingSpace ", " both ", "\tkey", "key\t"); + + for (String key : keysWithWhitespace) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(key); + assertFalse("Key with whitespace should be invalid: '" + key + "'", result.isValid()); + assertEquals("Key cannot have leading or trailing whitespace", result.getErrorMessage()); + } + } + + public void testValidateQuerySetKey_ValidWithInternalWhitespace() { + // Test that keys can have internal whitespace + String keyWithSpace = "expected score"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(keyWithSpace); + assertTrue("Key with internal whitespace should be valid", result.isValid()); + assertNull(result.getErrorMessage()); + } + + public void testValidateQuerySetKey_NullAndEmpty() { + // Test null key + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(null); + assertFalse(result.isValid()); + assertEquals("Key cannot be null", result.getErrorMessage()); + + // Test empty key + result = TextValidationUtil.validateQuerySetKey(""); + assertFalse(result.isValid()); + assertEquals("Key cannot be empty", result.getErrorMessage()); + } + + public void testValidateQuerySetKey_MaxLength() { + String validKey = "a".repeat(50); + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetKey(validKey); + assertTrue(result.isValid()); + assertNull(result.getErrorMessage()); + + String invalidKey = "a".repeat(51); + result = TextValidationUtil.validateQuerySetKey(invalidKey); + assertFalse(result.isValid()); + assertEquals("Key exceeds maximum length of 50 characters", result.getErrorMessage()); + } + + // ============================================ + // Integration Test: Validation Flow + // ============================================ + + public void testQuerySetValidation_CompleteFlow() { + // Simulate a complete QuerySet entry validation + String queryText = "What is OpenSearch?"; + String referenceAnswerKey = "referenceAnswer"; + String referenceAnswerValue = "OpenSearch is a search and analytics suite"; + String categoryKey = "category"; + String categoryValue = "technology"; + + // Validate queryText + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(queryText); + assertTrue("QueryText should be valid", result.isValid()); + + // Validate referenceAnswer key + result = TextValidationUtil.validateQuerySetKey(referenceAnswerKey); + assertTrue("ReferenceAnswer key should be valid", result.isValid()); + + // Validate referenceAnswer value + result = TextValidationUtil.validateQuerySetValue(referenceAnswerValue); + assertTrue("ReferenceAnswer value should be valid", result.isValid()); + + // Validate category key + result = TextValidationUtil.validateQuerySetKey(categoryKey); + assertTrue("Category key should be valid", result.isValid()); + + // Validate category value + result = TextValidationUtil.validateQuerySetValue(categoryValue); + assertTrue("Category value should be valid", result.isValid()); + } + + public void testQuerySetValidation_InvalidScenarios() { + // Test invalid queryText with reserved character + String invalidQueryText = "query#with#hash"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validateQuerySetValue(invalidQueryText); + assertFalse("QueryText with # should be invalid", result.isValid()); + + // Test invalid key name (reserved) + result = TextValidationUtil.validateQuerySetKey("queryText"); + assertFalse("Reserved key 'queryText' should be invalid", result.isValid()); + + // Test invalid value with colon + String invalidValue = "value: with colon"; + result = TextValidationUtil.validateQuerySetValue(invalidValue); + assertFalse("Value with : should be invalid", result.isValid()); + + // Test invalid key with newline + String invalidKey = "key\nwith\nnewline"; + result = TextValidationUtil.validateQuerySetKey(invalidKey); + assertFalse("Key with newline should be invalid", result.isValid()); + } + + // ============================================ + // Prompt Template Validation Tests + // ============================================ + + public void testValidatePromptTemplate_WithHitsPlaceholder() { + // Test valid templates with {{hits}} placeholder and query placeholders + List validTemplates = List.of( + "Query: {{queryText}}\n\nDocuments: {{hits}}", + "Rate these documents: {{hits}}\nQuery: {{queryText}}", + "Query: {{queryText}}\nCategory: {{category}}\nDocuments: {{hits}}", + "{{queryText}} - {{hits}} - {{referenceAnswer}}", + "Search: {{searchText}}\nResults: {{hits}}" + ); + + for (String template : validTemplates) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertTrue("Template with {{hits}} should be valid: " + template, result.isValid()); + assertNull("Error message should be null for valid template", result.getErrorMessage()); + } + } + + public void testValidatePromptTemplate_WithResultsPlaceholder() { + // Test valid templates with {{results}} placeholder and query placeholders + List validTemplates = List.of( + "Query: {{queryText}}\n\nDocuments: {{results}}", + "Rate these documents: {{results}}\nQuery: {{queryText}}", + "Query: {{queryText}}\nCategory: {{category}}\nDocuments: {{results}}", + "Search: {{searchText}}\nDocs: {{results}}" + ); + + for (String template : validTemplates) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertTrue("Template with {{results}} should be valid: " + template, result.isValid()); + assertNull("Error message should be null for valid template", result.getErrorMessage()); + } + } + + public void testValidatePromptTemplate_NullOrEmpty() { + // Null and empty templates are allowed (will use defaults) + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(null); + assertTrue("Null template should be valid (uses defaults)", result.isValid()); + assertNull(result.getErrorMessage()); + + result = TextValidationUtil.validatePromptTemplate(""); + assertTrue("Empty template should be valid (uses defaults)", result.isValid()); + assertNull(result.getErrorMessage()); + + result = TextValidationUtil.validatePromptTemplate(" "); + assertTrue("Whitespace-only template should be valid (uses defaults)", result.isValid()); + assertNull(result.getErrorMessage()); + } + + public void testValidatePromptTemplate_MissingHitsPlaceholder() { + // Test templates missing both {{hits}} and {{results}} placeholders + List invalidTemplates = List.of( + "Query: {{queryText}}", + "Rate relevance from 0.0 to 1.0\nQuery: {{queryText}}\nCategory: {{category}}", + "{{queryText}} - {{referenceAnswer}}", + "Query: {{query}}\nReference: {{reference}}" + ); + + for (String template : invalidTemplates) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertFalse("Template without {{hits}} or {{results}} should be invalid: " + template, result.isValid()); + assertTrue( + "Error should mention missing hits placeholder", + result.getErrorMessage().contains("must include either {{hits}} or {{results}} placeholder") + ); + assertTrue("Error should provide example", result.getErrorMessage().contains("Example:")); + } + } + + public void testValidatePromptTemplate_MissingQueryPlaceholder() { + // Test templates missing queryText/searchText placeholders + List invalidTemplates = List.of( + "Documents: {{hits}}", + "Rate these documents: {{hits}}\nCategory: {{category}}", + "{{hits}} - {{referenceAnswer}}", + "Results: {{results}}\nReference: {{reference}}" + ); + + for (String template : invalidTemplates) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertFalse("Template without query placeholder should be invalid: " + template, result.isValid()); + assertTrue( + "Error should mention missing query placeholder", + result.getErrorMessage().contains("must include either {{queryText}} or {{searchText}} placeholder") + ); + assertTrue("Error should provide example", result.getErrorMessage().contains("Example:")); + } + } + + public void testValidatePromptTemplate_MissingBothPlaceholders() { + // Test template missing both required placeholders + String template = "Just some plain text without placeholders"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertFalse("Template without any placeholders should be invalid", result.isValid()); + // Should fail on the first check (hits/results) + assertTrue( + "Error should mention missing hits placeholder", + result.getErrorMessage().contains("must include either {{hits}} or {{results}} placeholder") + ); + } + + public void testValidatePromptTemplate_CaseSensitive() { + // Test that placeholder matching is case-sensitive + List invalidTemplates = List.of( + "Query: {{queryText}}\nDocuments: {{HITS}}", + "Query: {{queryText}}\nDocuments: {{Hits}}", + "Query: {{queryText}}\nDocuments: {{Results}}", + "Query: {{queryText}}\nDocuments: {{RESULTS}}" + ); + + for (String template : invalidTemplates) { + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertFalse("Case-sensitive: " + template + " should be invalid", result.isValid()); + } + } + + public void testValidatePromptTemplate_BothPlaceholders() { + // Test that template can have both {{hits}} and {{results}} (though unusual) + String template = "Query: {{queryText}}\nPrimary: {{hits}}\nAlternate: {{results}}"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertTrue("Template with both hits and results placeholders should be valid", result.isValid()); + assertNull(result.getErrorMessage()); + } + + public void testValidatePromptTemplate_ContainsDelimiter() { + // Test that template cannot contain the reserved delimiter character (#) + String template = "Query: {{queryText}}#Documents: {{hits}}"; + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(template); + assertFalse("Template with delimiter character should be invalid", result.isValid()); + assertTrue( + "Error should mention delimiter character", + result.getErrorMessage().contains("reserved delimiter character") && result.getErrorMessage().contains("#") + ); + } + + public void testValidatePromptTemplate_ExceedsMaxLength() { + // Test that template cannot exceed maximum length (10000 characters) + StringBuilder longTemplate = new StringBuilder("Query: {{queryText}}\nDocuments: {{hits}}\n"); + while (longTemplate.length() < 10001) { + longTemplate.append("This is a very long template that exceeds the maximum allowed length. "); + } + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(longTemplate.toString()); + assertFalse("Template exceeding max length should be invalid", result.isValid()); + assertTrue("Error should mention maximum length", result.getErrorMessage().contains("exceeds maximum length")); + assertTrue("Error should mention 10000 characters", result.getErrorMessage().contains("10000")); + } + + public void testValidatePromptTemplate_ValidLongTemplate() { + // Test that a long but valid template (under 10000 characters) is accepted + StringBuilder longTemplate = new StringBuilder("Query: {{queryText}}\nDocuments: {{hits}}\n"); + while (longTemplate.length() < 9990) { + longTemplate.append("This is a long template. "); + } + TextValidationUtil.ValidationResult result = TextValidationUtil.validatePromptTemplate(longTemplate.toString()); + assertTrue("Valid long template should be accepted", result.isValid()); + assertNull(result.getErrorMessage()); + } } diff --git a/src/test/java/org/opensearch/searchrelevance/utils/ParserUtilsTests.java b/src/test/java/org/opensearch/searchrelevance/utils/ParserUtilsTests.java new file mode 100644 index 00000000..988f3c9c --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/utils/ParserUtilsTests.java @@ -0,0 +1,362 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.utils; + +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +/** + * Unit tests for ParserUtils + */ +public class ParserUtilsTests extends OpenSearchTestCase { + + /** + * Test getDocIdFromCompositeKey with standard composite key format (index::docId) + */ + public void testGetDocIdFromCompositeKeyWithCompositeFormat() { + String compositeKey = "test_products::123"; + String docId = ParserUtils.getDocIdFromCompositeKey(compositeKey); + assertEquals("Should extract docId from composite key", "123", docId); + } + + /** + * Test getDocIdFromCompositeKey with multiple :: separators + * Note: split("::") without limit splits on all occurrences, + * so this extracts the second element, not everything after first :: + */ + public void testGetDocIdFromCompositeKeyWithMultipleSeparators() { + String compositeKey = "index::with::colons::docId123"; + String docId = ParserUtils.getDocIdFromCompositeKey(compositeKey); + // split("::") returns ["index", "with", "colons", "docId123"], so [1] = "with" + assertEquals("Should extract second element", "with", docId); + } + + /** + * Test getDocIdFromCompositeKey with plain docId (no ::) + * This is a regression test for the bug where LLM returns plain docIds + * instead of composite keys, causing ArrayIndexOutOfBoundsException + */ + public void testGetDocIdFromCompositeKeyWithPlainDocId() { + String plainDocId = "123"; + String docId = ParserUtils.getDocIdFromCompositeKey(plainDocId); + assertEquals("Should return plain docId as-is", "123", docId); + } + + /** + * Test getDocIdFromCompositeKey with various plain docId formats + */ + public void testGetDocIdFromCompositeKeyVariousPlainFormats() { + // Numeric docId + assertEquals("1", ParserUtils.getDocIdFromCompositeKey("1")); + + // Alphanumeric docId + assertEquals("abc123", ParserUtils.getDocIdFromCompositeKey("abc123")); + + // UUID-like docId + assertEquals("550e8400-e29b-41d4-a716-446655440000", ParserUtils.getDocIdFromCompositeKey("550e8400-e29b-41d4-a716-446655440000")); + + // DocId with hyphens (but no ::) + assertEquals("doc-123-456", ParserUtils.getDocIdFromCompositeKey("doc-123-456")); + } + + /** + * Test getDocIdFromCompositeKey with edge cases + */ + public void testGetDocIdFromCompositeKeyEdgeCases() { + // DocId with special characters + String specialChars = "index::doc_id-123.test"; + String result3 = ParserUtils.getDocIdFromCompositeKey(specialChars); + assertEquals("Should preserve special characters", "doc_id-123.test", result3); + + // DocId with numbers + String withNumbers = "products::12345"; + String result4 = ParserUtils.getDocIdFromCompositeKey(withNumbers); + assertEquals("Should extract numeric docId", "12345", result4); + } + + /** + * Test combinedIndexAndDocId creates proper composite keys + */ + public void testCombinedIndexAndDocId() { + String compositeKey = ParserUtils.combinedIndexAndDocId("test_index", "doc123"); + assertEquals("Should create composite key with :: separator", "test_index::doc123", compositeKey); + + // Verify round-trip + String extractedDocId = ParserUtils.getDocIdFromCompositeKey(compositeKey); + assertEquals("Should extract original docId", "doc123", extractedDocId); + } + + /** + * Test combinedIndexAndDocId with special characters + */ + public void testCombinedIndexAndDocIdWithSpecialChars() { + String compositeKey = ParserUtils.combinedIndexAndDocId("my-index_123", "doc-456.test"); + assertEquals("Should handle special characters", "my-index_123::doc-456.test", compositeKey); + + String extractedDocId = ParserUtils.getDocIdFromCompositeKey(compositeKey); + assertEquals("Should extract docId with special chars", "doc-456.test", extractedDocId); + } + + // ============================================ + // parseQueryTextWithCustomInput Tests + // ============================================ + + public void testParseQueryTextWithCustomInput_QueryOnly() { + // Test with only query text, no reference data + String input = "What is OpenSearch?"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "What is OpenSearch?", result.get("queryText")); + assertEquals("Should only contain queryText", 1, result.size()); + } + + public void testParseQueryTextWithCustomInput_JsonFormat() { + // Test current JSON format: queryText#{"key1":"value1","key2":"value2"} + String input = "What is OpenSearch?#{\"referenceAnswer\":\"OpenSearch is a search suite\",\"category\":\"technology\"}"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "What is OpenSearch?", result.get("queryText")); + assertEquals("Reference answer should be parsed", "OpenSearch is a search suite", result.get("referenceAnswer")); + assertEquals("Category should be parsed", "technology", result.get("category")); + assertEquals("Should contain queryText, referenceAnswer, and category", 3, result.size()); + } + + public void testParseQueryTextWithCustomInput_JsonFormatMultipleFields() { + // Test JSON format with multiple custom fields + String input = + "red shoes#{\"referenceAnswer\":\"High quality leather shoes\",\"color\":\"red\",\"brand\":\"Nike\",\"price\":\"120\"}"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "red shoes", result.get("queryText")); + assertEquals("Reference answer should be parsed", "High quality leather shoes", result.get("referenceAnswer")); + assertEquals("Color should be parsed", "red", result.get("color")); + assertEquals("Brand should be parsed", "Nike", result.get("brand")); + assertEquals("Price should be parsed", "120", result.get("price")); + assertEquals("Should contain 5 entries", 5, result.size()); + } + + public void testParseQueryTextWithCustomInput_LegacyPlainFormat() { + // Test legacy plain format: queryText#referenceAnswer + String input = "What is OpenSearch?#OpenSearch is a community-driven, open source search and analytics suite"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "What is OpenSearch?", result.get("queryText")); + assertEquals( + "Reference answer should be parsed", + "OpenSearch is a community-driven, open source search and analytics suite", + result.get("referenceAnswer") + ); + assertEquals("Should contain queryText and referenceAnswer", 2, result.size()); + } + + public void testParseQueryTextWithCustomInput_EmptyReferenceContent() { + // Test with delimiter but empty content after it + String input = "test query#"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "test query", result.get("queryText")); + assertEquals("Should only contain queryText", 1, result.size()); + } + + public void testParseQueryTextWithCustomInput_JsonFormatWithSpecialCharacters() { + // Test JSON format with special characters in values (colons, quotes, etc.) + String input = "test query#{\"url\":\"https://example.com:8080\",\"description\":\"Product with \\\"quotes\\\"\"}"; + Map result = ParserUtils.parseQueryTextWithCustomInput(input); + + assertEquals("Query text should be parsed", "test query", result.get("queryText")); + assertEquals("URL with colons should be parsed", "https://example.com:8080", result.get("url")); + assertEquals("Description with quotes should be parsed", "Product with \"quotes\"", result.get("description")); + assertEquals("Should contain 3 entries", 3, result.size()); + } + + // ============================================ + // QuerySetEntry Format Integration Tests + // ============================================ + + public void testQuerySetEntry_OldFormat_SingleReferenceAnswer() { + // Test old QuerySetEntry format: "queryText#referenceAnswer" + // This simulates the legacy format where queryText contains both query and reference answer + String querySetEntry = "What is OpenSearch?#OpenSearch is a community-driven, open source search and analytics suite"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "What is OpenSearch?", result.get("queryText")); + assertEquals( + "Reference answer should be extracted", + "OpenSearch is a community-driven, open source search and analytics suite", + result.get("referenceAnswer") + ); + assertEquals("Should contain queryText and referenceAnswer", 2, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; // Remaining entries are reference data + + assertEquals("Query text should be ready for ML", "What is OpenSearch?", queryText); + assertEquals("Reference data should contain referenceAnswer", 1, referenceData.size()); + assertTrue("Reference data should have referenceAnswer key", referenceData.containsKey("referenceAnswer")); + } + + public void testQuerySetEntry_JsonFormat_MultipleCustomFields() { + // Test new QuerySetEntry format from PutQuerySetTransportAction (JSON format) + String querySetEntry = + "red shoes#{\"referenceAnswer\":\"High quality red leather shoes\",\"color\":\"red\",\"brand\":\"Nike\",\"price\":\"120\"}"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "red shoes", result.get("queryText")); + assertEquals("Reference answer should be extracted", "High quality red leather shoes", result.get("referenceAnswer")); + assertEquals("Color should be extracted", "red", result.get("color")); + assertEquals("Brand should be extracted", "Nike", result.get("brand")); + assertEquals("Price should be extracted", "120", result.get("price")); + assertEquals("Should contain all fields", 5, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; // Remaining entries are reference data + + assertEquals("Query text should be ready for ML", "red shoes", queryText); + assertEquals("Reference data should contain all custom fields", 4, referenceData.size()); + assertTrue("Reference data should have referenceAnswer", referenceData.containsKey("referenceAnswer")); + assertTrue("Reference data should have color", referenceData.containsKey("color")); + assertTrue("Reference data should have brand", referenceData.containsKey("brand")); + assertTrue("Reference data should have price", referenceData.containsKey("price")); + } + + public void testQuerySetEntry_JsonFormat_OnlyReferenceAnswer() { + // Test JSON format with only referenceAnswer (no other custom fields) + String querySetEntry = "What is OpenSearch?#{\"referenceAnswer\":\"OpenSearch is a search and analytics suite\"}"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "What is OpenSearch?", result.get("queryText")); + assertEquals("Reference answer should be extracted", "OpenSearch is a search and analytics suite", result.get("referenceAnswer")); + assertEquals("Should contain queryText and referenceAnswer", 2, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; + + assertEquals("Query text should be ready for ML", "What is OpenSearch?", queryText); + assertEquals("Reference data should contain only referenceAnswer", 1, referenceData.size()); + } + + public void testQuerySetEntry_JsonFormat_NoReferenceAnswerOnlyCustomFields() { + // Test JSON format with custom fields but no referenceAnswer + String querySetEntry = "test query#{\"category\":\"technology\",\"expectedScore\":\"0.9\",\"difficulty\":\"medium\"}"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "test query", result.get("queryText")); + assertEquals("Category should be extracted", "technology", result.get("category")); + assertEquals("Expected score should be extracted", "0.9", result.get("expectedScore")); + assertEquals("Difficulty should be extracted", "medium", result.get("difficulty")); + assertFalse("Should not have referenceAnswer", result.containsKey("referenceAnswer")); + assertEquals("Should contain queryText and 3 custom fields", 4, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; + + assertEquals("Query text should be ready for ML", "test query", queryText); + assertEquals("Reference data should contain custom fields", 3, referenceData.size()); + assertFalse("Reference data should not have referenceAnswer", referenceData.containsKey("referenceAnswer")); + } + + public void testQuerySetEntry_OldFormat_EmptyReferenceAnswer() { + // Test old format with empty reference answer + String querySetEntry = "What is OpenSearch?#"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "What is OpenSearch?", result.get("queryText")); + assertEquals("Should only contain queryText", 1, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; + + assertEquals("Query text should be ready for ML", "What is OpenSearch?", queryText); + assertTrue("Reference data should be empty", referenceData.isEmpty()); + } + + public void testQuerySetEntry_NoDelimiter_QueryOnly() { + // Test entry with no delimiter (just query text) + String querySetEntry = "What is OpenSearch?"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + assertEquals("Query text should be extracted", "What is OpenSearch?", result.get("queryText")); + assertEquals("Should only contain queryText", 1, result.size()); + + // Verify this can be used for ML processing + String queryText = result.remove("queryText"); + Map referenceData = result; + + assertEquals("Query text should be ready for ML", "What is OpenSearch?", queryText); + assertTrue("Reference data should be empty", referenceData.isEmpty()); + } + + public void testQuerySetEntry_BackwardCompatibility_LegacyToJson() { + // Test that legacy plain format and new JSON format both work + String legacyFormatEntry = "test query#expected answer"; + String jsonFormatEntry = "test query#{\"referenceAnswer\":\"expected answer\"}"; + + Map legacyResult = ParserUtils.parseQueryTextWithCustomInput(legacyFormatEntry); + Map jsonResult = ParserUtils.parseQueryTextWithCustomInput(jsonFormatEntry); + + // Both should extract the same queryText + assertEquals("Query text should match", legacyResult.get("queryText"), jsonResult.get("queryText")); + + // Both should have referenceAnswer + assertEquals("Both should have referenceAnswer", legacyResult.get("referenceAnswer"), jsonResult.get("referenceAnswer")); + + // Both should have the same size + assertEquals("Both should have same number of entries", legacyResult.size(), jsonResult.size()); + } + + public void testQuerySetEntry_JsonFormat_RealWorldExample() { + // Test real-world example from PutQuerySetTransportAction (JSON format) + String querySetEntry = + "red leather shoes#{\"referenceAnswer\":\"High quality red leather shoes with rubber sole and comfortable insole\"," + + "\"expectedRelevanceScore\":\"0.95\"," + + "\"productCategory\":\"footwear\"," + + "\"targetAudience\":\"adults\"," + + "\"priceRange\":\"premium\"}"; + + Map result = ParserUtils.parseQueryTextWithCustomInput(querySetEntry); + + // Verify all fields are extracted + assertEquals("Query text should be extracted", "red leather shoes", result.get("queryText")); + assertEquals( + "Reference answer should be extracted", + "High quality red leather shoes with rubber sole and comfortable insole", + result.get("referenceAnswer") + ); + assertEquals("Expected score should be extracted", "0.95", result.get("expectedRelevanceScore")); + assertEquals("Category should be extracted", "footwear", result.get("productCategory")); + assertEquals("Target audience should be extracted", "adults", result.get("targetAudience")); + assertEquals("Price range should be extracted", "premium", result.get("priceRange")); + assertEquals("Should contain all 6 fields", 6, result.size()); + + // Verify this can be used for ML processing and UserPromptFactory + String queryText = result.remove("queryText"); + Map referenceData = result; + + assertEquals("Query text should be ready", "red leather shoes", queryText); + assertEquals("Reference data should have 5 custom fields", 5, referenceData.size()); + + // All these fields can now be used in UserPromptFactory with template variables + assertTrue("Should have all fields for template replacement", referenceData.containsKey("referenceAnswer")); + assertTrue("Should have expectedRelevanceScore", referenceData.containsKey("expectedRelevanceScore")); + assertTrue("Should have productCategory", referenceData.containsKey("productCategory")); + assertTrue("Should have targetAudience", referenceData.containsKey("targetAudience")); + assertTrue("Should have priceRange", referenceData.containsKey("priceRange")); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/utils/RatingOutputProcessorTests.java b/src/test/java/org/opensearch/searchrelevance/utils/RatingOutputProcessorTests.java new file mode 100644 index 00000000..669731b0 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/utils/RatingOutputProcessorTests.java @@ -0,0 +1,470 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.utils; + +import org.opensearch.test.OpenSearchTestCase; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Unit tests for RatingOutputProcessor with focus on GPT-3.5 unstructured output handling. + */ +public class RatingOutputProcessorTests extends OpenSearchTestCase { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + public void testStructuredOutputWithRatingsArray() throws Exception { + // GPT-4o with response_format: {"ratings": [...]} + String response = "{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 4}, {\"id\": \"doc2\", \"rating_score\": 5}]}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + assertEquals(4, resultNode.get(0).get("rating_score").asInt()); + } + + public void testDirectJsonArray() throws Exception { + // Already an array + String response = "[{\"id\": \"doc1\", \"rating_score\": 3}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + } + + public void testMarkdownCodeBlockWithJson() throws Exception { + // GPT-3.5 response with markdown code block + String response = "Here are the ratings:\n\n```json\n{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 4}]}\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testMarkdownCodeBlockWithoutJsonTag() throws Exception { + // GPT-3.5 response with markdown code block without 'json' tag + String response = "Here are the ratings:\n\n```\n{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 5}]}\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + } + + public void testEmbeddedJsonInText() throws Exception { + // GPT-3.5 response with JSON embedded in prose + String response = + "Based on the query, here is my evaluation: {\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 3}]} as requested."; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + } + + public void testEmbeddedJsonArray() throws Exception { + // GPT-3.5 response with JSON array embedded in text + String response = "The ratings are: [{\"id\": \"doc1\", \"rating_score\": 4}, {\"id\": \"doc2\", \"rating_score\": 2}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + } + + public void testComplexUnstructuredResponse() throws Exception { + // Realistic GPT-3.5 response + String response = "I'll rate each document based on relevance:\n\n" + + "```json\n" + + "{\n" + + " \"ratings\": [\n" + + " {\"id\": \"query1_doc1\", \"rating_score\": 4},\n" + + " {\"id\": \"query1_doc2\", \"rating_score\": 5},\n" + + " {\"id\": \"query1_doc3\", \"rating_score\": 2}\n" + + " ]\n" + + "}\n" + + "```\n\n" + + "These ratings reflect the relevance of each document."; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(3, resultNode.size()); + assertEquals("query1_doc1", resultNode.get(0).get("id").asText()); + assertEquals(4, resultNode.get(0).get("rating_score").asInt()); + } + + public void testEmptyResponse() throws Exception { + String result = RatingOutputProcessor.sanitizeLLMResponse(""); + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(0, resultNode.size()); + } + + public void testNullResponse() throws Exception { + String result = RatingOutputProcessor.sanitizeLLMResponse(null); + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(0, resultNode.size()); + } + + public void testUnparseableText() throws Exception { + // Pure text with no JSON + String response = "This is just plain text without any JSON structure."; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(0, resultNode.size()); + } + + public void testMultipleJsonObjectsSelectsFirst() throws Exception { + // Multiple JSON objects - should select the first valid one + String response = "{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 4}]} and also {\"other\": \"data\"}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testArrayAppearsBeforeObject() throws Exception { + // Array appears before object - should extract array + String response = "Result: [{\"id\": \"doc1\", \"rating_score\": 4}] or {\"ratings\": [...]}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testArrayWithMultipleElementsInText() throws Exception { + // This is the scenario that was failing - array with 2 elements embedded in text + String response = + "Here are the results: [{\"id\": \"doc1\", \"rating_score\": 4}, {\"id\": \"doc2\", \"rating_score\": 2}] as requested"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + assertEquals("doc2", resultNode.get(1).get("id").asText()); + } + + public void testNestedArrayInObject() throws Exception { + // Object with nested array - should extract the ratings array + String response = "Text before {\"meta\": \"data\", \"ratings\": [{\"id\": \"doc1\", \"rating_score\": 5}]} text after"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testMultipleArraysSelectsFirst() throws Exception { + // Multiple arrays - should select the first one + String response = "First: [{\"id\": \"doc1\", \"rating_score\": 4}] Second: [{\"id\": \"doc2\", \"rating_score\": 3}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testObjectBeforeArrayInText() throws Exception { + // Realistic case: Object appears first in prose, then array + String response = "Status: {\"status\": \"ok\"}. Here are the ratings: [{\"id\": \"doc1\", \"rating_score\": 4}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + // Should extract the first valid JSON structure (the object), + // and since it doesn't have ratings field, it wraps it in an array + assertTrue(resultNode.isArray()); + // Will extract the first object and wrap it + assertEquals(1, resultNode.size()); + } + + public void testComplexNestedStructure() throws Exception { + // Complex structure with nested objects and arrays + String response = + "The LLM response:\n```json\n{\n \"explanation\": \"analysis\",\n \"ratings\": [\n {\"id\": \"q1_d1\", \"rating_score\": 5},\n {\"id\": \"q1_d2\", \"rating_score\": 3},\n {\"id\": \"q1_d3\", \"rating_score\": 1}\n ]\n}\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(3, resultNode.size()); + assertEquals("q1_d1", resultNode.get(0).get("id").asText()); + assertEquals(5, resultNode.get(0).get("rating_score").asInt()); + } + + public void testArrayWithNoRatingsKey() throws Exception { + // Direct array without "ratings" wrapper - common GPT-3.5 format + String response = + "[{\"id\": \"doc1\", \"rating_score\": 4}, {\"id\": \"doc2\", \"rating_score\": 2}, {\"id\": \"doc3\", \"rating_score\": 5}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(3, resultNode.size()); + } + + public void testMalformedJsonReturnsEmpty() throws Exception { + // Malformed JSON should return empty array + String response = "Text with {broken json [that doesn't close properly"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(0, resultNode.size()); + } + + public void testProseWithCodeBlockContainingArray() throws Exception { + // GPT-3.5 style response with explanation and code block + String response = "I've evaluated each document based on relevance.\n\n" + + "```\n" + + "[{\"id\": \"doc1\", \"rating_score\": 0.9}, {\"id\": \"doc2\", \"rating_score\": 0.5}]\n" + + "```\n\n" + + "The first document is highly relevant."; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + // ============================================ + // Tests for improved state machine - handling braces/brackets inside strings + // ============================================ + + public void testJsonWithBracesInsideStrings() throws Exception { + // JSON object with braces inside string values - state machine should handle correctly + String response = "{\"ratings\": [{\"id\": \"doc1\", \"comment\": \"This {has} braces\", \"rating_score\": 4}]}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + assertEquals("This {has} braces", resultNode.get(0).get("comment").asText()); + } + + public void testJsonWithBracketsInsideStrings() throws Exception { + // JSON with brackets inside string values + String response = "[{\"id\": \"doc1\", \"title\": \"Array [1,2,3] reference\", \"rating_score\": 3}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("Array [1,2,3] reference", resultNode.get(0).get("title").asText()); + } + + public void testJsonWithEscapedQuotesInStrings() throws Exception { + // JSON with escaped quotes - state machine should handle properly + String response = "[{\"id\": \"doc1\", \"text\": \"He said \\\"hello\\\"\", \"rating_score\": 5}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("He said \"hello\"", resultNode.get(0).get("text").asText()); + } + + public void testJsonWithComplexEscapedContent() throws Exception { + // JSON with multiple escape sequences and special characters + String response = "{\"ratings\": [{\"id\": \"doc1\", \"note\": \"Path: C:\\\\Users\\\\file.txt\", \"rating_score\": 4}]}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("Path: C:\\Users\\file.txt", resultNode.get(0).get("note").asText()); + } + + public void testJsonWithMixedQuotes() throws Exception { + // JSON with both single and double quotes in strings (JSON standard requires double quotes for keys) + String response = "[{\"id\": \"doc1\", \"content\": \"It's a good match\", \"rating_score\": 4}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("It's a good match", resultNode.get(0).get("content").asText()); + } + + // ============================================ + // Tests for different line endings (CRLF vs LF) + // ============================================ + + public void testMarkdownCodeBlockWithCRLF() throws Exception { + // Windows-style line endings (CRLF) + String response = "Here are the ratings:\r\n\r\n```json\r\n" + + "{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 4}]}\r\n" + + "```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testJsonWithMixedLineEndings() throws Exception { + // Mixed CRLF and LF + String response = "Result:\n\r```\r\n[{\"id\": \"doc1\", \"rating_score\": 5}]\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + } + + // ============================================ + // Tests for multiple code blocks and other language tags + // ============================================ + + public void testMultipleCodeBlocksSelectsFirst() throws Exception { + // Multiple code blocks - should extract from the first one + String response = "First block:\n```json\n[{\"id\": \"doc1\", \"rating_score\": 4}]\n```\n\n" + + "Second block:\n```json\n[{\"id\": \"doc2\", \"rating_score\": 3}]\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testCodeBlockWithPythonTag() throws Exception { + // Code block with 'python' tag instead of 'json' - should still extract JSON + String response = "Here's the output:\n```python\n" + "[{\"id\": \"doc1\", \"rating_score\": 4}]\n" + "```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + // This will fail to extract from markdown (non-json tag), but should fall back to pattern extraction + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + // May be empty or may extract depending on fallback - at least should not crash + } + + public void testCodeBlockWithJavaScriptTag() throws Exception { + // Code block with 'javascript' tag - fallback to pattern extraction + String response = "```javascript\n{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 5}]}\n```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + // Should fall back to pattern extraction and still work + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + } + + public void testExplanationBeforeCodeBlock() throws Exception { + // Realistic: Long explanation before the actual JSON + String response = "Let me explain my reasoning for these ratings:\n\n" + + "Document 1 appears highly relevant because it contains...\n" + + "Document 2 is less relevant due to...\n\n" + + "Here are my final ratings:\n\n" + + "```json\n" + + "{\"ratings\": [{\"id\": \"doc1\", \"rating_score\": 5}, {\"id\": \"doc2\", \"rating_score\": 2}]}\n" + + "```"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + // ============================================ + // Tests for inline JSON and edge cases + // ============================================ + + public void testInlineJsonWithSurroundingText() throws Exception { + // Inline JSON with lots of surrounding prose + String response = "After analyzing the query and documents, I believe the ratings should be " + + "[{\"id\": \"doc1\", \"rating_score\": 4}, {\"id\": \"doc2\", \"rating_score\": 3}] " + + "because these scores reflect the relevance accurately."; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + } + + public void testJsonWithNestedObjectsAndArrays() throws Exception { + // Complex nested structure that state machine should handle + String response = + "{\"ratings\": [{\"id\": \"doc1\", \"details\": {\"score\": 5, \"factors\": [\"a\", \"b\"]}, \"rating_score\": 5}]}"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("doc1", resultNode.get(0).get("id").asText()); + } + + public void testMalformedJsonWithExtraComma() throws Exception { + // Common LLM mistake: trailing comma + String response = "[{\"id\": \"doc1\", \"rating_score\": 4,}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + // Jackson should fail to parse this, should return empty array + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + // Will likely be empty due to parse failure + } + + public void testJsonWithUnicodeCharacters() throws Exception { + // JSON with unicode characters + String response = "[{\"id\": \"doc1\", \"title\": \"Café résumé\", \"rating_score\": 4}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(1, resultNode.size()); + assertEquals("Café résumé", resultNode.get(0).get("title").asText()); + } + + public void testJsonArrayWithEmptyObjects() throws Exception { + // Edge case: array with empty objects + String response = "[{}, {\"id\": \"doc1\", \"rating_score\": 3}]"; + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(2, resultNode.size()); + } + + public void testVeryLongJsonResponse() throws Exception { + // Simulate a large response with many ratings + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < 100; i++) { + if (i > 0) { + sb.append(","); + } + sb.append("{\"id\": \"doc").append(i).append("\", \"rating_score\": ").append(i % 5).append("}"); + } + sb.append("]"); + String response = sb.toString(); + + String result = RatingOutputProcessor.sanitizeLLMResponse(response); + + JsonNode resultNode = OBJECT_MAPPER.readTree(result); + assertTrue(resultNode.isArray()); + assertEquals(100, resultNode.size()); + } +} diff --git a/src/test/resources/llmjudgment/BulkIngestProducts.json b/src/test/resources/llmjudgment/BulkIngestProducts.json new file mode 100644 index 00000000..fc6fe28f --- /dev/null +++ b/src/test/resources/llmjudgment/BulkIngestProducts.json @@ -0,0 +1,10 @@ +{"index":{"_index":"test_llm_products","_id":"1"}} +{"name":"Dell Laptop","description":"High performance laptop for professionals","category":"electronics","price":1200.00} +{"index":{"_index":"test_llm_products","_id":"2"}} +{"name":"Office Chair","description":"Ergonomic office chair with lumbar support","category":"furniture","price":299.99} +{"index":{"_index":"test_llm_products","_id":"3"}} +{"name":"Espresso Machine","description":"Premium coffee maker for home baristas","category":"kitchen","price":499.99} +{"index":{"_index":"test_llm_products","_id":"4"}} +{"name":"Running Shoes","description":"Comfortable athletic shoes for runners","category":"sports","price":129.99} +{"index":{"_index":"test_llm_products","_id":"5"}} +{"name":"MacBook Pro","description":"Apple laptop with M3 chip for developers","category":"electronics","price":2499.00} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentBinary.json b/src/test/resources/llmjudgment/CreateLlmJudgmentBinary.json new file mode 100644 index 00000000..7057e75b --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentBinary.json @@ -0,0 +1,14 @@ +{ + "name": "LLM Judgment Binary", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "llmJudgmentRatingType": "RELEVANT_IRRELEVANT", + "promptTemplate": "Query: {{queryText}}\n\nDocuments: {{hits}}\n\nIs this document relevant? Answer RELEVANT or IRRELEVANT.", + "overwriteCache": false +} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentMinimal.json b/src/test/resources/llmjudgment/CreateLlmJudgmentMinimal.json new file mode 100644 index 00000000..29000f7e --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentMinimal.json @@ -0,0 +1,11 @@ +{ + "name": "LLM Judgment Minimal", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false +} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteFalse.json b/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteFalse.json new file mode 100644 index 00000000..66a0be91 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteFalse.json @@ -0,0 +1,14 @@ +{ + "name": "LLM Judgment Overwrite Cache False", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "llmJudgmentRatingType": "SCORE0_1", + "promptTemplate": "Query: {{queryText}}\n\nDocuments: {{hits}}\n\nRate relevance 0-1", + "overwriteCache": false +} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteTrue.json b/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteTrue.json new file mode 100644 index 00000000..817dea94 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentOverwriteTrue.json @@ -0,0 +1,14 @@ +{ + "name": "LLM Judgment Overwrite Cache True", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "llmJudgmentRatingType": "SCORE0_1", + "promptTemplate": "Query: {{queryText}}\n\nDocuments: {{hits}}\n\nRate relevance 0-1", + "overwriteCache": true +} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentScore01.json b/src/test/resources/llmjudgment/CreateLlmJudgmentScore01.json new file mode 100644 index 00000000..e2edffda --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentScore01.json @@ -0,0 +1,14 @@ +{ + "name": "LLM Judgment SCORE0_1", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "llmJudgmentRatingType": "SCORE0_1", + "promptTemplate": "Query: {{queryText}}\n\nDocuments: {{hits}}\n\nRate the relevance from 0.0 to 1.0", + "overwriteCache": false +} diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentWithPromptTemplate.json b/src/test/resources/llmjudgment/CreateLlmJudgmentWithPromptTemplate.json new file mode 100644 index 00000000..3076b86f --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentWithPromptTemplate.json @@ -0,0 +1,14 @@ +{ + "name": "LLM Judgment with Prompt Template", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "llmJudgmentRatingType": "SCORE0_1", + "promptTemplate": "Given the query {{queryText}} and reference answer {{referenceAnswer}}, rate the relevance of these search results {{hits}} on a scale of 0-1.", + "overwriteCache": false +} diff --git a/src/test/resources/llmjudgment/CreateQuerySetSimple.json b/src/test/resources/llmjudgment/CreateQuerySetSimple.json new file mode 100644 index 00000000..f69222b2 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateQuerySetSimple.json @@ -0,0 +1,12 @@ +{ + "name": "Simple Query Set", + "description": "Simple query set for testing", + "querySetQueries": [ + { + "queryText": "laptop" + }, + { + "queryText": "chair" + } + ] +} diff --git a/src/test/resources/llmjudgment/CreateQuerySetWithCustomFields.json b/src/test/resources/llmjudgment/CreateQuerySetWithCustomFields.json new file mode 100644 index 00000000..73e4fd70 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateQuerySetWithCustomFields.json @@ -0,0 +1,16 @@ +{ + "name": "LLM Judgment Test Query Set", + "description": "Query set for testing LLM judgment with custom fields", + "querySetQueries": [ + { + "queryText": "laptop", + "category": "electronics", + "referenceAnswer": "A portable computer for professionals" + }, + { + "queryText": "coffee maker", + "category": "kitchen", + "referenceAnswer": "An appliance for brewing coffee at home" + } + ] +} diff --git a/src/test/resources/llmjudgment/CreateSearchConfiguration.json b/src/test/resources/llmjudgment/CreateSearchConfiguration.json new file mode 100644 index 00000000..7c4f91b9 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateSearchConfiguration.json @@ -0,0 +1,6 @@ +{ + "name": "Products Multi-Field Search", + "description": "Search both name and description fields", + "index": "{{index}}", + "query": "{\"query\": {\"multi_match\": {\"query\": \"%SearchText%\", \"fields\": [\"name\", \"description\"]}}}" +} diff --git a/src/test/resources/llmjudgment/CreateTestIndex.json b/src/test/resources/llmjudgment/CreateTestIndex.json new file mode 100644 index 00000000..08fd711d --- /dev/null +++ b/src/test/resources/llmjudgment/CreateTestIndex.json @@ -0,0 +1,18 @@ +{ + "mappings": { + "properties": { + "name": { + "type": "text" + }, + "description": { + "type": "text" + }, + "category": { + "type": "keyword" + }, + "price": { + "type": "float" + } + } + } +}