4848import org .elasticsearch .xpack .esql .expression .function .fulltext .QueryString ;
4949import org .elasticsearch .xpack .esql .expression .function .grouping .Bucket ;
5050import org .elasticsearch .xpack .esql .expression .function .scalar .convert .ToInteger ;
51+ import org .elasticsearch .xpack .esql .expression .function .scalar .string .Concat ;
5152import org .elasticsearch .xpack .esql .expression .function .scalar .string .Substring ;
5253import org .elasticsearch .xpack .esql .expression .predicate .operator .arithmetic .Add ;
5354import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
6768import org .elasticsearch .xpack .esql .plan .logical .OrderBy ;
6869import org .elasticsearch .xpack .esql .plan .logical .Row ;
6970import org .elasticsearch .xpack .esql .plan .logical .UnresolvedRelation ;
71+ import org .elasticsearch .xpack .esql .plan .logical .inference .Completion ;
7072import org .elasticsearch .xpack .esql .plan .logical .inference .Rerank ;
7173import org .elasticsearch .xpack .esql .plan .logical .local .EsqlProject ;
7274import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
8991import static org .elasticsearch .xpack .esql .EsqlTestUtils .TEST_VERIFIER ;
9092import static org .elasticsearch .xpack .esql .EsqlTestUtils .as ;
9193import static org .elasticsearch .xpack .esql .EsqlTestUtils .configuration ;
94+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .getAttributeByName ;
9295import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsConstant ;
9396import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsIdentifier ;
9497import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsPattern ;
98+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .referenceAttribute ;
9599import static org .elasticsearch .xpack .esql .EsqlTestUtils .withDefaultLimitWarning ;
96100import static org .elasticsearch .xpack .esql .analysis .Analyzer .NO_FIELDS ;
97101import static org .elasticsearch .xpack .esql .analysis .AnalyzerTestUtils .analyze ;
@@ -3050,7 +3054,7 @@ public void testResolveRerankInferenceId() {
30503054
30513055 {
30523056 LogicalPlan plan = analyze (
3053- " FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3057+ "FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
30543058 "mapping-books.json"
30553059 );
30563060 Rerank rerank = as (as (plan , Limit .class ).child (), Rerank .class );
@@ -3120,16 +3124,13 @@ public void testResolveRerankFields() {
31203124 Filter filter = as (drop .child (), Filter .class );
31213125 EsRelation relation = as (filter .child (), EsRelation .class );
31223126
3123- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3124- assertThat (titleAttribute , notNullValue ());
3127+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3128+ assertThat (getAttributeByName ( relation . output (), "title" ) , notNullValue ());
31253129
31263130 assertThat (rerank .queryText (), equalTo (string ("italian food recipe" )));
31273131 assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
31283132 assertThat (rerank .rerankFields (), equalTo (List .of (alias ("title" , titleAttribute ))));
3129- assertThat (
3130- rerank .scoreAttribute (),
3131- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3132- );
3133+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
31333134 }
31343135
31353136 {
@@ -3149,15 +3150,11 @@ public void testResolveRerankFields() {
31493150 assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
31503151
31513152 assertThat (rerank .rerankFields (), hasSize (3 ));
3152- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3153+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
31533154 assertThat (titleAttribute , notNullValue ());
31543155 assertThat (rerank .rerankFields ().get (0 ), equalTo (alias ("title" , titleAttribute )));
31553156
3156- Attribute descriptionAttribute = relation .output ()
3157- .stream ()
3158- .filter (attribute -> attribute .name ().equals ("description" ))
3159- .findFirst ()
3160- .get ();
3157+ Attribute descriptionAttribute = getAttributeByName (relation .output (), "description" );
31613158 assertThat (descriptionAttribute , notNullValue ());
31623159 Alias descriptionAlias = rerank .rerankFields ().get (1 );
31633160 assertThat (descriptionAlias .name (), equalTo ("description" ));
@@ -3166,13 +3163,11 @@ public void testResolveRerankFields() {
31663163 equalTo (List .of (descriptionAttribute , literal (0 ), literal (100 )))
31673164 );
31683165
3169- Attribute yearAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "year" )). findFirst (). get ( );
3166+ Attribute yearAttribute = getAttributeByName ( relation .output (), "year" );
31703167 assertThat (yearAttribute , notNullValue ());
31713168 assertThat (rerank .rerankFields ().get (2 ), equalTo (alias ("yearRenamed" , yearAttribute )));
3172- assertThat (
3173- rerank .scoreAttribute (),
3174- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3175- );
3169+
3170+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
31763171 }
31773172
31783173 {
@@ -3204,11 +3199,7 @@ public void testResolveRerankScoreField() {
32043199 Filter filter = as (rerank .child (), Filter .class );
32053200 EsRelation relation = as (filter .child (), EsRelation .class );
32063201
3207- Attribute metadataScoreAttribute = relation .output ()
3208- .stream ()
3209- .filter (attr -> attr .name ().equals (MetadataAttribute .SCORE ))
3210- .findFirst ()
3211- .get ();
3202+ Attribute metadataScoreAttribute = getAttributeByName (relation .output (), MetadataAttribute .SCORE );
32123203 assertThat (rerank .scoreAttribute (), equalTo (metadataScoreAttribute ));
32133204 assertThat (rerank .output (), hasItem (metadataScoreAttribute ));
32143205 }
@@ -3232,6 +3223,116 @@ public void testResolveRerankScoreField() {
32323223 }
32333224 }
32343225
3226+ public void testResolveCompletionInferenceId () {
3227+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3228+
3229+ LogicalPlan plan = analyze ("""
3230+ FROM books METADATA _score
3231+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3232+ """ , "mapping-books.json" );
3233+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3234+ assertThat (completion .inferenceId (), equalTo (string ("completion-inference-id" )));
3235+ }
3236+
3237+ public void testResolveCompletionInferenceIdInvalidTaskType () {
3238+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3239+
3240+ assertError (
3241+ """
3242+ FROM books METADATA _score
3243+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `reranking-inference-id`
3244+ """ ,
3245+ "mapping-books.json" ,
3246+ new QueryParams (),
3247+ "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3248+ + " Only inference endpoints with the task type [completion] are supported"
3249+ );
3250+ }
3251+
3252+ public void testResolveCompletionInferenceMissingInferenceId () {
3253+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3254+
3255+ assertError ("""
3256+ FROM books METADATA _score
3257+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `unknown-inference-id`
3258+ """ , "mapping-books.json" , new QueryParams (), "unresolved inference [unknown-inference-id]" );
3259+ }
3260+
3261+ public void testResolveCompletionInferenceIdResolutionError () {
3262+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3263+
3264+ assertError ("""
3265+ FROM books METADATA _score
3266+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `error-inference-id`
3267+ """ , "mapping-books.json" , new QueryParams (), "error with inference resolution" );
3268+ }
3269+
3270+ public void testResolveCompletionTargetField () {
3271+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3272+
3273+ LogicalPlan plan = analyze ("""
3274+ FROM books METADATA _score
3275+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS translation
3276+ """ , "mapping-books.json" );
3277+
3278+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3279+ assertThat (completion .targetField (), equalTo (referenceAttribute ("translation" , DataType .TEXT )));
3280+ }
3281+
3282+ public void testResolveCompletionDefaultTargetField () {
3283+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3284+
3285+ LogicalPlan plan = analyze ("""
3286+ FROM books METADATA _score
3287+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3288+ """ , "mapping-books.json" );
3289+
3290+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3291+ assertThat (completion .targetField (), equalTo (referenceAttribute ("completion" , DataType .TEXT )));
3292+ }
3293+
3294+ public void testResolveCompletionPrompt () {
3295+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3296+
3297+ LogicalPlan plan = analyze ("""
3298+ FROM books METADATA _score
3299+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3300+ """ , "mapping-books.json" );
3301+
3302+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3303+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3304+
3305+ assertThat (
3306+ as (completion .prompt (), Concat .class ).children (),
3307+ equalTo (List .of (string ("Translate the following text in French\n " ), getAttributeByName (esRelation .output (), "description" )))
3308+ );
3309+ }
3310+
3311+ public void testResolveCompletionPromptInvalidType () {
3312+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3313+
3314+ assertError ("""
3315+ FROM books METADATA _score
3316+ | COMPLETION LENGTH(description) WITH `completion-inference-id`
3317+ """ , "mapping-books.json" , new QueryParams (), "prompt must be of type [text] but is [integer]" );
3318+ }
3319+
3320+ public void testResolveCompletionOutputField () {
3321+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3322+
3323+ LogicalPlan plan = analyze ("""
3324+ FROM books METADATA _score
3325+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS description
3326+ """ , "mapping-books.json" );
3327+
3328+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3329+ assertThat (completion .targetField (), equalTo (referenceAttribute ("description" , DataType .TEXT )));
3330+
3331+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3332+ assertThat (getAttributeByName (completion .output (), "description" ), equalTo (completion .targetField ()));
3333+ assertThat (getAttributeByName (esRelation .output (), "description" ), not (equalTo (completion .targetField ())));
3334+ }
3335+
32353336 @ Override
32363337 protected IndexAnalyzers createDefaultIndexAnalyzers () {
32373338 return super .createDefaultIndexAnalyzers ();
0 commit comments