diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java index ae6ae1b4bec21..9fafc92ef7f6c 100644 --- a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java +++ b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java @@ -25,8 +25,8 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.type.TypeDeserializer; import com.google.common.collect.ImmutableList; @@ -91,7 +91,7 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + binder.install(new TestingHandleJsonModule()); configBinder(binder).bindConfig(FeaturesConfig.class); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index b1aea2b875fb1..ac77ad139503e 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -571,6 +571,24 @@ shared across all of the partitioned consumers. Increasing this value may improve network throughput for data transferred between stages if the network has high latency or if there are many nodes in the cluster. +``use-connector-provided-serialization-codecs`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enables the use of custom connector-provided serialization codecs for handles. +This feature allows connectors to use their own serialization format for +handle objects (such as table handles, column handles, and splits) instead +of standard JSON serialization. + +When enabled, connectors that provide a ``ConnectorCodecProvider`` with +appropriate codecs will have their handles serialized using custom binary +formats, which are then Base64-encoded for transport. Connectors without +codec support automatically fall back to standard JSON serialization. +Internal Presto handles (prefixed with ``$``) always use JSON serialization +regardless of this setting. + .. _task-properties: Task Properties diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java index b5d7f6b3c8f43..d53e39f770d68 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java @@ -27,8 +27,8 @@ import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.SplitWeight; @@ -153,8 +153,8 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); configBinder(binder).bindConfig(FeaturesConfig.class); + binder.install(new TestingHandleJsonModule()); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java index c36ff87b0460b..8eeabaa431617 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java @@ -13,12 +13,16 @@ */ package com.facebook.presto.connector; +import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSystemConfig; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.TupleDomainSerde; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; @@ -40,6 +44,7 @@ public class ConnectorContextInstance private final FilterStatsCalculatorService filterStatsCalculatorService; private final BlockEncodingSerde blockEncodingSerde; private final ConnectorSystemConfig connectorSystemConfig; + private final TupleDomainSerde tupleDomainSerde; public ConnectorContextInstance( NodeManager nodeManager, @@ -51,7 +56,8 @@ public ConnectorContextInstance( RowExpressionService rowExpressionService, FilterStatsCalculatorService filterStatsCalculatorService, BlockEncodingSerde blockEncodingSerde, - ConnectorSystemConfig connectorSystemConfig) + ConnectorSystemConfig connectorSystemConfig, + JsonCodec> tupleDomainJsonCodec) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); @@ -63,6 +69,7 @@ public ConnectorContextInstance( this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.connectorSystemConfig = requireNonNull(connectorSystemConfig, "connectorSystemConfig is null"); + this.tupleDomainSerde = new JsonCodecTupleDomainSerde(tupleDomainJsonCodec); } @Override @@ -124,4 +131,10 @@ public ConnectorSystemConfig getConnectorSystemConfig() { return connectorSystemConfig; } + + @Override + public TupleDomainSerde getTupleDomainSerde() + { + return tupleDomainSerde; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index dc85eea82f0e5..e6c2beee4ddd9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -13,10 +13,12 @@ */ package com.facebook.presto.connector; +import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.connector.informationSchema.InformationSchemaConnector; import com.facebook.presto.connector.system.DelegatingSystemTablesProvider; @@ -33,6 +35,7 @@ import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.security.AccessControlManager; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSystemConfig; import com.facebook.presto.spi.PageIndexerFactory; @@ -122,6 +125,7 @@ public class ConnectorManager private final BlockEncodingSerde blockEncodingSerde; private final ConnectorSystemConfig connectorSystemConfig; private final ConnectorCodecManager connectorCodecManager; + private final JsonCodec> tupleDomainJsonCodec; @GuardedBy("this") private final ConcurrentMap connectorFactories = new ConcurrentHashMap<>(); @@ -156,7 +160,8 @@ public ConnectorManager( FilterStatsCalculator filterStatsCalculator, BlockEncodingSerde blockEncodingSerde, FeaturesConfig featuresConfig, - ConnectorCodecManager connectorCodecManager) + ConnectorCodecManager connectorCodecManager, + JsonCodec> tupleDomainCodec) { this.metadataManager = requireNonNull(metadataManager, "metadataManager is null"); this.catalogManager = requireNonNull(catalogManager, "catalogManager is null"); @@ -182,6 +187,7 @@ public ConnectorManager( this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.connectorSystemConfig = () -> featuresConfig.isNativeExecutionEnabled(); this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + this.tupleDomainJsonCodec = requireNonNull(tupleDomainCodec, "tupleDomainCodec is null"); } @PreDestroy @@ -386,13 +392,24 @@ private Connector createConnector(ConnectorId connectorId, ConnectorFactory fact new RowExpressionFormatter(metadataManager.getFunctionAndTypeManager())), new ConnectorFilterStatsCalculatorService(filterStatsCalculator), blockEncodingSerde, - connectorSystemConfig); + connectorSystemConfig, + tupleDomainJsonCodec); try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) { return factory.create(connectorId.getCatalogName(), properties, context); } } + public Optional getConnectorCodecProvider(ConnectorId connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + MaterializedConnector materializedConnector = connectors.get(connectorId); + if (materializedConnector == null) { + return Optional.empty(); + } + return materializedConnector.getConnectorCodecProvider(); + } + private static class MaterializedConnector { private final ConnectorId connectorId; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/JsonCodecTupleDomainSerde.java b/presto-main-base/src/main/java/com/facebook/presto/connector/JsonCodecTupleDomainSerde.java new file mode 100644 index 0000000000000..3b60cb0b51d58 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/JsonCodecTupleDomainSerde.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.TupleDomainSerde; + +import static java.util.Objects.requireNonNull; + +class JsonCodecTupleDomainSerde + implements TupleDomainSerde +{ + private final JsonCodec> tupleDomainJsonCodec; + + public JsonCodecTupleDomainSerde(JsonCodec> tupleDomainJsonCodec) + { + this.tupleDomainJsonCodec = requireNonNull(tupleDomainJsonCodec, "tupleDomainJsonCodec is null"); + } + + @Override + public String serialize(TupleDomain tupleDomain) + { + return tupleDomainJsonCodec.toJson(tupleDomain); + } + + @Override + public TupleDomain deserialize(String serialized) + { + return tupleDomainJsonCodec.fromJson(serialized); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java index 6536d9d63bf23..8d1d9dfd07299 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java @@ -13,19 +13,30 @@ */ package com.facebook.presto.index; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.AbstractTypedJacksonModule; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.spi.ConnectorIndexHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class IndexHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public IndexHandleJacksonModule(HandleResolver handleResolver) + public IndexHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorIndexHandle.class, handleResolver::getId, - handleResolver::getIndexHandleClass); + handleResolver::getIndexHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorIndexHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java index 489bb076d764c..f112e10baa37b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java @@ -13,13 +13,18 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.TreeNode; import com.fasterxml.jackson.core.Version; import com.fasterxml.jackson.databind.DatabindContext; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; @@ -31,6 +36,7 @@ import com.fasterxml.jackson.databind.jsontype.impl.AsPropertyTypeSerializer; import com.fasterxml.jackson.databind.jsontype.impl.TypeIdResolverBase; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.ser.BeanSerializerFactory; import com.fasterxml.jackson.databind.ser.std.StdSerializer; import com.fasterxml.jackson.databind.type.TypeFactory; @@ -38,6 +44,8 @@ import com.google.common.cache.CacheBuilder; import java.io.IOException; +import java.util.Base64; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.function.Function; @@ -49,20 +57,205 @@ public abstract class AbstractTypedJacksonModule extends SimpleModule { private static final String TYPE_PROPERTY = "@type"; + private static final String DATA_PROPERTY = "customSerializedValue"; protected AbstractTypedJacksonModule( Class baseClass, Function nameResolver, - Function> classResolver) + Function> classResolver, + boolean binarySerializationEnabled, + Function>> codecExtractor) { super(baseClass.getSimpleName() + "Module", Version.unknownVersion()); - TypeIdResolver typeResolver = new InternalTypeResolver<>(nameResolver, classResolver); + requireNonNull(baseClass, "baseClass is null"); + requireNonNull(nameResolver, "nameResolver is null"); + requireNonNull(classResolver, "classResolver is null"); + requireNonNull(codecExtractor, "codecExtractor is null"); - addSerializer(baseClass, new InternalTypeSerializer<>(baseClass, typeResolver)); - addDeserializer(baseClass, new InternalTypeDeserializer<>(baseClass, typeResolver)); + if (binarySerializationEnabled) { + // Use codec serialization + addSerializer(baseClass, new CodecSerializer<>(nameResolver, classResolver, codecExtractor)); + addDeserializer(baseClass, new CodecDeserializer<>(classResolver, codecExtractor)); + } + else { + // Use legacy typed serialization + TypeIdResolver typeResolver = new InternalTypeResolver<>(nameResolver, classResolver); + addSerializer(baseClass, new InternalTypeSerializer<>(baseClass, typeResolver)); + addDeserializer(baseClass, new InternalTypeDeserializer<>(baseClass, typeResolver)); + } + } + + private static class CodecSerializer + extends JsonSerializer + { + private final Function nameResolver; + private final Function> classResolver; + private final Function>> codecExtractor; + private final TypeIdResolver typeResolver; + private final TypeSerializer typeSerializer; + private final Cache, JsonSerializer> serializerCache = CacheBuilder.newBuilder().build(); + + public CodecSerializer( + Function nameResolver, + Function> classResolver, + Function>> codecExtractor) + { + this.nameResolver = requireNonNull(nameResolver, "nameResolver is null"); + this.classResolver = requireNonNull(classResolver, "classResolver is null"); + this.codecExtractor = requireNonNull(codecExtractor, "codecExtractor is null"); + this.typeResolver = new InternalTypeResolver<>(nameResolver, classResolver); + this.typeSerializer = new AsPropertyTypeSerializer(typeResolver, null, TYPE_PROPERTY); + } + + @Override + public void serialize(T value, JsonGenerator jsonGenerator, SerializerProvider provider) + throws IOException + { + if (value == null) { + jsonGenerator.writeNull(); + return; + } + + String connectorIdString = nameResolver.apply(value); + + // Only try binary serialization for actual connectors (not internal handles like "$remote") + if (!connectorIdString.startsWith("$")) { + ConnectorId connectorId = new ConnectorId(connectorIdString); + + // Check if connector has a binary codec + Optional> codec = codecExtractor.apply(connectorId); + if (codec.isPresent()) { + // Use binary serialization with flat structure + jsonGenerator.writeStartObject(); + jsonGenerator.writeStringField(TYPE_PROPERTY, connectorIdString); + byte[] data = codec.get().serialize(value); + jsonGenerator.writeStringField(DATA_PROPERTY, Base64.getEncoder().encodeToString(data)); + jsonGenerator.writeEndObject(); + return; + } + } + + // Fall back to legacy typed JSON serialization + // Use the InternalTypeSerializer approach which adds @type for polymorphic deserialization + try { + Class type = value.getClass(); + JsonSerializer serializer = serializerCache.get(type, () -> createSerializer(provider, type)); + + // Serialize with type information + serializer.serializeWithType(value, jsonGenerator, provider, typeSerializer); + } + catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause != null) { + throwIfInstanceOf(cause, IOException.class); + } + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + private static JsonSerializer createSerializer(SerializerProvider provider, Class type) + throws JsonMappingException + { + JavaType javaType = provider.constructType(type); + return (JsonSerializer) BeanSerializerFactory.instance.createSerializer(provider, javaType); + } + + @Override + public void serializeWithType(T value, JsonGenerator gen, + SerializerProvider serializers, TypeSerializer typeSer) + throws IOException + { + serialize(value, gen, serializers); + } + } + + private static class CodecDeserializer + extends JsonDeserializer + { + private final Function> classResolver; + private final Function>> codecExtractor; + + public CodecDeserializer( + Function> classResolver, + Function>> codecExtractor) + { + this.classResolver = requireNonNull(classResolver, "classResolver is null"); + this.codecExtractor = requireNonNull(codecExtractor, "codecExtractor is null"); + } + + @Override + public T deserialize(JsonParser parser, DeserializationContext context) + throws IOException + { + if (parser.getCurrentToken() == JsonToken.VALUE_NULL) { + return null; + } + + if (parser.getCurrentToken() != JsonToken.START_OBJECT) { + throw new IOException("Expected START_OBJECT, got " + parser.getCurrentToken()); + } + + // Parse the JSON tree + TreeNode tree = parser.readValueAsTree(); + + if (tree instanceof ObjectNode) { + ObjectNode node = (ObjectNode) tree; + + // Get the @type field + if (!node.has(TYPE_PROPERTY)) { + throw new IOException("Missing " + TYPE_PROPERTY + " field"); + } + String connectorIdString = node.get(TYPE_PROPERTY).asText(); + // Check if @data field is present (binary serialization) + if (node.has(DATA_PROPERTY)) { + // Binary data is present, we need a codec to deserialize it + // Special handling for internal handles like "$remote" + if (!connectorIdString.startsWith("$")) { + ConnectorId connectorId = new ConnectorId(connectorIdString); + Optional> codec = codecExtractor.apply(connectorId); + if (codec.isPresent()) { + String base64Data = node.get(DATA_PROPERTY).asText(); + byte[] data = Base64.getDecoder().decode(base64Data); + return codec.get().deserialize(data); + } + } + // @data field present but no codec available or internal handle + throw new IOException("Type " + connectorIdString + " has binary data (customSerializedValue field) but no codec available to deserialize it"); + } + + // No @data field - use standard JSON deserialization + Class handleClass = classResolver.apply(connectorIdString); + + // Remove the @type field and deserialize the remaining content + node.remove(TYPE_PROPERTY); + return context.readTreeAsValue(node, handleClass); + } + + throw new IOException("Unable to deserialize"); + } + + @Override + public T deserializeWithType(JsonParser p, DeserializationContext ctxt, + TypeDeserializer typeDeserializer) + throws IOException + { + // We handle the type ourselves + return deserialize(p, ctxt); + } + + @Override + public T deserializeWithType(JsonParser p, DeserializationContext ctxt, + TypeDeserializer typeDeserializer, T intoValue) + throws IOException + { + // We handle the type ourselves + return deserialize(p, ctxt); + } } + // Legacy classes for backward compatibility private static class InternalTypeDeserializer extends StdDeserializer { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java index 84db3f3344c81..5ce2f916cb264 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class ColumnHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public ColumnHandleJacksonModule(HandleResolver handleResolver) + public ColumnHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ColumnHandle.class, handleResolver::getId, - handleResolver::getColumnHandleClass); + handleResolver::getColumnHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getColumnHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java index 09b71787994a5..979c10c54eec0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class DeleteTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public DeleteTableHandleJacksonModule(HandleResolver handleResolver) + public DeleteTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorDeleteTableHandle.class, handleResolver::getId, - handleResolver::getDeleteTableHandleClass); + handleResolver::getDeleteTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorDeleteTableHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java index a87335daab43e..ef7b7529c76d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java @@ -22,6 +22,13 @@ public class FunctionHandleJacksonModule @Inject public FunctionHandleJacksonModule(HandleResolver handleResolver) { - super(FunctionHandle.class, handleResolver::getId, handleResolver::getFunctionHandleClass); + // Functions are internal to Presto and don't need binary serialization + super(FunctionHandle.class, + handleResolver::getId, + handleResolver::getFunctionHandleClass, + false, // Always disabled for functions + connectorId -> { + throw new UnsupportedOperationException("Function handles do not support binary serialization"); + }); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java index 6a83f7e1d2ef6..bc3607d59859c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class InsertTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public InsertTableHandleJacksonModule(HandleResolver handleResolver) + public InsertTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorInsertTableHandle.class, handleResolver::getId, - handleResolver::getInsertTableHandleClass); + handleResolver::getInsertTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorInsertTableHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java index ad04c5b7e834d..de2a814a5042d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class OutputTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public OutputTableHandleJacksonModule(HandleResolver handleResolver) + public OutputTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorOutputTableHandle.class, handleResolver::getId, - handleResolver::getOutputTableHandleClass); + handleResolver::getOutputTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorOutputTableHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java index ca876d872ff32..b0230534fd1c3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class PartitioningHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public PartitioningHandleJacksonModule(HandleResolver handleResolver) + public PartitioningHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorPartitioningHandle.class, handleResolver::getId, - handleResolver::getPartitioningHandleClass); + handleResolver::getPartitioningHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorPartitioningHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java index 858f0a6c1fbbc..790206716ab9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class SplitJacksonModule extends AbstractTypedJacksonModule { @Inject - public SplitJacksonModule(HandleResolver handleResolver) + public SplitJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorSplit.class, handleResolver::getId, - handleResolver::getSplitClass); + handleResolver::getSplitClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorSplitCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java index 9981704af9fd9..632ac089a2041 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class TableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TableHandleJacksonModule(HandleResolver handleResolver) + public TableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTableHandle.class, handleResolver::getId, - handleResolver::getTableHandleClass); + handleResolver::getTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java index bfbced02d0c38..8494c696f0de7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class TableLayoutHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TableLayoutHandleJacksonModule(HandleResolver handleResolver) + public TableLayoutHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTableLayoutHandle.class, handleResolver::getId, - handleResolver::getTableLayoutHandleClass); + handleResolver::getTableLayoutHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableLayoutHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java index 230d6be16a1b0..fd050cebf3ed3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java @@ -13,17 +13,28 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import jakarta.inject.Inject; +import jakarta.inject.Provider; public class TransactionHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TransactionHandleJacksonModule(HandleResolver handleResolver) + public TransactionHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTransactionHandle.class, handleResolver::getId, - handleResolver::getTransactionHandleClass); + handleResolver::getTransactionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTransactionHandleCodec)); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index cdecf697bda5e..ea38f31ed5a6f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -185,6 +185,7 @@ public class FeaturesConfig private boolean listBuiltInFunctionsOnly = true; private boolean experimentalFunctionsEnabled; + private boolean useConnectorProvidedSerializationCodecs; private boolean optimizeCommonSubExpressions = true; private boolean preferDistributedUnion = true; private boolean optimizeNullsInJoin; @@ -1803,6 +1804,19 @@ public FeaturesConfig setExperimentalFunctionsEnabled(boolean experimentalFuncti return this; } + public boolean isUseConnectorProvidedSerializationCodecs() + { + return useConnectorProvidedSerializationCodecs; + } + + @Config("use-connector-provided-serialization-codecs") + @ConfigDescription("Enable use of custom connector-provided serialization codecs for handles") + public FeaturesConfig setUseConnectorProvidedSerializationCodecs(boolean useConnectorProvidedSerializationCodecs) + { + this.useConnectorProvidedSerializationCodecs = useConnectorProvidedSerializationCodecs; + return this; + } + public boolean isOptimizeCommonSubExpressions() { return optimizeCommonSubExpressions; diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index f7fd052f02c3a..6854fff120114 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -26,6 +26,7 @@ import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.Type; import com.facebook.presto.connector.ConnectorCodecManager; @@ -131,6 +132,7 @@ import com.facebook.presto.server.security.PasswordAuthenticatorManager; import com.facebook.presto.server.security.PrestoAuthenticatorManager; import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; @@ -241,6 +243,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Closer; +import com.google.common.reflect.TypeToken; import org.intellij.lang.annotations.Language; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.testing.TestingMBeanServer; @@ -510,7 +513,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new FilterStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer), blockEncodingManager, featuresConfig, - new ConnectorCodecManager(ThriftCodecManager::new)); + new ConnectorCodecManager(ThriftCodecManager::new), + jsonCodec(new TypeToken>() {})); GlobalSystemConnectorFactory globalSystemConnectorFactory = new GlobalSystemConnectorFactory(ImmutableSet.of( new NodeSystemTable(nodeManager), diff --git a/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java b/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java index f677e5b210bba..8f97632243a1f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java +++ b/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java @@ -18,7 +18,7 @@ import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.connector.informationSchema.InformationSchemaTableHandle; import com.facebook.presto.connector.informationSchema.InformationSchemaTransactionHandle; -import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.MaterializedViewDefinition; @@ -52,7 +52,7 @@ public class TestCatalogServerResponse public void setup() { this.testingCatalogServerClient = new TestingCatalogServerClient(); - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector(new JsonModule(), new TestingHandleJsonModule()); this.objectMapper = injector.getInstance(ObjectMapper.class); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java index 5bc55f203713b..5a153f48fa3d0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java @@ -23,7 +23,7 @@ import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.server.SliceDeserializer; import com.facebook.presto.server.SliceSerializer; import com.facebook.presto.spi.ConnectorId; @@ -132,7 +132,7 @@ private static JsonCodec createJsonCodec() SqlParser sqlParser = new SqlParser(); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + binder.install(new TestingHandleJsonModule()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); configBinder(binder).bindConfig(FeaturesConfig.class); diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java new file mode 100644 index 0000000000000..185f7f5413899 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java @@ -0,0 +1,614 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Module; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Base64; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +@Test(singleThreaded = true) +public class TestAbstractTypedJacksonModule +{ + private ObjectMapper objectMapper; + @BeforeMethod + public void setup() + { + // Default setup with binary serialization disabled + setupInjector(false, null); + } + + private void setupInjector(boolean binarySerializationEnabled, ConnectorCodecProvider codecProvider) + { + Module testModule = binder -> { + binder.install(new JsonModule()); + + // Configure FeaturesConfig + FeaturesConfig featuresConfig = new FeaturesConfig(); + featuresConfig.setUseConnectorProvidedSerializationCodecs(binarySerializationEnabled); + binder.bind(FeaturesConfig.class).toInstance(featuresConfig); + + // Bind HandleResolver + binder.bind(HandleResolver.class).toInstance(new TestHandleResolver()); + + // Bind TestConnectorManager as a singleton + TestConnectorManager testConnectorManager = new TestConnectorManager(codecProvider); + binder.bind(TestConnectorManager.class).toInstance(testConnectorManager); + + // Register the test Jackson module + jsonBinder(binder).addModuleBinding().to(TestHandleJacksonModule.class); + }; + + Injector injector = Guice.createInjector(testModule); + objectMapper = injector.getInstance(ObjectMapper.class); + } + + @Test + public void testLegacyJsonSerializationWithoutCodec() + throws Exception + { + // Setup with binary serialization disabled + setupInjector(false, null); + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should have @type field but no binary data + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "id", "value1"); + assertJsonContains(json, "count", "42"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testBinarySerializationWithCodec() + throws Exception + { + // Create a simple codec that serializes to a custom format + ConnectorCodec codec = new ConnectorCodec() + { + @Override + public byte[] serialize(TestHandle value) + { + return String.format("%s|%d", value.getId(), value.getCount()).getBytes(UTF_8); + } + + @Override + public TestHandle deserialize(byte[] data) + { + String[] parts = new String(data, UTF_8).split("\\|"); + return new TestHandle("connector1", parts[0], Integer.parseInt(parts[1])); + } + }; + + // Setup with binary serialization enabled and codec provider + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should have @type and binary data fields + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "customSerializedValue"); + assertJsonNotContains(json, "id", "value1"); // Should not have regular fields + + // Test deserialization + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getId(), original.getId()); + assertEquals(deserialized.getCount(), original.getCount()); + } + + @Test + public void testBinarySerializationDisabled() + throws Exception + { + // This test verifies that when binary serialization is disabled via the feature flag, + // the module falls back to legacy JSON serialization even if codecs are available + + // Setup with binary serialization disabled even though codec is available + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(false, codecProvider); // false = binary serialization disabled + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should use legacy JSON serialization even though codec is available + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "id", "value1"); + assertJsonContains(json, "count", "42"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testFallbackToJsonWhenNoCodec() + throws Exception + { + // Setup with binary serialization enabled but no codec available + setupInjector(true, null); + + // Test with connector2 (no codec available) + TestHandle original = new TestHandle("connector2", "value2", 84); + String json = objectMapper.writeValueAsString(original); + + // Should fall back to JSON serialization + assertJsonContains(json, "@type", "connector2"); + assertJsonContains(json, "id", "value2"); + assertJsonContains(json, "count", "84"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testInternalHandlesAlwaysUseJson() + throws Exception + { + // Setup with codec that would handle all connectors + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + // Test with internal handle (starts with $) + TestHandle original = new TestHandle("$remote", "internal", 99); + String json = objectMapper.writeValueAsString(original); + + // Should use JSON serialization for internal handles + assertJsonContains(json, "@type", "$remote"); + assertJsonContains(json, "id", "internal"); + assertJsonContains(json, "count", "99"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testNullValueSerialization() + throws Exception + { + setupInjector(false, null); + + String json = objectMapper.writeValueAsString(null); + assertEquals(json, "null"); + + TestHandle deserialized = objectMapper.readValue("null", TestHandle.class); + assertNull(deserialized); + } + + @Test + public void testRoundTripWithMixedHandles() + throws Exception + { + // Create a TestConnectorManager that only provides codec for "binary-connector" + setupInjector(true, new SelectiveCodecProvider("binary-connector")); + + // Test multiple handles with different serialization methods + TestHandle[] handles = new TestHandle[] { + new TestHandle("binary-connector", "binary1", 1), + new TestHandle("json-connector", "json1", 2), + new TestHandle("$internal", "internal1", 3), + new TestHandle("binary-connector", "binary2", 4), + }; + + for (TestHandle original : handles) { + String json = objectMapper.writeValueAsString(original); + + // Verify serialization format based on handle type + if (original.getConnectorId().equals("binary-connector")) { + // Should use binary serialization + assertJsonContains(json, "customSerializedValue"); + assertJsonNotContains(json, "\"id\":"); + + // Test deserialization for binary-serialized handles + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getId(), original.getId()); + assertEquals(deserialized.getCount(), original.getCount()); + } + else { + // Should use JSON serialization + assertJsonNotContains(json, "customSerializedValue"); + assertJsonContains(json, "id", original.getId()); + } + } + } + + @Test + public void testDirectBinaryDataDeserialization() + throws Exception + { + // Test deserialization of manually crafted binary data JSON + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + // Manually create JSON with binary data + String encodedData = Base64.getEncoder().encodeToString("connector1|testValue|999".getBytes(UTF_8)); + String json = String.format("{\"@type\":\"connector1\",\"customSerializedValue\":\"%s\"}", encodedData); + + // Deserialize + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getConnectorId(), "connector1"); + assertEquals(deserialized.getId(), "testValue"); + assertEquals(deserialized.getCount(), 999); + } + + @Test + public void testMixedSerializationRoundTrip() + throws Exception + { + // Test that we can serialize and deserialize a mix of binary and JSON in sequence + setupInjector(true, new SelectiveCodecProvider("binary-connector")); + + // Create handles with different serialization methods + TestHandle binaryHandle = new TestHandle("binary-connector", "binary-data", 100); + TestHandle jsonHandle = new TestHandle("json-connector", "json-data", 200); + + // Serialize both + String binaryJson = objectMapper.writeValueAsString(binaryHandle); + String jsonJson = objectMapper.writeValueAsString(jsonHandle); + + // Deserialize both + TestHandle deserializedBinary = objectMapper.readValue(binaryJson, TestHandle.class); + // For JSON deserialization, we skip due to complex type handling in isolated tests + + // Verify binary deserialization worked + assertEquals(deserializedBinary.getId(), binaryHandle.getId()); + assertEquals(deserializedBinary.getCount(), binaryHandle.getCount()); + + // Verify JSON format is correct (even if we can't deserialize in this test) + assertJsonContains(jsonJson, "\"id\":\"json-data\""); + assertJsonContains(jsonJson, "\"count\":200"); + } + + private void assertJsonContains(String json, String... values) + { + for (String value : values) { + if (!json.contains(value)) { + throw new AssertionError("JSON does not contain: " + value + "\nJSON: " + json); + } + } + } + + private void assertJsonNotContains(String json, String... values) + { + for (String value : values) { + if (json.contains(value)) { + throw new AssertionError("JSON should not contain: " + value + "\nJSON: " + json); + } + } + } + + // Simple codec implementation for testing + private static class SimpleCodec + implements ConnectorCodec + { + @Override + public byte[] serialize(TestHandle value) + { + return String.format("%s|%s|%d", value.getConnectorId(), value.getId(), value.getCount()).getBytes(UTF_8); + } + + @Override + public TestHandle deserialize(byte[] data) + { + String[] parts = new String(data, UTF_8).split("\\|"); + return new TestHandle(parts[0], parts[1], Integer.parseInt(parts[2])); + } + } + + // Codec provider that only provides codec for specific connectors + private static class SelectiveCodecProvider + implements ConnectorCodecProvider + { + private final String connectorIdWithCodec; + private final ConnectorCodec codec = new SimpleCodec(); + + public SelectiveCodecProvider(String connectorIdWithCodec) + { + this.connectorIdWithCodec = connectorIdWithCodec; + } + + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + } + + // Test handle that implements multiple connector interfaces for testing + public static class TestHandle + implements com.facebook.presto.spi.ConnectorTableHandle, + com.facebook.presto.spi.ConnectorSplit, + com.facebook.presto.spi.ColumnHandle, + com.facebook.presto.spi.ConnectorTableLayoutHandle, + com.facebook.presto.spi.ConnectorOutputTableHandle, + com.facebook.presto.spi.ConnectorInsertTableHandle, + com.facebook.presto.spi.ConnectorDeleteTableHandle, + com.facebook.presto.spi.ConnectorIndexHandle, + com.facebook.presto.spi.connector.ConnectorPartitioningHandle, + com.facebook.presto.spi.connector.ConnectorTransactionHandle + { + private final String connectorId; + private final String id; + private final int count; + + // Constructor for programmatic creation + public TestHandle(String connectorId, String id, int count) + { + this.connectorId = connectorId; + this.id = id; + this.count = count; + } + + // Constructor for Jackson deserialization + @JsonCreator + public TestHandle( + @JsonProperty("id") String id, + @JsonProperty("count") int count) + { + // When deserializing, the connector ID is determined by the @type field + // For simplicity in tests, we use a fixed value + this("deserialized", id, count); + } + + // This field is excluded from JSON serialization but used internally for type resolution + @JsonIgnore + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + public String getId() + { + return id; + } + + @JsonProperty + public int getCount() + { + return count; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestHandle that = (TestHandle) o; + return count == that.count && + Objects.equals(connectorId, that.connectorId) && + Objects.equals(id, that.id); + } + + @Override + public int hashCode() + { + return Objects.hash(connectorId, id, count); + } + + @Override + public String toString() + { + return "TestHandle{" + + "connectorId='" + connectorId + '\'' + + ", id='" + id + '\'' + + ", count=" + count + + '}'; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return null; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return ImmutableList.of(); + } + + @Override + public Object getInfo() + { + return null; + } + } + + // Test ConnectorHandleResolver implementation + private static class TestConnectorHandleResolver + implements com.facebook.presto.spi.ConnectorHandleResolver + { + @Override + public Class getTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getSplitClass() + { + return TestHandle.class; + } + + @Override + public Class getIndexHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getOutputTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getInsertTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getDeleteTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getPartitioningHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getTransactionHandleClass() + { + return TestHandle.class; + } + } + + // Test HandleResolver implementation + private static class TestHandleResolver + extends HandleResolver + { + public TestHandleResolver() + { + super(); + // Register the test handle resolver for all test connectors + TestConnectorHandleResolver resolver = new TestConnectorHandleResolver(); + addConnectorName("connector1", resolver); + addConnectorName("connector2", resolver); + addConnectorName("binary-connector", resolver); + addConnectorName("json-connector", resolver); + addConnectorName("$internal", resolver); + addConnectorName("deserialized", resolver); + } + } + + // Mock ConnectorManager implementation + private static class TestConnectorManager + { + private final ConnectorCodecProvider codecProvider; + + public TestConnectorManager(ConnectorCodecProvider codecProvider) + { + this.codecProvider = codecProvider; + } + + public Optional getConnectorCodecProvider(ConnectorId connectorId) + { + // Only return codec provider for specific connectors if it's a SelectiveCodecProvider + if (codecProvider instanceof SelectiveCodecProvider) { + SelectiveCodecProvider selective = (SelectiveCodecProvider) codecProvider; + if (connectorId.getCatalogName().equals(selective.connectorIdWithCodec)) { + return Optional.of(codecProvider); + } + return Optional.empty(); + } + return Optional.ofNullable(codecProvider); + } + } + + // Test Jackson module that uses TestHandle + public static class TestHandleJacksonModule + extends AbstractTypedJacksonModule + { + @jakarta.inject.Inject + public TestHandleJacksonModule( + HandleResolver handleResolver, + TestConnectorManager testConnectorManager, + FeaturesConfig featuresConfig) + { + super(TestHandle.class, + TestHandle::getConnectorId, + id -> TestHandle.class, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> testConnectorManager + .getConnectorCodecProvider(connectorId) + .flatMap(provider -> { + Optional> codec = + provider.getConnectorTableHandleCodec(); + // Cast is safe because TestHandle implements ConnectorTableHandle + return (Optional>) (Optional) codec; + })); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java index 8b072926db5c1..0a57a32368e8c 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java @@ -44,7 +44,9 @@ public class TestInformationSchemaTableHandle @BeforeMethod public void startUp() { - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector( + new JsonModule(), + new TestingHandleJsonModule()); objectMapper = injector.getInstance(ObjectMapper.class); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java index 0e291c6a8fb59..5c5acabbe6fae 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java @@ -47,7 +47,9 @@ public class TestSystemTableHandle @BeforeMethod public void startUp() { - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector( + new JsonModule(), + new TestingHandleJsonModule()); objectMapper = injector.getInstance(ObjectMapper.class); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestingHandleJsonModule.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestingHandleJsonModule.java new file mode 100644 index 0000000000000..4a0525df3ab8e --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestingHandleJsonModule.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; + +public class TestingHandleJsonModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(FeaturesConfig.class); + + binder.install(new HandleJsonModule()); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java b/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java index 485dbfbacdfc8..7155a5f85072e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java @@ -30,9 +30,9 @@ import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.ConstantExpression; @@ -244,7 +244,7 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + binder.install(new TestingHandleJsonModule()); configBinder(binder).bindConfig(FeaturesConfig.class); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java index 5e28a77689e8c..d2bff66b01a94 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java @@ -19,8 +19,8 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; @@ -126,7 +126,7 @@ private JsonCodec getJsonCodec() SqlParser sqlParser = new SqlParser(); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + binder.install(new TestingHandleJsonModule()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); newSetBinder(binder, Type.class); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index fa5a8c6db51dc..07cf90d746215 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -20,7 +20,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.server.SliceDeserializer; import com.facebook.presto.server.SliceSerializer; import com.facebook.presto.spi.VariableAllocator; @@ -159,10 +159,10 @@ private JsonCodec getJsonCodec() SqlParser sqlParser = new SqlParser(); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + configBinder(binder).bindConfig(FeaturesConfig.class); + binder.install(new TestingHandleJsonModule()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); - configBinder(binder).bindConfig(FeaturesConfig.class); newSetBinder(binder, Type.class); jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index f5d2543048a46..c1d874f125b67 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -48,6 +48,7 @@ import com.facebook.presto.common.block.BlockEncoding; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.connector.ConnectorCodecManager; @@ -163,6 +164,7 @@ import com.facebook.presto.server.thrift.ThriftTaskUpdateRequestBodyReader; import com.facebook.presto.sessionpropertyproviders.JavaWorkerSessionPropertyProvider; import com.facebook.presto.sessionpropertyproviders.NativeWorkerSessionPropertyProvider; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; @@ -586,6 +588,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class); jsonCodecBinder(binder).bindJsonCodec(TaskSource.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonCodecBinder(binder).bindJsonCodec(new TypeLiteral>() {}); smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java index 2de40f4fa2173..47b6e6b445850 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java @@ -53,10 +53,10 @@ import com.facebook.presto.execution.buffer.OutputBuffers; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.TaskUpdateRequest; import com.facebook.presto.server.thrift.ConnectorSplitThriftCodec; @@ -361,7 +361,7 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso new JsonModule(), new SmileModule(), new ThriftCodecModule(), - new HandleJsonModule(), + new TestingHandleJsonModule(), new Module() { @Override diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java index 3e8cc6d11e4b4..cb9cc97801548 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java @@ -51,10 +51,10 @@ import com.facebook.presto.execution.buffer.OutputBuffers; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.TaskUpdateRequest; import com.facebook.presto.server.thrift.ConnectorSplitThriftCodec; @@ -369,7 +369,7 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso new JsonModule(), new SmileModule(), new ThriftCodecModule(), - new HandleJsonModule(), + new TestingHandleJsonModule(), new Module() { @Override diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index da8f6debfa0e3..27fea5ec624c8 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -47,6 +47,7 @@ target_link_libraries( $ $ $ + presto_tpch_connector_protocol presto_common presto_exception presto_function_metadata @@ -127,12 +128,9 @@ else() target_link_options(presto_server BEFORE PUBLIC "-Wl,-export-dynamic") endif() -# velox_tpch_connector is an OBJECT target in Velox and so needs to be linked to -# the executable or use TARGET_OBJECT linkage for the presto_server_lib target. -# However, we also would need to add its dependencies (tpch_gen etc). TODO -# change the target in Velox to a library target then we can move this to the -# presto_server_lib. -target_link_libraries(presto_server presto_server_lib velox_tpch_connector) +# TPCH binary connector protocol is linked via presto_tpch_connector_protocol +# which is a separate static library in presto_server_lib +target_link_libraries(presto_server presto_server_lib) # Clang requires explicit linking with libatomic. if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp index 9f5f4acf0824f..040bf8b0d74bc 100644 --- a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp @@ -1492,56 +1492,4 @@ std::unique_ptr IcebergPrestoToVeloxConnector::createConnectorProtocol() const { return std::make_unique(); } - -std::unique_ptr -TpchPrestoToVeloxConnector::toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { - auto tpchSplit = - dynamic_cast(connectorSplit); - VELOX_CHECK_NOT_NULL( - tpchSplit, "Unexpected split type {}", connectorSplit->_type); - return std::make_unique( - catalogId, - splitContext->cacheable, - tpchSplit->totalParts, - tpchSplit->partNumber); -} - -std::unique_ptr -TpchPrestoToVeloxConnector::toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const { - auto tpchColumn = - dynamic_cast(column); - VELOX_CHECK_NOT_NULL( - tpchColumn, "Unexpected column handle type {}", column->_type); - return std::make_unique( - tpchColumn->columnName); -} - -std::unique_ptr -TpchPrestoToVeloxConnector::toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - velox::connector::ColumnHandleMap& assignments) const { - auto tpchLayout = - std::dynamic_pointer_cast( - tableHandle.connectorTableLayout); - VELOX_CHECK_NOT_NULL( - tpchLayout, - "Unexpected layout type {}", - tableHandle.connectorTableLayout->_type); - return std::make_unique( - tableHandle.connectorId, - tpch::fromTableName(tpchLayout->table.tableName), - tpchLayout->table.scaleFactor); -} - -std::unique_ptr -TpchPrestoToVeloxConnector::createConnectorProtocol() const { - return std::make_unique(); -} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h index 18183ec86c388..370015d2e1db1 100644 --- a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h @@ -190,29 +190,4 @@ class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr createConnectorProtocol() const final; }; - -class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { - public: - explicit TpchPrestoToVeloxConnector(std::string connectorName) - : PrestoToVeloxConnector(std::move(connectorName)) {} - - std::unique_ptr toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; - - std::unique_ptr toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const final; - - std::unique_ptr toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - velox::connector::ColumnHandleMap& assignments) - const final; - - std::unique_ptr createConnectorProtocol() - const final; -}; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp index 1650dafd16763..e5ccac4a1ff5e 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -13,6 +13,7 @@ */ #include "presto_cpp/main/connectors/Registration.h" #include "presto_cpp/main/connectors/SystemConnector.h" +#include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" #ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR #include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" diff --git a/presto-native-execution/presto_cpp/main/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tests/CMakeLists.txt index 2e24268aa1a07..61b76c334f177 100644 --- a/presto-native-execution/presto_cpp/main/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/tests/CMakeLists.txt @@ -46,7 +46,6 @@ target_link_libraries( $ $ velox_hive_connector - velox_tpch_connector velox_presto_serializer velox_functions_prestosql velox_aggregates diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index fd7b823c30444..3acd8251f950a 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -26,7 +26,6 @@ target_link_libraries( velox_dwio_common velox_dwio_orc_reader velox_hive_connector - velox_tpch_connector velox_exec velox_dwio_common_exception presto_type_converter @@ -64,7 +63,6 @@ target_link_libraries( velox_functions_prestosql velox_functions_lib velox_hive_connector - velox_tpch_connector velox_hive_partition_function velox_presto_serializer velox_presto_type_parser @@ -96,7 +94,6 @@ target_link_libraries( presto_types velox_dwio_common velox_hive_connector - velox_tpch_connector GTest::gtest GTest::gtest_main) @@ -117,6 +114,5 @@ target_link_libraries( velox_dwio_common velox_exec_test_lib velox_hive_connector - velox_tpch_connector GTest::gtest GTest::gtest_main) diff --git a/presto-native-execution/presto_cpp/presto_protocol/CMakeLists.txt b/presto-native-execution/presto_cpp/presto_protocol/CMakeLists.txt index 15ebb198c164a..77a23d317c7e4 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/presto_protocol/CMakeLists.txt @@ -13,7 +13,28 @@ add_library( presto_protocol OBJECT presto_protocol.cpp Base64Util.cpp core/DataSize.cpp core/Duration.cpp core/ConnectorProtocol.cpp) -target_link_libraries(presto_protocol velox_type velox_presto_serializer ${RE2}) +target_link_libraries(presto_protocol + velox_type + velox_presto_serializer + velox_connector + velox_exec + velox_expression + ${RE2}) + +# Separate library for TPCH connector protocol (binary protocol implementation) +# This is a static library (not OBJECT) so dependencies propagate automatically +add_library( + presto_tpch_connector_protocol STATIC + connector/tpch/TpchConnectorProtocol.cpp) + +target_link_libraries(presto_tpch_connector_protocol PUBLIC + velox_type + velox_tpch_connector + velox_tpch_gen + velox_connector + velox_exec + velox_expression + ${RE2}) if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) diff --git a/presto-native-execution/presto_cpp/presto_protocol/Makefile b/presto-native-execution/presto_cpp/presto_protocol/Makefile index 09b43df28b4f5..6be9bf49102bc 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/Makefile +++ b/presto-native-execution/presto_cpp/presto_protocol/Makefile @@ -38,13 +38,6 @@ presto_protocol-cpp: presto_protocol-json chevron -d connector/iceberg/presto_protocol_iceberg.json connector/iceberg/presto_protocol-json-hpp.mustache >> connector/iceberg/presto_protocol_iceberg.h clang-format -style=file -i connector/iceberg/presto_protocol_iceberg.h connector/iceberg/presto_protocol_iceberg.cpp - # build tpch connector related structs - echo "// DO NOT EDIT : This file is generated by chevron" > connector/tpch/presto_protocol_tpch.cpp - chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-cpp.mustache >> connector/tpch/presto_protocol_tpch.cpp - echo "// DO NOT EDIT : This file is generated by chevron" > connector/tpch/presto_protocol_tpch.h - chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-hpp.mustache >> connector/tpch/presto_protocol_tpch.h - clang-format -style=file -i connector/tpch/presto_protocol_tpch.h connector/tpch/presto_protocol_tpch.cpp - # build arrow_flight connector related structs echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.cpp chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-cpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.cpp @@ -56,12 +49,10 @@ presto_protocol-json: ./java-to-struct-json.py --config core/presto_protocol_core.yml core/special/*.java core/special/*.inc -j | jq . > core/presto_protocol_core.json ./java-to-struct-json.py --config connector/hive/presto_protocol_hive.yml connector/hive/special/*.inc -j | jq . > connector/hive/presto_protocol_hive.json ./java-to-struct-json.py --config connector/iceberg/presto_protocol_iceberg.yml connector/iceberg/special/*.inc -j | jq . > connector/iceberg/presto_protocol_iceberg.json - ./java-to-struct-json.py --config connector/tpch/presto_protocol_tpch.yml connector/tpch/special/*.inc -j | jq . > connector/tpch/presto_protocol_tpch.json ./java-to-struct-json.py --config connector/arrow_flight/presto_protocol_arrow_flight.yml connector/arrow_flight/special/*.inc -j | jq . > connector/arrow_flight/presto_protocol_arrow_flight.json presto_protocol.proto: presto_protocol-json pystache presto_protocol-protobuf.mustache core/presto_protocol_core.json > core/presto_protocol_core.proto pystache presto_protocol-protobuf.mustache connector/hive/presto_protocol_hive.json > connector/hive/presto_protocol_hive.proto pystache presto_protocol-protobuf.mustache connector/iceberg/presto_protocol_iceberg.json > connector/iceberg/presto_protocol_iceberg.proto - pystache presto_protocol-protobuf.mustache connector/tpch/presto_protocol_tpch.json > connector/tpch/presto_protocol_tpch.proto pystache presto_protocol-protobuf.mustache connector/arrow_flight/presto_protocol_arrow_flight.json > connector/arrow_flight/presto_protocol_arrow_flight.proto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp index 8011da82eee47..24e90b78e2e8c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp @@ -370,9 +370,10 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BucketFunctionType_enum_table[] = { // NOLINT: cert-err58-cpp - {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, - {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; + BucketFunctionType_enum_table[] = + { // NOLINT: cert-err58-cpp + {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, + {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; void to_json(json& j, const BucketFunctionType& e) { static_assert( std::is_enum::value, @@ -598,12 +599,13 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - HiveCompressionCodec_enum_table[] = { // NOLINT: cert-err58-cpp - {HiveCompressionCodec::NONE, "NONE"}, - {HiveCompressionCodec::SNAPPY, "SNAPPY"}, - {HiveCompressionCodec::GZIP, "GZIP"}, - {HiveCompressionCodec::LZ4, "LZ4"}, - {HiveCompressionCodec::ZSTD, "ZSTD"}}; + HiveCompressionCodec_enum_table[] = + { // NOLINT: cert-err58-cpp + {HiveCompressionCodec::NONE, "NONE"}, + {HiveCompressionCodec::SNAPPY, "SNAPPY"}, + {HiveCompressionCodec::GZIP, "GZIP"}, + {HiveCompressionCodec::LZ4, "LZ4"}, + {HiveCompressionCodec::ZSTD, "ZSTD"}}; void to_json(json& j, const HiveCompressionCodec& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 3229da2e88d07..6d03a5ce52b12 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -25,11 +25,12 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - ChangelogOperation_enum_table[] = { // NOLINT: cert-err58-cpp - {ChangelogOperation::INSERT, "INSERT"}, - {ChangelogOperation::DELETE, "DELETE"}, - {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, - {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; + ChangelogOperation_enum_table[] = + { // NOLINT: cert-err58-cpp + {ChangelogOperation::INSERT, "INSERT"}, + {ChangelogOperation::DELETE, "DELETE"}, + {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, + {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; void to_json(json& j, const ChangelogOperation& e) { static_assert( std::is_enum::value, @@ -508,14 +509,15 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - PartitionTransformType_enum_table[] = { // NOLINT: cert-err58-cpp - {PartitionTransformType::IDENTITY, "IDENTITY"}, - {PartitionTransformType::YEAR, "YEAR"}, - {PartitionTransformType::MONTH, "MONTH"}, - {PartitionTransformType::DAY, "DAY"}, - {PartitionTransformType::HOUR, "HOUR"}, - {PartitionTransformType::BUCKET, "BUCKET"}, - {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; + PartitionTransformType_enum_table[] = + { // NOLINT: cert-err58-cpp + {PartitionTransformType::IDENTITY, "IDENTITY"}, + {PartitionTransformType::YEAR, "YEAR"}, + {PartitionTransformType::MONTH, "MONTH"}, + {PartitionTransformType::DAY, "DAY"}, + {PartitionTransformType::HOUR, "HOUR"}, + {PartitionTransformType::BUCKET, "BUCKET"}, + {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; void to_json(json& j, const PartitionTransformType& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.cpp new file mode 100644 index 0000000000000..804ec031a3dc8 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.cpp @@ -0,0 +1,289 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" +#include +#include +#include +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" +#include "velox/connectors/tpch/TpchConnector.h" +#include "velox/connectors/tpch/TpchConnectorSplit.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/RoundRobinPartitionFunction.h" +#include "velox/tpch/gen/TpchGen.h" + +namespace facebook::presto::protocol::tpch { + +namespace { + +std::string readUTF(std::istringstream& in) { + // Java's modified UTF-8 format: 2-byte length followed by UTF-8 bytes + uint16_t length; + in.read(reinterpret_cast(&length), sizeof(length)); + length = folly::Endian::big(length); + + std::string result(length, '\0'); + in.read(&result[0], length); + return result; +} + +double readDouble(std::istringstream& in) { + // Java writes doubles as 8 bytes in big-endian + uint64_t value; + in.read(reinterpret_cast(&value), sizeof(value)); + value = folly::Endian::big(value); + + double result; + memcpy(&result, &value, sizeof(double)); + return result; +} + +int32_t readInt(std::istringstream& in) { + // Java writes ints as 4 bytes in big-endian + uint32_t value; + in.read(reinterpret_cast(&value), sizeof(value)); + value = folly::Endian::big(value); + return static_cast(value); +} + +int64_t readLong(std::istringstream& in) { + // Java writes longs as 8 bytes in big-endian + uint64_t value; + in.read(reinterpret_cast(&value), sizeof(value)); + value = folly::Endian::big(value); + return static_cast(value); +} + +} // namespace + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + std::istringstream in(binaryData); + auto handle = std::make_shared(); + + handle->tableName = readUTF(in); + handle->scaleFactor = readDouble(in); + handle->_type = "tpch"; + + proto = handle; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + std::istringstream in(binaryData); + auto handle = std::make_shared(); + + handle->table.tableName = readUTF(in); + handle->table.scaleFactor = readDouble(in); + handle->table._type = "tpch"; + + std::string predicateJson = readUTF(in); + json j = json::parse(predicateJson); + handle->predicate = + j.get>>(); + + handle->_type = "tpch"; + + proto = handle; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + std::istringstream in(binaryData); + auto handle = std::make_shared(); + + handle->columnName = readUTF(in); + handle->type = readUTF(in); + + int32_t subfieldCount = readInt(in); + + handle->requiredSubfields.reserve(subfieldCount); + for (int32_t i = 0; i < subfieldCount; i++) { + std::string subfield = readUTF(in); + handle->requiredSubfields.push_back(subfield); + } + + handle->_type = "tpch"; + + proto = handle; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + std::istringstream in(binaryData); + auto split = std::make_shared(); + + split->tableHandle.tableName = readUTF(in); + split->tableHandle.scaleFactor = readDouble(in); + split->tableHandle._type = "tpch"; + + split->partNumber = readInt(in); + split->totalParts = readInt(in); + + int32_t addressCount = readInt(in); + split->addresses.reserve(addressCount); + for (int32_t i = 0; i < addressCount; i++) { + std::string host = readUTF(in); + int32_t port = readInt(in); + split->addresses.push_back(host + ":" + std::to_string(port)); + } + + std::string predicateJson = readUTF(in); + + json j = json::parse(predicateJson); + split->predicate = + j.get>>(); + + split->_type = "tpch"; + + proto = split; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + std::istringstream in(binaryData); + auto handle = std::make_shared(); + + handle->table = readUTF(in); + handle->totalRows = readLong(in); + handle->_type = "tpch"; + + proto = handle; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + // TpchTransactionHandle in Java is essentially a singleton with no data + // The binary data is empty (size 0) + auto handle = std::make_shared(); + handle->instance = "INSTANCE"; + handle->_type = "tpch"; + + proto = handle; +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + VELOX_NYI("TpchInsertTableHandle not supported"); +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + VELOX_NYI("TpchOutputTableHandle not supported"); +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + VELOX_NYI("TpchDeleteTableHandle not supported"); +} + +void TpchConnectorProtocol::deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const { + VELOX_NYI("TpchIndexHandle not supported"); +} + +} // namespace facebook::presto::protocol::tpch + +namespace facebook::presto { + +using namespace protocol; +using namespace protocol::tpch; + +std::unique_ptr +TpchPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const { + auto* tpchSplit = dynamic_cast(connectorSplit); + if (!tpchSplit) { + VELOX_FAIL("Expected TpchSplit but got {}", connectorSplit->_type); + } + + return std::make_unique( + catalogId, + splitContext->cacheable, + tpchSplit->totalParts, + tpchSplit->partNumber); +} + +std::unique_ptr +TpchPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto* tpchColumn = dynamic_cast(column); + if (!tpchColumn) { + VELOX_FAIL("Expected TpchColumnHandle but got {}", column->_type); + } + + return std::make_unique( + tpchColumn->columnName); +} + +std::unique_ptr +TpchPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + velox::connector::ColumnHandleMap& assignments) const { + auto tpchLayout = + std::dynamic_pointer_cast( + tableHandle.connectorTableLayout); + VELOX_CHECK_NOT_NULL( + tpchLayout, + "Unexpected layout type {}", + tableHandle.connectorTableLayout->_type); + + return std::make_unique( + tableHandle.connectorId, + velox::tpch::fromTableName(tpchLayout->table.tableName), + tpchLayout->table.scaleFactor); +} + +std::unique_ptr +TpchPrestoToVeloxConnector::createVeloxPartitionFunctionSpec( + const protocol::ConnectorPartitioningHandle* partitioningHandle, + const std::vector& bucketToPartition, + const std::vector& channels, + const std::vector& constValues, + bool& effectivelyGather) const { + auto* tpchPartitioningHandle = + dynamic_cast(partitioningHandle); + if (!tpchPartitioningHandle) { + VELOX_FAIL( + "Expected TpchPartitioningHandle but got {}", + partitioningHandle->_type); + } + + effectivelyGather = false; + return std::make_unique(); +} + +std::unique_ptr +TpchPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h index bca3818f33cd5..17d44c450ab88 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h @@ -14,21 +14,324 @@ #pragma once -#include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" +#include +#include +#include +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +namespace velox::connector { +class ConnectorSplit; +class ColumnHandle; +class ConnectorTableHandle; +} // namespace velox::connector + +namespace facebook::presto { +class TypeParser; +class VeloxExprConverter; +} // namespace facebook::presto namespace facebook::presto::protocol::tpch { -using TpchConnectorProtocol = ConnectorProtocolTemplate< - TpchTableHandle, - TpchTableLayoutHandle, - TpchColumnHandle, - NotImplemented, - NotImplemented, - TpchSplit, - TpchPartitioningHandle, - TpchTransactionHandle, - NotImplemented, - NotImplemented>; +struct TpchTableHandle : public ConnectorTableHandle { + std::string tableName; + double scaleFactor; + + TpchTableHandle() { + _type = "tpch"; + } +}; + +struct TpchTableLayoutHandle : public ConnectorTableLayoutHandle { + TpchTableHandle table; + TupleDomain> predicate; + + TpchTableLayoutHandle() { + _type = "tpch"; + // Initialize predicate as "all" (empty domains means no filtering) + predicate.domains = nullptr; + } +}; + +struct TpchColumnHandle : public ColumnHandle { + std::string columnName; + std::string type; + std::vector requiredSubfields; + + TpchColumnHandle() { + _type = "tpch"; + } +}; + +struct TpchSplit : public ConnectorSplit { + TpchTableHandle tableHandle; + int partNumber; + int totalParts; + List addresses; + TupleDomain> predicate; + + TpchSplit() { + _type = "tpch"; + // Initialize predicate as "all" (empty domains means no filtering) + predicate.domains = nullptr; + } +}; + +struct TpchTransactionHandle : public ConnectorTransactionHandle { + std::string instance; + + TpchTransactionHandle() { + _type = "tpch"; + } +}; + +struct TpchPartitioningHandle : public ConnectorPartitioningHandle { + std::string table; + int64_t totalRows; + + TpchPartitioningHandle() { + _type = "tpch"; + } +}; + +class TpchConnectorProtocol final : public ConnectorProtocol { + public: + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void deserialize( + const std::string& binaryData, + std::shared_ptr& proto) const override; + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + // These handle types are not used by TPCH, but need dummy implementations + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } + + void to_json(json& j, const std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void from_json(const json& j, std::shared_ptr& p) + const override { + VELOX_NYI("JSON not supported with binary serialization"); + } + + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const override { + VELOX_NYI("Serialize not implemented"); + } +}; } // namespace facebook::presto::protocol::tpch + +namespace facebook::presto { + +class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit TpchPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + velox::connector::ColumnHandleMap& assignments) const final; + + std::unique_ptr + createVeloxPartitionFunctionSpec( + const protocol::ConnectorPartitioningHandle* partitioningHandle, + const std::vector& bucketToPartition, + const std::vector& channels, + const std::vector& constValues, + bool& effectivelyGather) const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-cpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-cpp.mustache deleted file mode 100644 index 665a5d11653a7..0000000000000 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-cpp.mustache +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// presto_protocol.prolog.cpp -// - -{{#.}} -{{#comment}} -{{comment}} -{{/comment}} -{{/.}} - - -#include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" -using namespace std::string_literals; - -namespace facebook::presto::protocol::tpch { - -void to_json(json& j, const TpchTransactionHandle& p) { - j = json::array(); - j.push_back(p._type); - j.push_back(p.instance); -} - -void from_json(const json& j, TpchTransactionHandle& p) { - j[0].get_to(p._type); - j[1].get_to(p.instance); -} -} // namespace facebook::presto::protocol -{{#.}} -{{#cinc}} -{{&cinc}} -{{/cinc}} -{{^cinc}} -{{#struct}} -namespace facebook::presto::protocol::tpch { - {{#super_class}} - {{&class_name}}::{{&class_name}}() noexcept { - _type = "{{json_key}}"; - } - {{/super_class}} - - void to_json(json& j, const {{&class_name}}& p) { - j = json::object(); - {{#super_class}} - j["@type"] = "{{&json_key}}"; - {{/super_class}} - {{#fields}} - to_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); - {{/fields}} - } - - void from_json(const json& j, {{&class_name}}& p) { - {{#super_class}} - p._type = j["@type"]; - {{/super_class}} - {{#fields}} - from_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); - {{/fields}} - } -} -{{/struct}} -{{#enum}} -namespace facebook::presto::protocol::tpch { - //Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() - - // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays - static const std::pair<{{&class_name}}, json> - {{&class_name}}_enum_table[] = { // NOLINT: cert-err58-cpp - {{#elements}} - { {{&class_name}}::{{&element}}, "{{&element}}" }{{^_last}},{{/_last}} - {{/elements}} - }; - void to_json(json& j, const {{&class_name}}& e) - { - static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); - const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), - [e](const std::pair<{{&class_name}}, json>& ej_pair) -> bool - { - return ej_pair.first == e; - }); - j = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->second; - } - void from_json(const json& j, {{&class_name}}& e) - { - static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); - const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), - [&j](const std::pair<{{&class_name}}, json>& ej_pair) -> bool - { - return ej_pair.second == j; - }); - e = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->first; - } -} -{{/enum}} -{{#abstract}} -namespace facebook::presto::protocol::tpch { - void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p) { - if ( p == nullptr ) { - return; - } - String type = p->_type; - - {{#subclasses}} - if ( type == "{{&key}}" ) { - j = *std::static_pointer_cast<{{&type}}>(p); - return; - } - {{/subclasses}} - - throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); - } - - void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p) { - String type; - try { - type = p->getSubclassKey(j); - } catch (json::parse_error &e) { - throw ParseError(std::string(e.what()) + " {{&class_name}} {{&key}} {{&class_name}}"); - } - - {{#subclasses}} - if ( type == "{{&key}}" ) { - std::shared_ptr<{{&type}}> k = std::make_shared<{{&type}}>(); - j.get_to(*k); - p = std::static_pointer_cast<{{&class_name}}>(k); - return; - } - {{/subclasses}} - - throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); - } -} -{{/abstract}} -{{/cinc}} -{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-hpp.mustache deleted file mode 100644 index f4705ebc2aa2b..0000000000000 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol-json-hpp.mustache +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -{{#.}} -{{#comment}} -{{comment}} -{{/comment}} -{{/.}} - -#include -#include - -#include "presto_cpp/external/json/nlohmann/json.hpp" -#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" - -namespace facebook::presto::protocol::tpch { -struct TpchTransactionHandle : public ConnectorTransactionHandle { - String instance = {}; - }; -void to_json(json& j, const TpchTransactionHandle& p); - -void from_json(const json& j, TpchTransactionHandle& p); -} //namespace facebook::presto::protocol -{{#.}} -{{#hinc}} -{{&hinc}} -{{/hinc}} -{{^hinc}} -{{#struct}} -namespace facebook::presto::protocol::tpch { - struct {{class_name}} {{#super_class}}: public {{super_class}}{{/super_class}}{ - {{#fields}} - {{#field_local}}{{#optional}}std::shared_ptr<{{/optional}}{{&field_text}}{{#optional}}>{{/optional}} {{&field_name}} = {};{{/field_local}} - {{/fields}} - - {{#super_class}} - {{class_name}}() noexcept; - {{/super_class}} - }; - void to_json(json& j, const {{class_name}}& p); - void from_json(const json& j, {{class_name}}& p); -} -{{/struct}} -{{#enum}} -namespace facebook::presto::protocol::tpch { - enum class {{class_name}} { - {{#elements}} - {{&element}}{{^_last}},{{/_last}} - {{/elements}} - }; - extern void to_json(json& j, const {{class_name}}& e); - extern void from_json(const json& j, {{class_name}}& e); -} -{{/enum}} -{{/hinc}} -{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp deleted file mode 100644 index 0ed0bbecd1de9..0000000000000 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp +++ /dev/null @@ -1,255 +0,0 @@ -// DO NOT EDIT : This file is generated by chevron -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// presto_protocol.prolog.cpp -// - -// This file is generated DO NOT EDIT @generated - -#include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" -using namespace std::string_literals; - -namespace facebook::presto::protocol::tpch { - -void to_json(json& j, const TpchTransactionHandle& p) { - j = json::array(); - j.push_back(p._type); - j.push_back(p.instance); -} - -void from_json(const json& j, TpchTransactionHandle& p) { - j[0].get_to(p._type); - j[1].get_to(p.instance); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -TpchColumnHandle::TpchColumnHandle() noexcept { - _type = "tpch"; -} - -void to_json(json& j, const TpchColumnHandle& p) { - j = json::object(); - j["@type"] = "tpch"; - to_json_key( - j, - "columnName", - p.columnName, - "TpchColumnHandle", - "String", - "columnName"); - to_json_key(j, "type", p.type, "TpchColumnHandle", "Type", "type"); -} - -void from_json(const json& j, TpchColumnHandle& p) { - p._type = j["@type"]; - from_json_key( - j, - "columnName", - p.columnName, - "TpchColumnHandle", - "String", - "columnName"); - from_json_key(j, "type", p.type, "TpchColumnHandle", "Type", "type"); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -TpchPartitioningHandle::TpchPartitioningHandle() noexcept { - _type = "tpch"; -} - -void to_json(json& j, const TpchPartitioningHandle& p) { - j = json::object(); - j["@type"] = "tpch"; - to_json_key(j, "table", p.table, "TpchPartitioningHandle", "String", "table"); - to_json_key( - j, - "totalRows", - p.totalRows, - "TpchPartitioningHandle", - "int64_t", - "totalRows"); -} - -void from_json(const json& j, TpchPartitioningHandle& p) { - p._type = j["@type"]; - from_json_key( - j, "table", p.table, "TpchPartitioningHandle", "String", "table"); - from_json_key( - j, - "totalRows", - p.totalRows, - "TpchPartitioningHandle", - "int64_t", - "totalRows"); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -void to_json(json& j, const std::shared_ptr& p) { - if (p == nullptr) { - return; - } - String type = p->_type; - - if (type == "tpch") { - j = *std::static_pointer_cast(p); - return; - } - - throw TypeError(type + " no abstract type ColumnHandle "); -} - -void from_json(const json& j, std::shared_ptr& p) { - String type; - try { - type = p->getSubclassKey(j); - } catch (json::parse_error& e) { - throw ParseError(std::string(e.what()) + " ColumnHandle ColumnHandle"); - } - - if (type == "tpch") { - std::shared_ptr k = std::make_shared(); - j.get_to(*k); - p = std::static_pointer_cast(k); - return; - } - - throw TypeError(type + " no abstract type ColumnHandle "); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -TpchTableHandle::TpchTableHandle() noexcept { - _type = "tpch"; -} - -void to_json(json& j, const TpchTableHandle& p) { - j = json::object(); - j["@type"] = "tpch"; - to_json_key( - j, "tableName", p.tableName, "TpchTableHandle", "String", "tableName"); - to_json_key( - j, - "scaleFactor", - p.scaleFactor, - "TpchTableHandle", - "double", - "scaleFactor"); -} - -void from_json(const json& j, TpchTableHandle& p) { - p._type = j["@type"]; - from_json_key( - j, "tableName", p.tableName, "TpchTableHandle", "String", "tableName"); - from_json_key( - j, - "scaleFactor", - p.scaleFactor, - "TpchTableHandle", - "double", - "scaleFactor"); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -TpchSplit::TpchSplit() noexcept { - _type = "tpch"; -} - -void to_json(json& j, const TpchSplit& p) { - j = json::object(); - j["@type"] = "tpch"; - to_json_key( - j, - "tableHandle", - p.tableHandle, - "TpchSplit", - "TpchTableHandle", - "tableHandle"); - to_json_key(j, "partNumber", p.partNumber, "TpchSplit", "int", "partNumber"); - to_json_key(j, "totalParts", p.totalParts, "TpchSplit", "int", "totalParts"); - to_json_key( - j, - "addresses", - p.addresses, - "TpchSplit", - "List", - "addresses"); - to_json_key( - j, - "predicate", - p.predicate, - "TpchSplit", - "TupleDomain>", - "predicate"); -} - -void from_json(const json& j, TpchSplit& p) { - p._type = j["@type"]; - from_json_key( - j, - "tableHandle", - p.tableHandle, - "TpchSplit", - "TpchTableHandle", - "tableHandle"); - from_json_key( - j, "partNumber", p.partNumber, "TpchSplit", "int", "partNumber"); - from_json_key( - j, "totalParts", p.totalParts, "TpchSplit", "int", "totalParts"); - from_json_key( - j, - "addresses", - p.addresses, - "TpchSplit", - "List", - "addresses"); - from_json_key( - j, - "predicate", - p.predicate, - "TpchSplit", - "TupleDomain>", - "predicate"); -} -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -TpchTableLayoutHandle::TpchTableLayoutHandle() noexcept { - _type = "tpch"; -} - -void to_json(json& j, const TpchTableLayoutHandle& p) { - j = json::object(); - j["@type"] = "tpch"; - to_json_key( - j, "table", p.table, "TpchTableLayoutHandle", "TpchTableHandle", "table"); - to_json_key( - j, - "predicate", - p.predicate, - "TpchTableLayoutHandle", - "TupleDomain>", - "predicate"); -} - -void from_json(const json& j, TpchTableLayoutHandle& p) { - p._type = j["@type"]; - from_json_key( - j, "table", p.table, "TpchTableLayoutHandle", "TpchTableHandle", "table"); - from_json_key( - j, - "predicate", - p.predicate, - "TpchTableLayoutHandle", - "TupleDomain>", - "predicate"); -} -} // namespace facebook::presto::protocol::tpch diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h b/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h deleted file mode 100644 index 1212e8bd5152d..0000000000000 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h +++ /dev/null @@ -1,92 +0,0 @@ -// DO NOT EDIT : This file is generated by chevron -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -// This file is generated DO NOT EDIT @generated - -#include -#include - -#include "presto_cpp/external/json/nlohmann/json.hpp" -#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" - -namespace facebook::presto::protocol::tpch { -struct TpchTransactionHandle : public ConnectorTransactionHandle { - String instance = {}; -}; -void to_json(json& j, const TpchTransactionHandle& p); - -void from_json(const json& j, TpchTransactionHandle& p); -} // namespace facebook::presto::protocol::tpch -// TpchColumnHandle is special since it needs an implementation of -// operator<(). - -namespace facebook::presto::protocol::tpch { -struct TpchColumnHandle : public ColumnHandle { - String columnName = {}; - Type type = {}; - - TpchColumnHandle() noexcept; - - bool operator<(const ColumnHandle& o) const override { - return columnName < dynamic_cast(o).columnName; - } -}; -void to_json(json& j, const TpchColumnHandle& p); -void from_json(const json& j, TpchColumnHandle& p); -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -struct TpchPartitioningHandle : public ConnectorPartitioningHandle { - String table = {}; - int64_t totalRows = {}; - - TpchPartitioningHandle() noexcept; -}; -void to_json(json& j, const TpchPartitioningHandle& p); -void from_json(const json& j, TpchPartitioningHandle& p); -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -struct TpchTableHandle : public ConnectorTableHandle { - String tableName = {}; - double scaleFactor = {}; - - TpchTableHandle() noexcept; -}; -void to_json(json& j, const TpchTableHandle& p); -void from_json(const json& j, TpchTableHandle& p); -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -struct TpchSplit : public ConnectorSplit { - TpchTableHandle tableHandle = {}; - int partNumber = {}; - int totalParts = {}; - List addresses = {}; - TupleDomain> predicate = {}; - - TpchSplit() noexcept; -}; -void to_json(json& j, const TpchSplit& p); -void from_json(const json& j, TpchSplit& p); -} // namespace facebook::presto::protocol::tpch -namespace facebook::presto::protocol::tpch { -struct TpchTableLayoutHandle : public ConnectorTableLayoutHandle { - TpchTableHandle table = {}; - TupleDomain> predicate = {}; - - TpchTableLayoutHandle() noexcept; -}; -void to_json(json& j, const TpchTableLayoutHandle& p); -void from_json(const json& j, TpchTableLayoutHandle& p); -} // namespace facebook::presto::protocol::tpch diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h index f349c1a6644af..18c4b9b9d9aba 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h @@ -61,6 +61,12 @@ class ConnectorProtocol { const = 0; virtual void from_json(const json& j, std::shared_ptr& p) const = 0; + virtual void serialize( + const std::shared_ptr& proto, + std::string& thrift) const = 0; + virtual void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const = 0; virtual void to_json( json& j, @@ -105,6 +111,12 @@ class ConnectorProtocol { virtual void from_json( const json& j, std::shared_ptr& p) const = 0; + virtual void serialize( + const std::shared_ptr& proto, + std::string& thrift) const = 0; + virtual void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const = 0; virtual void to_json( json& j, @@ -138,6 +150,12 @@ class ConnectorProtocol { virtual void from_json( const json& j, std::shared_ptr& p) const = 0; + virtual void serialize( + const std::shared_ptr& proto, + std::string& thrift) const = 0; + virtual void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const = 0; }; namespace { @@ -201,6 +219,16 @@ class ConnectorProtocolTemplate final : public ConnectorProtocol { void from_json(const json& j, std::shared_ptr& p) const final { from_json_template(j, p); } + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const final { + serializeTemplate(proto, thrift); + } + void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const final { + deserializeTemplate(thrift, proto); + } void to_json(json& j, const std::shared_ptr& p) const final { @@ -266,6 +294,16 @@ class ConnectorProtocolTemplate final : public ConnectorProtocol { const final { from_json_template(j, p); } + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const final { + serializeTemplate(proto, thrift); + } + void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const final { + deserializeTemplate(thrift, proto); + } void to_json(json& j, const std::shared_ptr& p) const final { @@ -315,6 +353,16 @@ class ConnectorProtocolTemplate final : public ConnectorProtocol { std::shared_ptr& p) const final { from_json_template(j, p); } + void serialize( + const std::shared_ptr& proto, + std::string& thrift) const final { + serializeTemplate(proto, thrift); + } + void deserialize( + const std::string& thrift, + std::shared_ptr& proto) const final { + deserializeTemplate(thrift, proto); + } private: template diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol-json-hpp.mustache index 886735f96963c..48b16639dcb2a 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol-json-hpp.mustache +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol-json-hpp.mustache @@ -253,6 +253,14 @@ namespace facebook::presto::protocol { throw std::runtime_error("missing operator<() in {{class_name}} subclass"); } {{/comparable}} + static std::string serialize({{&class_name}}& p) { + VELOX_NYI("Serialization not implemented for {{&class_name}}"); + } + static std::shared_ptr<{{&class_name}}> deserialize( + const std::string& data, + std::shared_ptr<{{&class_name}}> p) { + VELOX_NYI("Deserialization not implemented for {{&class_name}}"); + } }; void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p); void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 50d7de6024075..4e108083bc46f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -1096,6 +1096,8 @@ void from_json(const json& j, AllOrNoneValueSet& p) { from_json_key(j, "all", p.all, "AllOrNoneValueSet", "bool", "all"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -1106,6 +1108,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -1116,6 +1128,8 @@ void from_json(const json& j, std::shared_ptr& p) { getConnectorProtocol(type).from_json(j, p); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + // dependency TpchTransactionHandle // dependency ArrowTransactionHandle @@ -1134,6 +1148,21 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK( + !type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", + type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -2399,6 +2428,8 @@ void from_json(const json& j, Lifespan& p) { } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -2418,6 +2449,21 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK( + !type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", + type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -3004,6 +3050,8 @@ void from_json(const json& j, ConstantExpression& p) { from_json_key(j, "type", p.type, "ConstantExpression", "Type", "type"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -3014,6 +3062,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -3158,6 +3216,8 @@ void from_json(const json& j, DataOrganizationSpecification& p) { "orderingScheme"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -3168,6 +3228,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -4860,6 +4930,8 @@ void from_json(const json& j, ExchangeEncoding& e) { ->first; } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -4875,6 +4947,21 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK( + !type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", + type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -5607,6 +5694,8 @@ void from_json(const json& j, GroupIdNode& p) { "groupIdVariable"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -5617,6 +5706,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -5830,6 +5929,8 @@ void from_json(const json& j, IndexJoinNode& p) { "lookupVariables"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -5840,6 +5941,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -5849,6 +5960,8 @@ void from_json(const json& j, std::shared_ptr& p) { getConnectorProtocol(type).from_json(j, p); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -5859,6 +5972,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); @@ -6035,6 +6158,8 @@ void from_json(const json& j, IndexSourceNode& p) { "currentConstraint"); } } // namespace facebook::presto::protocol +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -6045,6 +6170,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index dae1c63b907d5..2eb236046a621 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -69,21 +69,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message){}; + : std::runtime_error(message) {}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message){}; + explicit TypeError(const std::string& message) : Exception(message) {}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message){}; + explicit ParseError(const std::string& message) : Exception(message) {}; }; using String = std::string; @@ -267,13 +267,30 @@ struct adl_serializer> { // Forward declaration of all abstract types // namespace facebook::presto::protocol { -struct FunctionHandle : public JsonEncodedSubclass {}; +struct FunctionHandle : public JsonEncodedSubclass { + static std::string serialize(FunctionHandle& p) { + VELOX_NYI("Serialization not implemented for FunctionHandle"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for FunctionHandle"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { struct RowExpression : public JsonEncodedSubclass { std::shared_ptr sourceLocation = {}; + static std::string serialize(RowExpression& p) { + VELOX_NYI("Serialization not implemented for RowExpression"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for RowExpression"); + } }; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); @@ -281,32 +298,86 @@ void from_json(const json& j, std::shared_ptr& p); namespace facebook::presto::protocol { struct PlanNode : public JsonEncodedSubclass { PlanNodeId id = {}; + static std::string serialize(PlanNode& p) { + VELOX_NYI("Serialization not implemented for PlanNode"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for PlanNode"); + } }; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct ExecutionWriterTarget : public JsonEncodedSubclass {}; +struct ExecutionWriterTarget : public JsonEncodedSubclass { + static std::string serialize(ExecutionWriterTarget& p) { + VELOX_NYI("Serialization not implemented for ExecutionWriterTarget"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for ExecutionWriterTarget"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct InputDistribution : public JsonEncodedSubclass {}; +struct InputDistribution : public JsonEncodedSubclass { + static std::string serialize(InputDistribution& p) { + VELOX_NYI("Serialization not implemented for InputDistribution"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for InputDistribution"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct ValueSet : public JsonEncodedSubclass {}; +struct ValueSet : public JsonEncodedSubclass { + static std::string serialize(ValueSet& p) { + VELOX_NYI("Serialization not implemented for ValueSet"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for ValueSet"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct ConnectorPartitioningHandle : public JsonEncodedSubclass {}; +struct ConnectorPartitioningHandle : public JsonEncodedSubclass { + static std::string serialize(ConnectorPartitioningHandle& p) { + VELOX_NYI("Serialization not implemented for ConnectorPartitioningHandle"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI( + "Deserialization not implemented for ConnectorPartitioningHandle"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -struct ConnectorIndexHandle : public JsonEncodedSubclass {}; +struct ConnectorIndexHandle : public JsonEncodedSubclass { + static std::string serialize(ConnectorIndexHandle& p) { + VELOX_NYI("Serialization not implemented for ConnectorIndexHandle"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for ConnectorIndexHandle"); + } +}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol @@ -315,6 +386,14 @@ struct ColumnHandle : public JsonEncodedSubclass { virtual bool operator<(const ColumnHandle& /* o */) const { throw std::runtime_error("missing operator<() in ColumnHandle subclass"); } + static std::string serialize(ColumnHandle& p) { + VELOX_NYI("Serialization not implemented for ColumnHandle"); + } + static std::shared_ptr deserialize( + const std::string& data, + std::shared_ptr p) { + VELOX_NYI("Deserialization not implemented for ColumnHandle"); + } }; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ColumnHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ColumnHandle.cpp.inc index fc0fa0de04851..2a088b2648dfc 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ColumnHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ColumnHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDeleteTableHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDeleteTableHandle.cpp.inc index c8112b1847213..e7ebee037066a 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDeleteTableHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorDeleteTableHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorIndexHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorIndexHandle.cpp.inc index 9346fc4f31f26..5f2ed1262013f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorIndexHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorIndexHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorInsertTableHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorInsertTableHandle.cpp.inc index 98eb8e913218c..4da611f8dc679 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorInsertTableHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorInsertTableHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorOutputTableHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorOutputTableHandle.cpp.inc index a078f4df8e40b..7d069bcbf0508 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorOutputTableHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorOutputTableHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorPartitioningHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorPartitioningHandle.cpp.inc index 89b2827625a41..88058c612d92a 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorPartitioningHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorPartitioningHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -26,6 +28,19 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK(!type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorSplit.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorSplit.cpp.inc index fde3277731624..0883f13d0b8f2 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorSplit.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorSplit.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -30,6 +32,19 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK(!type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableHandle.cpp.inc index ff84783523529..f1ea17bc70ace 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableLayoutHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableLayoutHandle.cpp.inc index 49b641db1cfa4..52934b691b3f4 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableLayoutHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTableLayoutHandle.cpp.inc @@ -11,6 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" + namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -21,6 +23,16 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc index 1dfb17e4a908f..4c7e5ad4f421a 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc @@ -11,6 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/encode/Base64.h" // dependency TpchTransactionHandle // dependency ArrowTransactionHandle @@ -30,6 +31,19 @@ void to_json(json& j, const std::shared_ptr& p) { } void from_json(const json& j, std::shared_ptr& p) { + if (j.contains("customSerializedValue")) { + String type = j["@type"]; + + VELOX_CHECK(!type.empty() && type[0] != '$', + "Internal handle type '{}' should not have customSerializedValue", type); + + std::string base64Data = j["customSerializedValue"]; + std::string binaryData = velox::encoding::Base64::decode(base64Data); + + getConnectorProtocol(type).deserialize(binaryData, p); + return; + } + String type; try { type = p->getSubclassKey(j); diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp index 24f24f27f87a3..9841be4bdb34d 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp @@ -18,5 +18,4 @@ #include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp" -#include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp" #include "presto_cpp/presto_protocol/core/presto_protocol_core.cpp" diff --git a/presto-native-execution/presto_cpp/presto_protocol/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/presto_protocol/tests/CMakeLists.txt index 893ab732a597c..b034b6402eff6 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/presto_protocol/tests/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable( OptionalTest.cpp RowExpressionTest.cpp TaskUpdateRequestTest.cpp + TpchConnectorProtocolTest.cpp TupleDomainTest.cpp TypeErrorTest.cpp VariableReferenceExpressionTest.cpp @@ -39,11 +40,15 @@ target_link_libraries( GTest::gtest GTest::gtest_main $ + presto_tpch_connector_protocol velox_type velox_encode velox_exception velox_vector velox_presto_serializer + velox_connector + velox_exec + velox_expression Boost::filesystem ${RE2} ${FOLLY_LIBRARIES} diff --git a/presto-native-execution/presto_cpp/presto_protocol/tests/TpchConnectorProtocolTest.cpp b/presto-native-execution/presto_cpp/presto_protocol/tests/TpchConnectorProtocolTest.cpp new file mode 100644 index 0000000000000..6e07acccf5cda --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/tests/TpchConnectorProtocolTest.cpp @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" + +using namespace facebook::presto; +using namespace facebook::presto::protocol; +using namespace facebook::presto::protocol::tpch; + +class TpchConnectorProtocolTest : public ::testing::Test { + protected: + TpchConnectorProtocol protocol; +}; + +TEST_F(TpchConnectorProtocolTest, TestTpchTableHandleDeserialization) { + std::ostringstream oss; + + std::string tableName = "customer"; + uint16_t nameLen = htons(static_cast(tableName.length())); + oss.write(reinterpret_cast(&nameLen), 2); + oss.write(tableName.data(), tableName.length()); + + double scaleFactor = 0.01; + uint64_t scaleFactorBits; + std::memcpy(&scaleFactorBits, &scaleFactor, sizeof(double)); + scaleFactorBits = folly::Endian::big(scaleFactorBits); + oss.write(reinterpret_cast(&scaleFactorBits), 8); + + std::string binaryData = oss.str(); + + std::shared_ptr handle; + protocol.deserialize(binaryData, handle); + + auto tpchHandle = std::dynamic_pointer_cast(handle); + ASSERT_NE(tpchHandle, nullptr); + EXPECT_EQ(tpchHandle->tableName, "customer"); + EXPECT_DOUBLE_EQ(tpchHandle->scaleFactor, 0.01); +} + +TEST_F(TpchConnectorProtocolTest, TestTpchColumnHandleDeserialization) { + std::ostringstream oss; + + std::string columnName = "c_custkey"; + uint16_t nameLen = htons(static_cast(columnName.length())); + oss.write(reinterpret_cast(&nameLen), 2); + oss.write(columnName.data(), columnName.length()); + + std::string type = "bigint"; + uint16_t typeLen = htons(static_cast(type.length())); + oss.write(reinterpret_cast(&typeLen), 2); + oss.write(type.data(), type.length()); + + uint32_t subfieldCount = htonl(0); + oss.write(reinterpret_cast(&subfieldCount), 4); + + std::string binaryData = oss.str(); + + std::shared_ptr handle; + protocol.deserialize(binaryData, handle); + + auto tpchHandle = std::dynamic_pointer_cast(handle); + ASSERT_NE(tpchHandle, nullptr); + EXPECT_EQ(tpchHandle->columnName, "c_custkey"); + EXPECT_EQ(tpchHandle->type, "bigint"); +} + +TEST_F(TpchConnectorProtocolTest, TestColumnHandleWithSubfields) { + std::ostringstream oss; + + std::string columnName = "complex_col"; + uint16_t nameLen = htons(static_cast(columnName.length())); + oss.write(reinterpret_cast(&nameLen), 2); + oss.write(columnName.data(), columnName.length()); + + std::string type = "row(field1 bigint, field2 varchar)"; + uint16_t typeLen = htons(static_cast(type.length())); + oss.write(reinterpret_cast(&typeLen), 2); + oss.write(type.data(), type.length()); + + uint32_t subfieldCount = htonl(2); + oss.write(reinterpret_cast(&subfieldCount), 4); + + std::string subfield1 = "field1"; + uint16_t sub1Len = htons(static_cast(subfield1.length())); + oss.write(reinterpret_cast(&sub1Len), 2); + oss.write(subfield1.data(), subfield1.length()); + + std::string subfield2 = "field2"; + uint16_t sub2Len = htons(static_cast(subfield2.length())); + oss.write(reinterpret_cast(&sub2Len), 2); + oss.write(subfield2.data(), subfield2.length()); + + std::string binaryData = oss.str(); + + std::shared_ptr handle; + protocol.deserialize(binaryData, handle); + + auto tpchHandle = std::dynamic_pointer_cast(handle); + ASSERT_NE(tpchHandle, nullptr); + EXPECT_EQ(tpchHandle->columnName, "complex_col"); + EXPECT_EQ(tpchHandle->type, type); + ASSERT_EQ(tpchHandle->requiredSubfields.size(), 2); + EXPECT_EQ(tpchHandle->requiredSubfields[0], "field1"); + EXPECT_EQ(tpchHandle->requiredSubfields[1], "field2"); +} + +TEST_F(TpchConnectorProtocolTest, TestAllTpchTableNames) { + std::vector tableNames = { + "customer", + "lineitem", + "nation", + "orders", + "part", + "partsupp", + "region", + "supplier"}; + + for (const auto& tableName : tableNames) { + std::ostringstream oss; + + uint16_t nameLen = htons(static_cast(tableName.length())); + oss.write(reinterpret_cast(&nameLen), 2); + oss.write(tableName.data(), tableName.length()); + + double scaleFactor = 1.0; + uint64_t scaleFactorBits; + std::memcpy(&scaleFactorBits, &scaleFactor, sizeof(double)); + scaleFactorBits = folly::Endian::big(scaleFactorBits); + oss.write(reinterpret_cast(&scaleFactorBits), 8); + + std::string binaryData = oss.str(); + + std::shared_ptr handle; + protocol.deserialize(binaryData, handle); + + auto tpchHandle = std::dynamic_pointer_cast(handle); + ASSERT_NE(tpchHandle, nullptr); + EXPECT_EQ(tpchHandle->tableName, tableName); + EXPECT_DOUBLE_EQ(tpchHandle->scaleFactor, 1.0); + } +} + +TEST_F(TpchConnectorProtocolTest, TestVariousScaleFactors) { + double scaleFactors[] = {0.01, 0.1, 1.0, 10.0, 100.0, 1000.0}; + + for (double scaleFactor : scaleFactors) { + std::ostringstream oss; + + std::string tableName = "lineitem"; + uint16_t nameLen = htons(static_cast(tableName.length())); + oss.write(reinterpret_cast(&nameLen), 2); + oss.write(tableName.data(), tableName.length()); + + uint64_t scaleFactorBits; + std::memcpy(&scaleFactorBits, &scaleFactor, sizeof(double)); + scaleFactorBits = folly::Endian::big(scaleFactorBits); + oss.write(reinterpret_cast(&scaleFactorBits), 8); + + std::string binaryData = oss.str(); + + std::shared_ptr handle; + protocol.deserialize(binaryData, handle); + + auto tpchHandle = std::dynamic_pointer_cast(handle); + ASSERT_NE(tpchHandle, nullptr); + EXPECT_EQ(tpchHandle->tableName, "lineitem"); + EXPECT_DOUBLE_EQ(tpchHandle->scaleFactor, scaleFactor); + } +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeTpchQueriesWithBinarySerialization.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeTpchQueriesWithBinarySerialization.java new file mode 100644 index 0000000000000..0875b3caef61e --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeTpchQueriesWithBinarySerialization.java @@ -0,0 +1,145 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestNativeTpchQueriesWithBinarySerialization + extends AbstractTestQueryFramework +{ + @Override + protected void createTables() + { + // No need to create tables - TPCH connector provides them + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + // Create a query runner with binary serialization enabled for TPCH connector + return PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() + .setExtraProperties(ImmutableMap.builder() + .put("use-connector-provided-serialization-codecs", "true") + .build()) + .setExtraCoordinatorProperties(ImmutableMap.builder() + .put("use-connector-provided-serialization-codecs", "true") + .build()) + .build(); + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(getSession()) + .setNodeCount(1) + .build(); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpchstandard", "tpch", + ImmutableMap.of("tpch.column-naming", "STANDARD")); + return queryRunner; + } + + @Override + protected Session getSession() + { + // Override the session to use TPCH connector catalog and schema by default + return testSessionBuilder() + .setCatalog("tpchstandard") + .setSchema("tiny") + .build(); + } + + @Test + public void testTpchConnectorDirectQuery() + { + assertQuery("SELECT * FROM tpchstandard.tiny.nation"); + assertQuery("SELECT * FROM tpchstandard.tiny.region"); + } + + @Test + public void testTpchConnectorFilterPushdown() + { + assertQuery("SELECT * FROM tpchstandard.tiny.nation WHERE n_nationkey = 1"); + // Select only columns that match between Java and C++ TPCH implementations + // Avoiding c_address which has different random string generation + assertQuery("SELECT c_custkey, c_name, c_nationkey, c_phone, ROUND(c_acctbal, 2), c_mktsegment, c_comment " + + "FROM tpchstandard.tiny.customer WHERE c_custkey < 10"); + } + + @Test + public void testTpchConnectorJoin() + { + assertQuery("SELECT n.n_name, r.r_name " + + "FROM tpchstandard.tiny.nation n " + + "JOIN tpchstandard.tiny.region r ON n.n_regionkey = r.r_regionkey"); + } + + @Test + public void testTpchConnectorAggregation() + { + assertQuery("SELECT n_regionkey, COUNT(*) " + + "FROM tpchstandard.tiny.nation " + + "GROUP BY n_regionkey"); + + assertQuery("SELECT r_name, COUNT(*) as nation_count " + + "FROM tpchstandard.tiny.nation n " + + "JOIN tpchstandard.tiny.region r ON n.n_regionkey = r.r_regionkey " + + "GROUP BY r_name " + + "ORDER BY nation_count DESC"); + } + + @Test + public void testTpchConnectorComplexQuery() + { + assertQuerySucceeds("SELECT " + + " l.l_orderkey, " + + " SUM(l.l_extendedprice * (1 - l.l_discount)) as revenue, " + + " o.o_orderdate, " + + " o.o_shippriority " + + "FROM " + + " tpchstandard.tiny.customer c, " + + " tpchstandard.tiny.orders o, " + + " tpchstandard.tiny.lineitem l " + + "WHERE " + + " c.c_mktsegment = 'BUILDING' " + + " AND c.c_custkey = o.o_custkey " + + " AND l.l_orderkey = o.o_orderkey " + + " AND o.o_orderdate < DATE '1995-03-15' " + + " AND l.l_shipdate > DATE '1995-03-15' " + + "GROUP BY " + + " l.l_orderkey, " + + " o.o_orderdate, " + + " o.o_shippriority " + + "ORDER BY " + + " revenue DESC, " + + " o.o_orderdate " + + "LIMIT 10"); + } + + @Test + public void testDelimitedIdentifiers() + { + assertQuery("SELECT \"c_custkey\", \"c_name\" FROM tpchstandard.tiny.customer WHERE \"c_custkey\" < 10"); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index f0f297904497b..a1b8efdc8ec6f 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -31,6 +31,7 @@ import com.facebook.presto.common.block.BlockEncoding; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.connector.ConnectorCodecManager; @@ -134,6 +135,7 @@ import com.facebook.presto.spark.planner.PrestoSparkRddFactory; import com.facebook.presto.spark.planner.PrestoSparkStatsCalculatorModule; import com.facebook.presto.spark.planner.optimizers.AdaptivePlanOptimizers; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; @@ -307,6 +309,7 @@ protected void setup(Binder binder) jsonCodecBinder(binder).bindJsonCodec(BatchTaskUpdateRequest.class); jsonCodecBinder(binder).bindJsonCodec(BroadcastFileInfo.class); jsonCodecBinder(binder).bindJsonCodec(SimplePlanFragment.class); + jsonCodecBinder(binder).bindJsonCodec(new TypeLiteral>() {}); binder.bind(SimplePlanFragmentSerde.class).to(JsonCodecSimplePlanFragmentSerde.class).in(Scopes.SINGLETON); // smile codecs diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java index fda1cee85153c..731857bd3a467 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java @@ -25,9 +25,9 @@ import com.facebook.presto.execution.TaskSource; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.RemoteTransactionHandle; import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.TestingHandleJsonModule; import com.facebook.presto.server.TaskUpdateRequest; import com.facebook.presto.spark.execution.http.BatchTaskUpdateRequest; import com.facebook.presto.spark.execution.shuffle.PrestoSparkLocalShuffleInfoTranslator; @@ -36,7 +36,6 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.split.RemoteSplit; -import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.testing.TestingSession; import com.facebook.presto.type.TypeDeserializer; @@ -51,7 +50,6 @@ import java.util.List; import java.util.Optional; -import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.airlift.json.JsonBinder.jsonBinder; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; import static com.facebook.presto.execution.TaskTestUtils.createPlanFragment; @@ -150,8 +148,7 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); - configBinder(binder).bindConfig(FeaturesConfig.class); + binder.install(new TestingHandleJsonModule()); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/TupleDomainSerde.java b/presto-spi/src/main/java/com/facebook/presto/spi/TupleDomainSerde.java new file mode 100644 index 0000000000000..95487365fa321 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/TupleDomainSerde.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import com.facebook.presto.common.predicate.TupleDomain; + +public interface TupleDomainSerde +{ + String serialize(TupleDomain tupleDomain); + + TupleDomain deserialize(String serialized); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java index 4bd2d81d456b4..fb9f180395356 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.spi.connector; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorCodec; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorIndexHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; @@ -59,4 +61,19 @@ default Optional> getConnectorTableHandleCo { return Optional.empty(); } + + default Optional> getColumnHandleCodec() + { + return Optional.empty(); + } + + default Optional> getConnectorPartitioningHandleCodec() + { + return Optional.empty(); + } + + default Optional> getConnectorIndexHandleCodec() + { + return Optional.empty(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java index a954291e9c822..6bae6dafb19f6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorContext.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.TupleDomainSerde; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; @@ -75,4 +76,9 @@ default ConnectorSystemConfig getConnectorSystemConfig() { throw new UnsupportedOperationException(); } + + default TupleDomainSerde getTupleDomainSerde() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TestTpcdsWithCharColumnsAsChar.java b/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TestTpcdsWithCharColumnsAsChar.java index b5bca766686f3..564329f1cdeed 100644 --- a/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TestTpcdsWithCharColumnsAsChar.java +++ b/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TestTpcdsWithCharColumnsAsChar.java @@ -14,6 +14,7 @@ package com.facebook.presto.tpcds; import com.facebook.presto.testing.QueryRunner; +import com.google.common.collect.ImmutableMap; public class TestTpcdsWithCharColumnsAsChar extends AbstractTestTpcds @@ -22,6 +23,7 @@ public class TestTpcdsWithCharColumnsAsChar protected QueryRunner createQueryRunner() throws Exception { - return TpcdsQueryRunner.createQueryRunner(); + return TpcdsQueryRunner.createQueryRunner( + ImmutableMap.of("use-connector-provided-serialization-codecs", "true")); } } diff --git a/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TpcdsQueryRunner.java b/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TpcdsQueryRunner.java index 8643bed4305d2..2074fd9963248 100644 --- a/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TpcdsQueryRunner.java +++ b/presto-tpcds/src/test/java/com/facebook/presto/tpcds/TpcdsQueryRunner.java @@ -75,7 +75,9 @@ public static void main(String[] args) throws Exception { Logging.initialize(); - DistributedQueryRunner queryRunner = createQueryRunner(ImmutableMap.of("http-server.http.port", "8080")); + DistributedQueryRunner queryRunner = createQueryRunner(ImmutableMap.of( + "http-server.http.port", "8080", + "use-connector-provided-serialization-codecs", "true")); Thread.sleep(10); Logger log = Logger.get(TpcdsQueryRunner.class); log.info("======== SERVER STARTED ========"); diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorCodecProvider.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorCodecProvider.java new file mode 100644 index 0000000000000..31eff98cd9048 --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorCodecProvider.java @@ -0,0 +1,378 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch; + +import com.facebook.presto.common.Subfield; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.TupleDomainSerde; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static java.util.Objects.requireNonNull; + +public class TpchConnectorCodecProvider + implements ConnectorCodecProvider +{ + private final TypeManager typeManager; + private final TupleDomainSerde tupleDomainSerde; + + public TpchConnectorCodecProvider(TypeManager typeManager, TupleDomainSerde tupleDomainSerde) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.tupleDomainSerde = requireNonNull(tupleDomainSerde, "tupleDomainSerde is null"); + } + + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of(new TpchTableHandleCodec()); + } + + @Override + public Optional> getConnectorTableLayoutHandleCodec() + { + return Optional.of(new TpchTableLayoutHandleCodec(tupleDomainSerde)); + } + + @Override + public Optional> getColumnHandleCodec() + { + return Optional.of(new TpchColumnHandleCodec(typeManager)); + } + + @Override + public Optional> getConnectorSplitCodec() + { + return Optional.of(new TpchSplitCodec(tupleDomainSerde)); + } + + @Override + public Optional> getConnectorOutputTableHandleCodec() + { + // TPC-H doesn't support writes + return Optional.empty(); + } + + @Override + public Optional> getConnectorInsertTableHandleCodec() + { + // TPC-H doesn't support writes + return Optional.empty(); + } + + @Override + public Optional> getConnectorPartitioningHandleCodec() + { + return Optional.of(new TpchPartitioningHandleCodec()); + } + + @Override + public Optional> getConnectorTransactionHandleCodec() + { + return Optional.of(new TpchTransactionHandleCodec()); + } + + private static class TpchTableHandleCodec + implements ConnectorCodec + { + @Override + public byte[] serialize(ConnectorTableHandle handle) + { + try { + TpchTableHandle tableHandle = (TpchTableHandle) handle; + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(byteOut); + out.writeUTF(tableHandle.getTableName()); + out.writeDouble(tableHandle.getScaleFactor()); + return byteOut.toByteArray(); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to serialize TpchTableHandle", e); + } + } + + @Override + public ConnectorTableHandle deserialize(byte[] bytes) + { + try { + ByteArrayInputStream byteIn = new ByteArrayInputStream(bytes); + DataInputStream in = new DataInputStream(byteIn); + String tableName = in.readUTF(); + double scaleFactor = in.readDouble(); + return new TpchTableHandle(tableName, scaleFactor); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize TpchTableHandle", e); + } + } + } + + private static class TpchTableLayoutHandleCodec + implements ConnectorCodec + { + private final TupleDomainSerde tupleDomainSerde; + + public TpchTableLayoutHandleCodec(TupleDomainSerde tupleDomainSerde) + { + this.tupleDomainSerde = requireNonNull(tupleDomainSerde, "tupleDomainSerde is null"); + } + + @Override + public byte[] serialize(ConnectorTableLayoutHandle handle) + { + try { + TpchTableLayoutHandle layoutHandle = (TpchTableLayoutHandle) handle; + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(byteOut); + // Serialize the table handle + out.writeUTF(layoutHandle.getTable().getTableName()); + out.writeDouble(layoutHandle.getTable().getScaleFactor()); + + // Serialize the predicate using JSON + TupleDomain predicate = layoutHandle.getPredicate(); + String predicateJson = tupleDomainSerde.serialize(predicate); + out.writeUTF(predicateJson); + + return byteOut.toByteArray(); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to serialize TpchTableLayoutHandle", e); + } + } + + @Override + public ConnectorTableLayoutHandle deserialize(byte[] bytes) + { + try { + ByteArrayInputStream byteIn = new ByteArrayInputStream(bytes); + DataInputStream in = new DataInputStream(byteIn); + String tableName = in.readUTF(); + double scaleFactor = in.readDouble(); + TpchTableHandle table = new TpchTableHandle(tableName, scaleFactor); + + String predicateJson = in.readUTF(); + TupleDomain predicate = tupleDomainSerde.deserialize(predicateJson); + + return new TpchTableLayoutHandle(table, predicate); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize TpchTableLayoutHandle", e); + } + } + } + + private static class TpchColumnHandleCodec + implements ConnectorCodec + { + private final TypeManager typeManager; + + public TpchColumnHandleCodec(TypeManager typeManager) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + + @Override + public byte[] serialize(ColumnHandle handle) + { + try { + TpchColumnHandle columnHandle = (TpchColumnHandle) handle; + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(byteOut); + out.writeUTF(columnHandle.getColumnName()); + out.writeUTF(columnHandle.getType().getTypeSignature().toString()); + // Serialize required subfields + List subfields = columnHandle.getRequiredSubfields(); + out.writeInt(subfields.size()); + for (Subfield subfield : subfields) { + out.writeUTF(subfield.serialize()); + } + return byteOut.toByteArray(); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to serialize TpchColumnHandle", e); + } + } + + @Override + public ColumnHandle deserialize(byte[] bytes) + { + try { + ByteArrayInputStream byteIn = new ByteArrayInputStream(bytes); + DataInputStream in = new DataInputStream(byteIn); + String columnName = in.readUTF(); + String typeSignature = in.readUTF(); + Type type = typeManager.getType(parseTypeSignature(typeSignature)); + int subfieldCount = in.readInt(); + List subfields = new ArrayList<>(subfieldCount); + for (int i = 0; i < subfieldCount; i++) { + subfields.add(new Subfield(in.readUTF())); + } + return new TpchColumnHandle(columnName, type, subfields); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize TpchColumnHandle", e); + } + } + } + + private static class TpchSplitCodec + implements ConnectorCodec + { + private final TupleDomainSerde tupleDomainSerde; + + public TpchSplitCodec(TupleDomainSerde tupleDomainSerde) + { + this.tupleDomainSerde = requireNonNull(tupleDomainSerde, "tupleDomainSerde is null"); + } + + @Override + public byte[] serialize(ConnectorSplit split) + { + try { + TpchSplit tpchSplit = (TpchSplit) split; + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(byteOut); + // Serialize table handle + out.writeUTF(tpchSplit.getTableHandle().getTableName()); + out.writeDouble(tpchSplit.getTableHandle().getScaleFactor()); + // Serialize split info + out.writeInt(tpchSplit.getPartNumber()); + out.writeInt(tpchSplit.getTotalParts()); + // Serialize addresses + List addresses = tpchSplit.getAddresses(); + out.writeInt(addresses.size()); + for (HostAddress address : addresses) { + out.writeUTF(address.getHostText()); + out.writeInt(address.getPort()); + } + // Serialize the predicate using JSON + TupleDomain predicate = tpchSplit.getPredicate(); + String predicateJson = tupleDomainSerde.serialize(predicate); + out.writeUTF(predicateJson); + return byteOut.toByteArray(); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to serialize TpchSplit", e); + } + } + + @Override + public ConnectorSplit deserialize(byte[] bytes) + { + try { + ByteArrayInputStream byteIn = new ByteArrayInputStream(bytes); + DataInputStream in = new DataInputStream(byteIn); + // Deserialize table handle + String tableName = in.readUTF(); + double scaleFactor = in.readDouble(); + TpchTableHandle tableHandle = new TpchTableHandle(tableName, scaleFactor); + // Deserialize split info + int partNumber = in.readInt(); + int totalParts = in.readInt(); + // Deserialize addresses + int addressCount = in.readInt(); + ImmutableList.Builder addresses = ImmutableList.builder(); + for (int i = 0; i < addressCount; i++) { + String host = in.readUTF(); + int port = in.readInt(); + addresses.add(HostAddress.fromParts(host, port)); + } + // Deserialize the predicate + String predicateJson = in.readUTF(); + TupleDomain predicate = tupleDomainSerde.deserialize(predicateJson); + return new TpchSplit(tableHandle, partNumber, totalParts, addresses.build(), predicate); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize TpchSplit", e); + } + } + } + + private static class TpchPartitioningHandleCodec + implements ConnectorCodec + { + @Override + public byte[] serialize(ConnectorPartitioningHandle handle) + { + try { + TpchPartitioningHandle partitioningHandle = (TpchPartitioningHandle) handle; + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(byteOut); + out.writeUTF(partitioningHandle.getTable()); + out.writeLong(partitioningHandle.getTotalRows()); + return byteOut.toByteArray(); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to serialize TpchPartitioningHandle", e); + } + } + + @Override + public ConnectorPartitioningHandle deserialize(byte[] bytes) + { + try { + ByteArrayInputStream byteIn = new ByteArrayInputStream(bytes); + DataInputStream in = new DataInputStream(byteIn); + String tableName = in.readUTF(); + long totalRows = in.readLong(); + return new TpchPartitioningHandle(tableName, totalRows); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize TpchPartitioningHandle", e); + } + } + } + + private static class TpchTransactionHandleCodec + implements ConnectorCodec + { + @Override + public byte[] serialize(ConnectorTransactionHandle handle) + { + // TpchTransactionHandle is a singleton with no data + // Return empty byte array + return new byte[0]; + } + + @Override + public ConnectorTransactionHandle deserialize(byte[] bytes) + { + // Return the singleton instance + return TpchTransactionHandle.INSTANCE; + } + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java index 56daec470e417..fd532a6e42275 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; import com.facebook.presto.spi.connector.ConnectorMetadata; @@ -105,6 +106,12 @@ public ConnectorNodePartitioningProvider getNodePartitioningProvider() { return new TpchNodePartitioningProvider(nodeManager, splitsPerNode); } + + @Override + public ConnectorCodecProvider getConnectorCodecProvider() + { + return new TpchConnectorCodecProvider(context.getTypeManager(), context.getTupleDomainSerde()); + } }; } diff --git a/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchConnectorCodecProvider.java b/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchConnectorCodecProvider.java new file mode 100644 index 0000000000000..dfef22fa6e5eb --- /dev/null +++ b/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchConnectorCodecProvider.java @@ -0,0 +1,453 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch; + +import com.facebook.presto.common.Subfield; +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.Range; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.predicate.ValueSet; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.TupleDomainSerde; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.HyperLogLogType.HYPER_LOG_LOG; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestTpchConnectorCodecProvider +{ + private TpchConnectorCodecProvider codecProvider; + + @BeforeMethod + public void setUp() + { + codecProvider = new TpchConnectorCodecProvider( + new TestingTypeManager(), + new TestTupleDomainSerde(new ObjectMapper())); + } + + @Test + public void testTableHandleSerialization() + { + ConnectorCodec codec = codecProvider.getConnectorTableHandleCodec().get(); + + TpchTableHandle originalHandle = new TpchTableHandle("customer", 0.01); + + byte[] serialized = codec.serialize(originalHandle); + ConnectorTableHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchTableHandle); + TpchTableHandle deserializedTpch = (TpchTableHandle) deserialized; + assertEquals(deserializedTpch.getTableName(), originalHandle.getTableName()); + assertEquals(deserializedTpch.getScaleFactor(), originalHandle.getScaleFactor()); + } + + @Test + public void testColumnHandleSerialization() + { + ConnectorCodec codec = codecProvider.getColumnHandleCodec().get(); + + TpchColumnHandle originalHandle = new TpchColumnHandle("c_custkey", BIGINT); + + byte[] serialized = codec.serialize(originalHandle); + ColumnHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchColumnHandle); + TpchColumnHandle deserializedTpch = (TpchColumnHandle) deserialized; + assertEquals(deserializedTpch.getColumnName(), originalHandle.getColumnName()); + assertEquals(deserializedTpch.getType(), originalHandle.getType()); + } + + @Test + public void testColumnHandleWithSubfields() + { + ConnectorCodec codec = codecProvider.getColumnHandleCodec().get(); + + List subfields = ImmutableList.of( + new Subfield("field1"), + new Subfield("field2.nested")); + TpchColumnHandle originalHandle = new TpchColumnHandle("complex_column", VARCHAR, subfields); + + byte[] serialized = codec.serialize(originalHandle); + ColumnHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchColumnHandle); + TpchColumnHandle deserializedTpch = (TpchColumnHandle) deserialized; + assertEquals(deserializedTpch.getColumnName(), originalHandle.getColumnName()); + assertEquals(deserializedTpch.getType(), originalHandle.getType()); + assertEquals(deserializedTpch.getRequiredSubfields(), originalHandle.getRequiredSubfields()); + } + + @Test + public void testSplitSerialization() + { + ConnectorCodec codec = codecProvider.getConnectorSplitCodec().get(); + + TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0); + List addresses = ImmutableList.of( + HostAddress.fromParts("localhost", 8080), + HostAddress.fromParts("192.168.1.1", 9090)); + TupleDomain predicate = TupleDomain.all(); + + TpchSplit originalSplit = new TpchSplit(tableHandle, 2, 10, addresses, predicate); + + byte[] serialized = codec.serialize(originalSplit); + ConnectorSplit deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchSplit); + TpchSplit deserializedTpch = (TpchSplit) deserialized; + assertEquals(deserializedTpch.getTableHandle().getTableName(), originalSplit.getTableHandle().getTableName()); + assertEquals(deserializedTpch.getTableHandle().getScaleFactor(), originalSplit.getTableHandle().getScaleFactor()); + assertEquals(deserializedTpch.getPartNumber(), originalSplit.getPartNumber()); + assertEquals(deserializedTpch.getTotalParts(), originalSplit.getTotalParts()); + assertEquals(deserializedTpch.getAddresses(), originalSplit.getAddresses()); + } + + @Test + public void testTableLayoutHandleSerialization() + { + ConnectorCodec codec = codecProvider.getConnectorTableLayoutHandleCodec().get(); + + TpchTableHandle table = new TpchTableHandle("nation", 0.1); + TupleDomain predicate = TupleDomain.all(); + + TpchTableLayoutHandle originalHandle = new TpchTableLayoutHandle(table, predicate); + + byte[] serialized = codec.serialize(originalHandle); + ConnectorTableLayoutHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchTableLayoutHandle); + TpchTableLayoutHandle deserializedTpch = (TpchTableLayoutHandle) deserialized; + assertEquals(deserializedTpch.getTable().getTableName(), originalHandle.getTable().getTableName()); + assertEquals(deserializedTpch.getTable().getScaleFactor(), originalHandle.getTable().getScaleFactor()); + } + + @Test + public void testPartitioningHandleSerialization() + { + ConnectorCodec codec = codecProvider.getConnectorPartitioningHandleCodec().get(); + + TpchPartitioningHandle originalHandle = new TpchPartitioningHandle("abc", 123); + + byte[] serialized = codec.serialize(originalHandle); + ConnectorPartitioningHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchPartitioningHandle); + TpchPartitioningHandle deserializedTpch = (TpchPartitioningHandle) deserialized; + assertEquals(deserializedTpch.getTable(), originalHandle.getTable()); + assertEquals(deserializedTpch.getTotalRows(), originalHandle.getTotalRows()); + } + + @Test + public void testAllTpchTableNames() + { + ConnectorCodec codec = codecProvider.getConnectorTableHandleCodec().get(); + + String[] tableNames = {"customer", "lineitem", "nation", "orders", + "part", "partsupp", "region", "supplier"}; + + for (String tableName : tableNames) { + TpchTableHandle originalHandle = new TpchTableHandle(tableName, 1.0); + + byte[] serialized = codec.serialize(originalHandle); + ConnectorTableHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchTableHandle); + TpchTableHandle deserializedTpch = (TpchTableHandle) deserialized; + assertEquals(deserializedTpch.getTableName(), originalHandle.getTableName()); + assertEquals(deserializedTpch.getScaleFactor(), originalHandle.getScaleFactor()); + } + } + + @Test + public void testColumnHandleWithSpecialCharacters() + { + ConnectorCodec codec = codecProvider.getColumnHandleCodec().get(); + + TpchColumnHandle originalHandle = new TpchColumnHandle("column with spaces", VARCHAR); + + byte[] serialized = codec.serialize(originalHandle); + ColumnHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchColumnHandle); + TpchColumnHandle deserializedTpch = (TpchColumnHandle) deserialized; + assertEquals(deserializedTpch.getColumnName(), originalHandle.getColumnName()); + assertEquals(deserializedTpch.getType(), originalHandle.getType()); + } + + @Test + public void testVariousScaleFactors() + { + ConnectorCodec codec = codecProvider.getConnectorTableHandleCodec().get(); + + double[] scaleFactors = {0.01, 0.1, 1.0, 10.0, 100.0, 1000.0}; + + for (double scaleFactor : scaleFactors) { + TpchTableHandle originalHandle = new TpchTableHandle("lineitem", scaleFactor); + + byte[] serialized = codec.serialize(originalHandle); + ConnectorTableHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchTableHandle); + TpchTableHandle deserializedTpch = (TpchTableHandle) deserialized; + assertEquals(deserializedTpch.getTableName(), originalHandle.getTableName()); + assertEquals(deserializedTpch.getScaleFactor(), originalHandle.getScaleFactor()); + } + } + + @Test + public void testSplitWithEmptyAddresses() + { + ConnectorCodec codec = codecProvider.getConnectorSplitCodec().get(); + + TpchTableHandle tableHandle = new TpchTableHandle("region", 0.01); + List addresses = ImmutableList.of(); + TupleDomain predicate = TupleDomain.all(); + + TpchSplit originalSplit = new TpchSplit(tableHandle, 1, 2, addresses, predicate); + + byte[] serialized = codec.serialize(originalSplit); + ConnectorSplit deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchSplit); + TpchSplit deserializedTpch = (TpchSplit) deserialized; + assertEquals(deserializedTpch.getTableHandle().getTableName(), originalSplit.getTableHandle().getTableName()); + assertEquals(deserializedTpch.getTableHandle().getScaleFactor(), originalSplit.getTableHandle().getScaleFactor()); + assertEquals(deserializedTpch.getPartNumber(), originalSplit.getPartNumber()); + assertEquals(deserializedTpch.getTotalParts(), originalSplit.getTotalParts()); + assertEquals(deserializedTpch.getAddresses(), originalSplit.getAddresses()); + } + + @Test + public void testColumnHandleDeserialization() + throws IOException + { + ConnectorCodec codec = codecProvider.getColumnHandleCodec().get(); + + TpchColumnHandle originalHandle = new TpchColumnHandle("test_column", BIGINT); + byte[] serialized = codec.serialize(originalHandle); + + ByteArrayInputStream byteIn = new ByteArrayInputStream(serialized); + DataInputStream in = new DataInputStream(byteIn); + + String columnName = in.readUTF(); + String typeSignature = in.readUTF(); + int subfieldCount = in.readInt(); + + assertEquals(columnName, "test_column"); + assertEquals(typeSignature, "bigint"); + assertEquals(subfieldCount, 0); + + ColumnHandle deserialized = codec.deserialize(serialized); + assertTrue(deserialized instanceof TpchColumnHandle); + TpchColumnHandle deserializedTpch = (TpchColumnHandle) deserialized; + assertEquals(deserializedTpch.getColumnName(), "test_column"); + assertEquals(deserializedTpch.getType(), BIGINT); + } + + @Test + public void testTableHandleDeserialization() + throws IOException + { + ConnectorCodec codec = codecProvider.getConnectorTableHandleCodec().get(); + + TpchTableHandle originalHandle = new TpchTableHandle("supplier", 10.0); + byte[] serialized = codec.serialize(originalHandle); + + ByteArrayInputStream byteIn = new ByteArrayInputStream(serialized); + DataInputStream in = new DataInputStream(byteIn); + + String tableName = in.readUTF(); + double scaleFactor = in.readDouble(); + + assertEquals(tableName, "supplier"); + assertEquals(scaleFactor, 10.0); + + ConnectorTableHandle deserialized = codec.deserialize(serialized); + assertTrue(deserialized instanceof TpchTableHandle); + TpchTableHandle deserializedTpch = (TpchTableHandle) deserialized; + assertEquals(deserializedTpch.getTableName(), "supplier"); + assertEquals(deserializedTpch.getScaleFactor(), 10.0); + } + + @Test + public void testSplitWithComplexPredicate() + { + ConnectorCodec codec = codecProvider.getConnectorSplitCodec().get(); + + TpchTableHandle tableHandle = new TpchTableHandle("customer", 1.0); + TpchColumnHandle columnHandle = new TpchColumnHandle("c_custkey", BIGINT); + + Domain domain = Domain.create( + ValueSet.ofRanges(Range.range(BIGINT, 1L, true, 100L, false)), + false); + TupleDomain predicate = TupleDomain.withColumnDomains( + ImmutableMap.of(columnHandle, domain)); + + TpchSplit originalSplit = new TpchSplit( + tableHandle, + 1, + 2, + ImmutableList.of(HostAddress.fromParts("localhost", 8080)), + predicate); + + byte[] serialized = codec.serialize(originalSplit); + ConnectorSplit deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchSplit); + TpchSplit deserializedTpch = (TpchSplit) deserialized; + assertEquals(deserializedTpch.getTableHandle().getTableName(), "customer"); + assertNotNull(deserializedTpch.getPredicate()); + } + + @Test + public void testColumnHandleWithVariousTypes() + { + ConnectorCodec codec = codecProvider.getColumnHandleCodec().get(); + + Map types = ImmutableMap.of( + "bigint_col", BIGINT, + "integer_col", INTEGER, + "double_col", DOUBLE, + "varchar_col", VARCHAR); + + for (Map.Entry entry : types.entrySet()) { + TpchColumnHandle originalHandle = new TpchColumnHandle(entry.getKey(), entry.getValue()); + + byte[] serialized = codec.serialize(originalHandle); + ColumnHandle deserialized = codec.deserialize(serialized); + + assertTrue(deserialized instanceof TpchColumnHandle); + TpchColumnHandle deserializedTpch = (TpchColumnHandle) deserialized; + assertEquals(deserializedTpch.getColumnName(), originalHandle.getColumnName()); + assertEquals(deserializedTpch.getType(), originalHandle.getType()); + } + } + + private static class TestTupleDomainSerde + implements TupleDomainSerde + { + private final ObjectMapper objectMapper; + + public TestTupleDomainSerde(ObjectMapper objectMapper) + { + this.objectMapper = objectMapper; + } + + @Override + public String serialize(TupleDomain tupleDomain) + { + try { + if (tupleDomain.isAll()) { + return "{\"all\":true}"; + } + if (tupleDomain.isNone()) { + return "{\"none\":true}"; + } + return "{\"columnDomains\":[]}"; + } + catch (Exception e) { + throw new RuntimeException("Failed to serialize TupleDomain", e); + } + } + + @Override + public TupleDomain deserialize(String serialized) + { + try { + JsonNode node = objectMapper.readTree(serialized); + if (node.has("all") && node.get("all").asBoolean()) { + return TupleDomain.all(); + } + if (node.has("none") && node.get("none").asBoolean()) { + return TupleDomain.none(); + } + return TupleDomain.all(); + } + catch (Exception e) { + throw new RuntimeException("Failed to deserialize TupleDomain", e); + } + } + } + + public static class TestingTypeManager + implements TypeManager + { + @Override + public Type getType(TypeSignature signature) + { + for (Type type : getTypes()) { + if (signature.getBase().equals(type.getTypeSignature().getBase())) { + return type; + } + } + return null; + } + + @Override + public Type getParameterizedType(String baseTypeName, List typeParameters) + { + return getType(new TypeSignature(baseTypeName, typeParameters)); + } + + @Override + public boolean canCoerce(Type actualType, Type expectedType) + { + throw new UnsupportedOperationException(); + } + + @Override + public List getTypes() + { + return ImmutableList.of(BOOLEAN, INTEGER, BIGINT, DOUBLE, VARCHAR, VARBINARY, TIMESTAMP, DATE, HYPER_LOG_LOG); + } + + @Override + public boolean hasType(TypeSignature signature) + { + return getType(signature) != null; + } + } +}