Skip to content

Commit 5474d1b

Browse files
committed
Adding query planning tool search template validation and integration tests
Signed-off-by: Joshua Palis <[email protected]>
1 parent 281c430 commit 5474d1b

File tree

4 files changed

+202
-15
lines changed

4 files changed

+202
-15
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.opensearch.ml.common.utils.ToolUtils;
2828
import org.opensearch.transport.client.Client;
2929

30+
import com.google.gson.reflect.TypeToken;
31+
3032
import lombok.Getter;
3133
import lombok.Setter;
3234

@@ -46,13 +48,16 @@ public class QueryPlanningTool implements WithModelTool {
4648
public static final String USER_PROMPT_FIELD = "user_prompt";
4749
public static final String INDEX_MAPPING_FIELD = "index_mapping";
4850
public static final String QUERY_FIELDS_FIELD = "query_fields";
49-
private static final String GENERATION_TYPE_FIELD = "generation_type";
51+
public static final String GENERATION_TYPE_FIELD = "generation_type";
5052
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
51-
private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
52-
private static final String SEARCH_TEMPLATES_FIELD = "search_templates";
53+
public static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
54+
public static final String SEARCH_TEMPLATES_FIELD = "search_templates";
5355
public static final String TEMPLATE_FIELD = "template";
56+
private static final String TEMPLATE_ID_FIELD = "template_id";
57+
private static final String TEMPLATE_DESCRIPTION_FIELD = "template_description";
5458
private static final String DEFAULT_SYSTEM_PROMPT =
5559
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries";
60+
5661
@Getter
5762
private final String generationType;
5863
@Getter
@@ -112,7 +117,12 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
112117
// Retrieve search template by ID
113118
GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId);
114119
client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> {
115-
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
120+
if (getStoredScriptResponse.getSource() == null) {
121+
// Edge case where stored scripts arent synced, default search template should be used
122+
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
123+
} else {
124+
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
125+
}
116126
executeQueryPlanning(parameters, listener);
117127
}, e -> { listener.onFailure(e); }));
118128
}
@@ -233,14 +243,38 @@ public QueryPlanningTool create(Map<String, Object> map) {
233243
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
234244
} else {
235245
// array is parsed as a json string
236-
searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD));
237-
246+
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
247+
validateSearchTemplates(searchTemplatesJson);
248+
searchTemplates = gson.toJson(searchTemplatesJson);
238249
}
239250
}
240251

241252
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
242253
}
243254

