diff --git a/presto-client/pom.xml b/presto-client/pom.xml index 07e8ca839dfdf..7530bee57cd58 100644 --- a/presto-client/pom.xml +++ b/presto-client/pom.xml @@ -98,5 +98,23 @@ testng test + + + com.facebook.drift + drift-protocol + test + + + + com.facebook.drift + drift-codec + test + + + + com.facebook.drift + drift-codec-utils + test + diff --git a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java index ed27f3836f1c5..ec199a7c4bda3 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.client; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import io.airlift.units.Duration; @@ -26,6 +29,7 @@ import static java.util.Objects.requireNonNull; @Immutable +@ThriftStruct public class ServerInfo { private final NodeVersion nodeVersion; @@ -36,6 +40,7 @@ public class ServerInfo // optional to maintain compatibility with older servers private final Optional uptime; + @ThriftConstructor @JsonCreator public ServerInfo( @JsonProperty("nodeVersion") NodeVersion nodeVersion, @@ -51,30 +56,35 @@ public ServerInfo( this.uptime = requireNonNull(uptime, "uptime is null"); } + @ThriftField(1) @JsonProperty public NodeVersion getNodeVersion() { return nodeVersion; } + @ThriftField(2) @JsonProperty public String getEnvironment() { return environment; } + @ThriftField(3) @JsonProperty public boolean isCoordinator() { return coordinator; } + @ThriftField(4) @JsonProperty public boolean isStarting() { return starting; } + @ThriftField(5) @JsonProperty public Optional getUptime() { diff --git a/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java index 34d80bf134720..8573ddd64d3c1 100644 --- a/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java +++ b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java @@ -14,18 +14,55 @@ package com.facebook.presto.client; import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.ThriftCodec; +import com.facebook.drift.codec.ThriftCodecManager; +import com.facebook.drift.codec.internal.compiler.CompilerThriftCodecFactory; +import com.facebook.drift.codec.internal.reflection.ReflectionThriftCodecFactory; +import com.facebook.drift.codec.metadata.ThriftCatalog; +import com.facebook.drift.codec.utils.DurationToMillisThriftCodec; +import com.facebook.drift.protocol.TBinaryProtocol; +import com.facebook.drift.protocol.TCompactProtocol; +import com.facebook.drift.protocol.TFacebookCompactProtocol; +import com.facebook.drift.protocol.TMemoryBuffer; +import com.facebook.drift.protocol.TProtocol; +import com.facebook.drift.protocol.TTransport; import io.airlift.units.Duration; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.Optional; +import java.util.function.Function; import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.client.NodeVersion.UNKNOWN; import static org.testng.Assert.assertEquals; +@Test(singleThreaded = true) public class TestServerInfo { private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + private static final ThriftCatalog COMMON_CATALOG = new ThriftCatalog(); + private static final DurationToMillisThriftCodec DURATION_CODEC = new DurationToMillisThriftCodec(COMMON_CATALOG); + private static final ThriftCodecManager COMPILER_READ_CODEC_MANAGER = new ThriftCodecManager(new CompilerThriftCodecFactory(false), DURATION_CODEC); + private static final ThriftCodecManager COMPILER_WRITE_CODEC_MANAGER = new ThriftCodecManager(new CompilerThriftCodecFactory(false), DURATION_CODEC); + private static final ThriftCodec COMPILER_READ_CODEC = COMPILER_READ_CODEC_MANAGER.getCodec(ServerInfo.class); + private static final ThriftCodec COMPILER_WRITE_CODEC = COMPILER_WRITE_CODEC_MANAGER.getCodec(ServerInfo.class); + private static final ThriftCodecManager REFLECTION_READ_CODEC_MANAGER = new ThriftCodecManager(new ReflectionThriftCodecFactory(), DURATION_CODEC); + private static final ThriftCodecManager REFLECTION_WRITE_CODEC_MANAGER = new ThriftCodecManager(new ReflectionThriftCodecFactory(), DURATION_CODEC); + private static final ThriftCodec REFLECTION_READ_CODEC = REFLECTION_READ_CODEC_MANAGER.getCodec(ServerInfo.class); + private static final ThriftCodec REFLECTION_WRITE_CODEC = REFLECTION_WRITE_CODEC_MANAGER.getCodec(ServerInfo.class); + private static final TMemoryBuffer transport = new TMemoryBuffer(100 * 1024); + + @DataProvider + public Object[][] codecCombinations() + { + return new Object[][] { + {COMPILER_READ_CODEC, COMPILER_WRITE_CODEC}, + {COMPILER_READ_CODEC, REFLECTION_WRITE_CODEC}, + {REFLECTION_READ_CODEC, COMPILER_WRITE_CODEC}, + {REFLECTION_READ_CODEC, REFLECTION_WRITE_CODEC} + }; + } @Test public void testJsonRoundTrip() @@ -42,6 +79,46 @@ public void testBackwardsCompatible() assertEquals(newServerInfo, legacyServerInfo); } + @Test(dataProvider = "codecCombinations") + public void testRoundTripSerializeBinaryProtocol(ThriftCodec readCodec, ThriftCodec writeCodec) + throws Exception + { + ServerInfo serverInfo = getServerInfo(); + ServerInfo roundTripServerInfo = getRoundTripSerialize(readCodec, writeCodec, TBinaryProtocol::new, serverInfo); + assertEquals(serverInfo, roundTripServerInfo); + } + + @Test(dataProvider = "codecCombinations") + public void testRoundTripSerializeCompactProtocol(ThriftCodec readCodec, ThriftCodec writeCodec) + throws Exception + { + ServerInfo serverInfo = getServerInfo(); + ServerInfo roundTripServerInfo = getRoundTripSerialize(readCodec, writeCodec, TCompactProtocol::new, serverInfo); + assertEquals(serverInfo, roundTripServerInfo); + } + + @Test(dataProvider = "codecCombinations") + public void testRoundTripSerializeFacebookCompactProtocol(ThriftCodec readCodec, ThriftCodec writeCodec) + throws Exception + { + ServerInfo serverInfo = getServerInfo(); + ServerInfo roundTripServerInfo = getRoundTripSerialize(readCodec, writeCodec, TFacebookCompactProtocol::new, serverInfo); + assertEquals(serverInfo, roundTripServerInfo); + } + + private ServerInfo getServerInfo() + { + return new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m"))); + } + + private ServerInfo getRoundTripSerialize(ThriftCodec readCodec, ThriftCodec writeCodec, + Function protocolFactory, ServerInfo serverInfo) throws Exception + { + TProtocol protocol = protocolFactory.apply(transport); + writeCodec.write(serverInfo, protocol); + return readCodec.read(protocol); + } + private static void assertJsonRoundTrip(ServerInfo serverInfo) { String json = SERVER_INFO_CODEC.toJson(serverInfo);