diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnector.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnector.java index 98cf2a640f19..468fe2154070 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnector.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnector.java @@ -22,6 +22,7 @@ import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.SystemTable; +import io.trino.spi.ptf.ConnectorTableFunction; import io.trino.spi.transaction.IsolationLevel; import javax.inject.Inject; @@ -40,6 +41,7 @@ public class ElasticsearchConnector private final ElasticsearchSplitManager splitManager; private final ElasticsearchPageSourceProvider pageSourceProvider; private final NodesSystemTable nodesSystemTable; + private final Set connectorTableFunctions; @Inject public ElasticsearchConnector( @@ -47,13 +49,15 @@ public ElasticsearchConnector( ElasticsearchMetadata metadata, ElasticsearchSplitManager splitManager, ElasticsearchPageSourceProvider pageSourceProvider, - NodesSystemTable nodesSystemTable) + NodesSystemTable nodesSystemTable, + Set connectorTableFunctions) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.nodesSystemTable = requireNonNull(nodesSystemTable, "nodesSystemTable is null"); + this.connectorTableFunctions = ImmutableSet.copyOf(requireNonNull(connectorTableFunctions, "connectorTableFunctions is null")); } @Override @@ -87,6 +91,12 @@ public Set getSystemTables() return ImmutableSet.of(nodesSystemTable); } + @Override + public Set getTableFunctions() + { + return connectorTableFunctions; + } + @Override public final void shutdown() { diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnectorModule.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnectorModule.java index d8bf4b832cc7..78ea5ea90de6 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnectorModule.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchConnectorModule.java @@ -17,7 +17,10 @@ import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.elasticsearch.client.ElasticsearchClient; +import io.trino.plugin.elasticsearch.ptf.RawQuery; +import io.trino.spi.ptf.ConnectorTableFunction; +import static com.google.inject.multibindings.Multibinder.newSetBinder; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; @@ -46,6 +49,8 @@ protected void setup(Binder binder) newOptionalBinder(binder, AwsSecurityConfig.class); newOptionalBinder(binder, PasswordConfig.class); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(RawQuery.class).in(Scopes.SINGLETON); + install(conditionalModule( ElasticsearchConfig.class, config -> config.getSecurity() diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java index b32a0c81ee16..b5fb548fddef 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java @@ -42,25 +42,30 @@ import io.trino.plugin.elasticsearch.decoders.TinyintDecoder; import io.trino.plugin.elasticsearch.decoders.VarbinaryDecoder; import io.trino.plugin.elasticsearch.decoders.VarcharDecoder; +import io.trino.plugin.elasticsearch.ptf.RawQuery.RawQueryFunctionHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.ColumnSchema; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.ConnectorTableProperties; +import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.StandardTypes; @@ -164,6 +169,8 @@ public ElasticsearchTableHandle getTableHandle(ConnectorSession session, SchemaT Optional query = Optional.empty(); ElasticsearchTableHandle.Type type = SCAN; if (parts.length == 2) { + // TODO this query pass-through mechanism is deprecated in favor of the `raw_query` table function. + // it should be eventually removed: https://github.com/trinodb/trino/issues/13050 if (table.endsWith(PASSTHROUGH_QUERY_SUFFIX)) { table = table.substring(0, table.length() - PASSTHROUGH_QUERY_SUFFIX.length()); byte[] decoded; @@ -675,6 +682,24 @@ private static boolean isPassthroughQuery(ElasticsearchTableHandle table) return table.getType().equals(QUERY); } + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + if (!(handle instanceof RawQueryFunctionHandle)) { + return Optional.empty(); + } + + ConnectorTableHandle tableHandle = ((RawQueryFunctionHandle) handle).getTableHandle(); + ConnectorTableSchema tableSchema = getTableSchema(session, tableHandle); + Map columnHandlesByName = getColumnHandles(session, tableHandle); + List columnHandles = tableSchema.getColumns().stream() + .map(ColumnSchema::getName) + .map(columnHandlesByName::get) + .collect(toImmutableList()); + + return Optional.of(new TableFunctionApplicationResult<>(tableHandle, columnHandles)); + } + private static class InternalTableMetadata { private final SchemaTableName tableName; diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java new file mode 100644 index 000000000000..80547cc2aaf9 --- /dev/null +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.elasticsearch.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.slice.Slice; +import io.trino.plugin.elasticsearch.ElasticsearchColumnHandle; +import io.trino.plugin.elasticsearch.ElasticsearchMetadata; +import io.trino.plugin.elasticsearch.ElasticsearchTableHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnSchema; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.ConnectorTableSchema; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.ptf.AbstractConnectorTableFunction; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.ConnectorTableFunction; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; +import io.trino.spi.ptf.Descriptor; +import io.trino.spi.ptf.ScalarArgument; +import io.trino.spi.ptf.ScalarArgumentSpecification; +import io.trino.spi.ptf.TableFunctionAnalysis; + +import javax.inject.Inject; +import javax.inject.Provider; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.elasticsearch.ElasticsearchTableHandle.Type.QUERY; +import static io.trino.spi.ptf.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; + +public class RawQuery + implements Provider +{ + public static final String SCHEMA_NAME = "system"; + public static final String NAME = "raw_query"; + + private final ElasticsearchMetadata metadata; + + @Inject + public RawQuery(ElasticsearchMetadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public ConnectorTableFunction get() + { + return new RawQueryFunction(metadata); + } + + public static class RawQueryFunction + extends AbstractConnectorTableFunction + { + private final ElasticsearchMetadata metadata; + + public RawQueryFunction(ElasticsearchMetadata metadata) + { + super( + SCHEMA_NAME, + NAME, + List.of( + ScalarArgumentSpecification.builder() + .name("SCHEMA") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("INDEX") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("QUERY") + .type(VARCHAR) + .build()), + GENERIC_TABLE); + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + String schema = ((Slice) ((ScalarArgument) arguments.get("SCHEMA")).getValue()).toStringUtf8(); + String index = ((Slice) ((ScalarArgument) arguments.get("INDEX")).getValue()).toStringUtf8(); + String query = ((Slice) ((ScalarArgument) arguments.get("QUERY")).getValue()).toStringUtf8(); + + ElasticsearchTableHandle tableHandle = new ElasticsearchTableHandle(QUERY, schema, index, Optional.of(query)); + ConnectorTableSchema tableSchema = metadata.getTableSchema(session, tableHandle); + Map columnsByName = metadata.getColumnHandles(session, tableHandle); + List columns = tableSchema.getColumns().stream() + .map(ColumnSchema::getName) + .map(columnsByName::get) + .collect(toImmutableList()); + + Descriptor returnedType = new Descriptor(columns.stream() + .map(ElasticsearchColumnHandle.class::cast) + .map(column -> new Descriptor.Field(column.getName(), Optional.of(column.getType()))) + .collect(toList())); + + RawQueryFunctionHandle handle = new RawQueryFunctionHandle(tableHandle); + + return TableFunctionAnalysis.builder() + .returnedType(returnedType) + .handle(handle) + .build(); + } + } + + public static class RawQueryFunctionHandle + implements ConnectorTableFunctionHandle + { + private final ElasticsearchTableHandle tableHandle; + + @JsonCreator + public RawQueryFunctionHandle(@JsonProperty("tableHandle") ElasticsearchTableHandle tableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + } + + @JsonProperty + public ConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } +} diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java index 05ef17e43ee1..d7b00953d0e7 100644 --- a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.io.BaseEncoding; import com.google.common.net.HostAndPort; +import io.trino.Session; import io.trino.spi.type.VarcharType; import io.trino.sql.planner.plan.LimitNode; import io.trino.testing.AbstractTestQueries; @@ -1803,6 +1804,21 @@ public void testPassthroughQuery() "FROM data", BaseEncoding.base32().encode(query.getBytes(UTF_8))), "VALUES (60000, 449872500)"); + // assert that query pass-through returns the same result as the raw_query table function + assertThat(query(format("WITH data(r) AS (" + + " SELECT CAST(json_parse(result) AS ROW(aggregations ROW(max_orderkey ROW(value BIGINT), sum_orderkey ROW(value BIGINT)))) " + + " FROM \"orders$query:%s\") " + + "SELECT r.aggregations.max_orderkey.value, r.aggregations.sum_orderkey.value " + + "FROM data", BaseEncoding.base32().encode(query.getBytes(UTF_8))))) + .matches(format("WITH data(r) AS (" + + " SELECT CAST(json_parse(result) AS ROW(aggregations ROW(max_orderkey ROW(value BIGINT), sum_orderkey ROW(value BIGINT)))) " + + " FROM TABLE(elasticsearch.system.raw_query(" + + " schema => 'tpch', " + + " index => 'orders', " + + " query => '%s'))) " + + "SELECT r.aggregations.max_orderkey.value, r.aggregations.sum_orderkey.value " + + "FROM data", query)); + assertQueryFails( "SELECT * FROM \"orders$query:invalid-base32-encoding\"", "Elasticsearch query for 'orders' is not base32-encoded correctly"); @@ -1860,6 +1876,53 @@ public void testMissingIndex() assertTableDoesNotExist("nonexistent_table"); } + @Test + public void testQueryTableFunction() + { + // select single record + assertQuery("SELECT json_query(result, 'lax $[0][0].hits.hits._source') " + + "FROM TABLE(elasticsearch.system.raw_query(" + + "schema => 'tpch', " + + "index => 'nation', " + + "query => '{\"query\": {\"match\": {\"name\": \"ALGERIA\"}}}')) t(result)", + "VALUES '{\"nationkey\":0,\"name\":\"ALGERIA\",\"regionkey\":0,\"comment\":\" haggle. carefully final deposits detect slyly agai\"}'"); + + // parameters + Session session = Session.builder(getSession()) + .addPreparedStatement( + "my_query", + "SELECT json_query(result, 'lax $[0][0].hits.hits._source') FROM TABLE(elasticsearch.system.raw_query(schema => ?, index => ?, query => ?))") + .build(); + assertQuery( + session, + "EXECUTE my_query USING 'tpch', 'nation', '{\"query\": {\"match\": {\"name\": \"ALGERIA\"}}}'", + "VALUES '{\"nationkey\":0,\"name\":\"ALGERIA\",\"regionkey\":0,\"comment\":\" haggle. carefully final deposits detect slyly agai\"}'"); + + // select multiple records by range. Use array wrapper to wrap multiple results + assertQuery("SELECT array_sort(CAST(json_parse(json_query(result, 'lax $[0][0].hits.hits._source.name' WITH ARRAY WRAPPER)) AS array(varchar))) " + + "FROM TABLE(elasticsearch.system.raw_query(" + + "schema => 'tpch', " + + "index => 'nation', " + + "query => '{\"query\": {\"range\": {\"nationkey\": {\"gte\": 0,\"lte\": 3}}}}')) t(result)", + "VALUES ARRAY['ALGERIA', 'ARGENTINA', 'BRAZIL', 'CANADA']"); + + // no matches + assertQuery("SELECT json_query(result, 'lax $[0][0].hits.hits') " + + "FROM TABLE(elasticsearch.system.raw_query(" + + "schema => 'tpch', " + + "index => 'nation', " + + "query => '{\"query\": {\"match\": {\"name\": \"UTOPIA\"}}}')) t(result)", + "VALUES '[]'"); + + // syntax error + assertThatThrownBy(() -> query("SELECT * " + + "FROM TABLE(elasticsearch.system.raw_query(" + + "schema => 'tpch', " + + "index => 'nation', " + + "query => 'wrong syntax')) t(result)")) + .hasMessageContaining("json_parse_exception"); + } + protected void assertTableDoesNotExist(String name) { assertQueryReturnsEmptyResult(format("SELECT * FROM information_schema.columns WHERE table_name = '%s'", name));