255+
private void validateSearchTemplates(Object searchTemplatesObj) {
256+
List<Map<String, String>> templates = gson.fromJson(searchTemplatesObj.toString(), new TypeToken<List<Map<String, String>>>() {
257+
}.getType());
258+
259+
for (Map<String, String> template : templates) {
260+
validateTemplateFields(template);
261+
}
262+
}
263+
264+
private void validateTemplateFields(Map<String, String> template) {
265+
// Validate templateId
266+
String templateId = template.get(TEMPLATE_ID_FIELD);
267+
if (templateId == null || templateId.trim().isEmpty()) {
268+
throw new IllegalArgumentException("search_templates field entries must have a template_id");
269+
}
270+
271+
// Validate templateDescription
272+
String templateDescription = template.get(TEMPLATE_DESCRIPTION_FIELD);
273+
if (templateDescription == null || templateDescription.trim().isEmpty()) {
274+
throw new IllegalArgumentException("search_templates field entries must have a template_description");
275+
}
276+
}
277+
244278
@Override
245279
public String getDefaultDescription() {
246280
return DEFAULT_DESCRIPTION;

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,38 @@ public void testFactoryCreate() {
9595
assertEquals(QueryPlanningTool.TYPE, tool.getName());
9696
}
9797

98+
@Test
99+
public void testCreateWithInvalidSearchTemplatesDescription() throws IllegalArgumentException {
100+
Map<String, Object> params = new HashMap<>();
101+
params.put("generation_type", "user_templates");
102+
params.put(MODEL_ID_FIELD, "test_model_id");
103+
params
104+
.put(
105+
SYSTEM_PROMPT_FIELD,
106+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
107+
);
108+
params.put("query_text", "help me find some books related to wind");
109+
params.put("search_templates", "[{'template_id': 'template_id', 'template_des': 'test_description'}]");
110+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
111+
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
112+
}
113+
114+
@Test
115+
public void testCreateWithInvalidSearchTemplatesID() throws IllegalArgumentException {
116+
Map<String, Object> params = new HashMap<>();
117+
params.put("generation_type", "user_templates");
118+
params.put(MODEL_ID_FIELD, "test_model_id");
119+
params
120+
.put(
121+
SYSTEM_PROMPT_FIELD,
122+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
123+
);
124+
params.put("query_text", "help me find some books related to wind");
125+
params.put("search_templates", "[{'templateid': 'template_id', 'template_description': 'test_description'}]");
126+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
127+
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
128+
}
129+
98130
@Test
99131
public void testRun() throws ExecutionException, InterruptedException {
100132
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";

plugin/build.gradle

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ dependencies {
6868

6969
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18"
7070

71-
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
71+
zipArchive("org.opensearch.plugin:opensearch-job-scheduler:${opensearch_build}")
72+
zipArchive("org.opensearch.plugin:opensearch-knn:${opensearch_build}")
73+
zipArchive("org.opensearch.plugin:neural-search:${opensearch_build}")
7274
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
7375
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
7476
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
@@ -249,16 +251,38 @@ testClusters.integTest {
249251
}
250252
plugin(project.tasks.bundlePlugin.archiveFile)
251253
plugin(provider(new Callable<RegularFile>(){
252-
@Override
253-
RegularFile call() throws Exception {
254-
return new RegularFile() {
255-
@Override
256-
File getAsFile() {
257-
return configurations.zipArchive.asFileTree.getSingleFile()
254+
@Override
255+
RegularFile call() throws Exception {
256+
return new RegularFile() {
257+
@Override
258+
File getAsFile() {
259+
return configurations.zipArchive.asFileTree.matching{include "**/opensearch-job-scheduler-${opensearch_build}.zip"}.getSingleFile()
260+
}
258261
}
259262
}
260-
}
261-
}))
263+
}))
264+
plugin(provider(new Callable<RegularFile>(){
265+
@Override
266+
RegularFile call() throws Exception {
267+
return new RegularFile() {
268+
@Override
269+
File getAsFile() {
270+
return configurations.zipArchive.asFileTree.matching{include "**/opensearch-knn-${opensearch_build}.zip"}.getSingleFile()
271+
}
272+
}
273+
}
274+
}))
275+
plugin(provider(new Callable<RegularFile>(){
276+
@Override
277+
RegularFile call() throws Exception {
278+
return new RegularFile() {
279+
@Override
280+
File getAsFile() {
281+
return configurations.zipArchive.asFileTree.matching{include "**/neural-search-${opensearch_build}.zip"}.getSingleFile()
282+
}
283+
}
284+
}
285+
}))
262286

263287
nodes.each { node ->
264288
def plugins = node.plugins

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package org.opensearch.ml.rest;
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
9+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD;
910
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
11+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD;
12+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD;
1013

1114
import java.io.IOException;
1215
import java.util.List;
@@ -95,6 +98,50 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
9598
deleteAgent(agentId);
9699
}
97100

101+
@Test
102+
public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException {
103+
if (OPENAI_KEY == null) {
104+
return;
105+
}
106+
107+
// Create Search Templates
108+
String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}";
109+
Response response = createSearchTemplate("type_search_template", templateBody);
110+
templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}";
111+
response = createSearchTemplate("type_search_template_2", templateBody);
112+
113+
// Register agent with search template IDs
114+
String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates";
115+
String searchTemplates = "[{"
116+
+ "\"template_id\":\"type_search_template\","
117+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\""
118+
+ "},{"
119+
+ "\"template_id\":\"type_search_template_2\","
120+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\""
121+
+ "},{"
122+
+ "\"template_id\":\"brand_search_template\","
123+
+ "\"template_description\":\"this templates searches for products that match the given brand\""
124+
+ "}]";
125+
String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates);
126+
assertNotNull(agentId);
127+
128+
String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}";
129+
Response agentResponse = executeAgent(agentId, query);
130+
String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity());
131+
132+
Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class);
133+
134+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
135+
Map<String, Object> firstResult = inferenceResults.get(0);
136+
List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output");
137+
Map<String, Object> output = (Map<String, Object>) outputArray.get(0);
138+
String result = output.get("result").toString();
139+
140+
assertTrue(result.contains("query"));
141+
assertTrue(result.contains("term"));
142+
deleteAgent(agentId);
143+
}
144+
98145
private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
99146
MLToolSpec listIndexTool = MLToolSpec
100147
.builder()
@@ -125,6 +172,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model
125172
return registerAgent(agentName, agent);
126173
}
127174

175+
private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates)
176+
throws IOException {
177+
MLToolSpec listIndexTool = MLToolSpec
178+
.builder()
179+
.type("ListIndexTool")
180+
.name("MyListIndexTool")
181+
.description("A tool for list indices")
182+
.parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?"))
183+
.includeOutputInAgentResponse(true)
184+
.build();
185+
186+
MLToolSpec queryPlanningTool = MLToolSpec
187+
.builder()
188+
.type("QueryPlanningTool")
189+
.name("MyQueryPlanningTool")
190+
.description("A tool for planning queries")
191+
.parameters(
192+
Map
193+
.ofEntries(
194+
Map.entry(MODEL_ID_FIELD, modelId),
195+
Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD),
196+
Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates)
197+
)
198+
)
199+
.includeOutputInAgentResponse(true)
200+
.build();
201+
202+
MLAgent agent = MLAgent
203+
.builder()
204+
.name(agentName)
205+
.type("flow")
206+
.description("Test agent with QueryPlanningTool")
207+
.tools(List.of(listIndexTool, queryPlanningTool))
208+
.build();
209+
210+
return registerAgent(agentName, agent);
211+
}
212+
128213
private String registerQueryPlanningModel() throws IOException, InterruptedException {
129214
String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5);
130215
return registerRemoteModel(openaiConnectorEntity, openaiModelName, true);
@@ -177,6 +262,18 @@ private Response executeAgent(String agentId, String query) throws IOException {
177262
);
178263
}
179264

265+
private Response createSearchTemplate(String templateName, String templateBody) throws IOException {
266+
return TestHelper
267+
.makeRequest(
268+
client(),
269+
"PUT",
270+
"/_scripts/" + templateName,
271+
null,
272+
new StringEntity(templateBody),
273+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
274+
);
275+
}
276+
180277
private void deleteAgent(String agentId) throws IOException {
181278
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
182279
}

0 commit comments

Comments
 (0)