Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@
<artifactId>postgresql</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.github.luben</groupId>
<artifactId>zstd-jni</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
import com.facebook.presto.resourcemanager.ResourceManagerConfig;
import com.facebook.presto.resourcemanager.ResourceManagerInconsistentException;
import com.facebook.presto.resourcemanager.ResourceManagerResourceGroupService;
import com.facebook.presto.server.remotetask.DecompressionFilter;
import com.facebook.presto.server.remotetask.HttpLocationFactory;
import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig;
import com.facebook.presto.server.thrift.FixedAddressSelector;
Expand Down Expand Up @@ -437,6 +438,7 @@ else if (serverConfig.isCoordinator()) {
// task execution
jaxrsBinder(binder).bind(TaskResource.class);
jaxrsBinder(binder).bind(ThriftTaskUpdateRequestBodyReader.class);
jaxrsBinder(binder).bind(DecompressionFilter.class);

newExporter(binder).export(TaskResource.class).withGeneratedName();
jaxrsBinder(binder).bind(TaskExecutorResource.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.remotetask;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.spi.PrestoException;
import com.github.luben.zstd.ZstdInputStream;
import jakarta.annotation.Priority;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.ext.Provider;

import java.io.IOException;
import java.io.InputStream;

import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.lang.String.format;

@Provider
@Priority(Priorities.ENTITY_CODER)
public class DecompressionFilter
implements ContainerRequestFilter
{
private static final Logger log = Logger.get(DecompressionFilter.class);

@Override
public void filter(ContainerRequestContext containerRequestContext)
throws IOException
{
String contentEncoding = containerRequestContext.getHeaderString("Content-Encoding");

if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) {
InputStream originalStream = containerRequestContext.getEntityStream();
InputStream decompressedStream;

if (contentEncoding.equalsIgnoreCase("zstd")) {
decompressedStream = new ZstdInputStream(originalStream);
}
else {
throw new PrestoException(NOT_SUPPORTED, format("Unsupported Content-Encoding: '%s'. Only zstd compression is supported.", contentEncoding));
}

containerRequestContext.setEntityStream(decompressedStream);
containerRequestContext.getHeaders().remove("Content-Encoding");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import com.facebook.airlift.http.client.StaticBodyGenerator;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.units.Duration;
import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdOutputStreamNoFinalizer;
import com.google.common.base.Splitter;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.util.concurrent.SettableFuture;
import com.google.inject.Inject;
import io.netty.channel.ChannelOption;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.epoll.Epoll;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.ssl.ApplicationProtocolConfig;
Expand All @@ -44,6 +47,7 @@
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
Expand All @@ -62,6 +66,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.zip.GZIPInputStream;

import static com.facebook.airlift.security.pem.PemReader.loadPrivateKey;
import static com.facebook.airlift.security.pem.PemReader.readCertificateChain;
Expand All @@ -84,17 +89,25 @@ public class ReactorNettyHttpClient
private static final Logger log = Logger.get(ReactorNettyHttpClient.class);
private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName.of("Content-Type");
private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName.of("Content-Length");
private static final HeaderName CONTENT_ENCODING_HEADER_NAME = HeaderName.of("Content-Encoding");
private static final HeaderName ACCEPT_ENCODING_HEADER_NAME = HeaderName.of("Accept-Encoding");

private final Duration requestTimeout;
private HttpClient httpClient;
private final HttpClientConnectionPoolStats connectionPoolStats;
private final HttpClientStats httpClientStats;
private final boolean isHttp2CompressionEnabled;
private final int payloadSizeThreshold;
private final double compressionSavingThreshold;

@Inject
public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientConnectionPoolStats connectionPoolStats, HttpClientStats httpClientStats)
{
this.connectionPoolStats = connectionPoolStats;
this.httpClientStats = httpClientStats;
this.isHttp2CompressionEnabled = config.isHttp2CompressionEnabled();
this.payloadSizeThreshold = config.getPayloadSizeThreshold();
this.compressionSavingThreshold = config.getCompressionSavingThreshold();
SslContext sslContext = null;
if (config.isHttpsEnabled()) {
try {
Expand All @@ -114,11 +127,11 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
if (os.toLowerCase(Locale.ENGLISH).contains("linux")) {
// Make sure Open ssl is available for linux deployments
if (!OpenSsl.isAvailable()) {
throw new UnsupportedOperationException(format("OpenSsl is not unavailable. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n')));
throw new UnsupportedOperationException(format("OpenSsl is not available. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n')));
}
// Make sure epoll threads are used for linux deployments
if (!Epoll.isAvailable()) {
throw new UnsupportedOperationException(format("Epoll is not unavailable. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n')));
throw new UnsupportedOperationException(format("Epoll is not available. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n')));
}
}

Expand Down Expand Up @@ -166,9 +179,10 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon

// Create HTTP/2 client
SslContext finalSslContext = sslContext;

this.httpClient = HttpClient
// The custom pool is wrapped with a HttpConnectionProvider over here
.create(pool)
.create(pool) // The custom pool is wrapped with a HttpConnectionProvider over here
.compress(false) // we will enable response compression manually
.protocol(HttpProtocol.H2, HttpProtocol.HTTP11)
.runOn(loopResources, true)
.http2Settings(settings -> {
Expand All @@ -179,6 +193,9 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.getConnectTimeout().getValue())
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_SNDBUF, config.getTcpBufferSize())
.option(ChannelOption.SO_RCVBUF, config.getTcpBufferSize())
.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(config.getWriteBufferWaterMarkLow(), config.getWriteBufferWaterMarkHigh()))
// Track HTTP client metrics
.metrics(true, () -> httpClientStats, Function.identity());

Expand Down Expand Up @@ -208,6 +225,10 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
for (Map.Entry<String, String> entry : airliftRequest.getHeaders().entries()) {
hdr.set(entry.getKey(), entry.getValue());
}

if (isHttp2CompressionEnabled) {
hdr.set(ACCEPT_ENCODING_HEADER_NAME.toString(), "zstd, gzip");
}
});

URI uri = airliftRequest.getUri();
Expand All @@ -223,9 +244,33 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
break;
case "POST":
byte[] postBytes = ((StaticBodyGenerator) airliftRequest.getBodyGenerator()).getBody();
disposable = client.post()
byte[] bodyToSend = postBytes;
HttpClient postClient = client;
// We manually do compression for request, use zstd
if (isHttp2CompressionEnabled && postBytes.length >= payloadSizeThreshold) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(postBytes.length / 2);
try (ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer(baos)) {
zstdOutput.write(postBytes);
}

byte[] compressedBytes = baos.toByteArray();
double compressionRatio = (double) (postBytes.length - compressedBytes.length) / postBytes.length;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Compression ratio calculation may be negative if compressedBytes is larger than postBytes.

Negative compressionRatio values may cause compressed data to be used when it is actually larger than the original. Ensure the ratio is non-negative, for example by using Math.max(0, compressionRatio).

if (compressionRatio >= compressionSavingThreshold) {
bodyToSend = compressedBytes;
postClient = client.headers(h -> h.set(CONTENT_ENCODING_HEADER_NAME.toString(), "zstd"));
}
}
catch (IOException e) {
onError(listenableFuture, e);
disposable = () -> {};
break;
}
}

disposable = postClient.post()
.uri(uri)
.send(ByteBufFlux.fromInbound(Mono.just(postBytes)))
.send(ByteBufFlux.fromInbound(Mono.just(bodyToSend)))
.responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response)))
// Request timeout
.timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS))
Expand Down Expand Up @@ -303,6 +348,7 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
}

long contentLength = 0;
String contentEncoding = null;
// Iterate over the headers
for (String name : headers.names()) {
if (name.equalsIgnoreCase(CONTENT_LENGTH_HEADER_NAME.toString())) {
Expand All @@ -313,6 +359,9 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
responseHeaders.put(CONTENT_TYPE_HEADER_NAME, headers.get(name));
}
else if (name.equalsIgnoreCase(CONTENT_ENCODING_HEADER_NAME.toString())) {
contentEncoding = headers.get(name);
}
else {
responseHeaders.put(HeaderName.of(name), headers.get(name));
}
Expand All @@ -323,7 +372,21 @@ else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
return;
}

final InputStream[] streamHolder = new InputStream[1];
streamHolder[0] = inputStream;
try {
if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) {
if (contentEncoding.equalsIgnoreCase("zstd")) {
streamHolder[0] = new ZstdInputStream(inputStream);
}
else if (contentEncoding.equalsIgnoreCase("gzip")) {
streamHolder[0] = new GZIPInputStream(inputStream);
}
else {
throw new RuntimeException(format("Unsupported Content-Encoding: %s. Supported: zstd, gzip.", contentEncoding));
}
}

long finalContentLength = contentLength;
Object a = responseHandler.handle(null, new Response()
{
Expand All @@ -349,19 +412,21 @@ public long getBytesRead()
public InputStream getInputStream()
throws IOException
{
return inputStream;
return streamHolder[0];
}
});
// closing it here to prevent memory leak of bytebuf
inputStream.close();
if (streamHolder[0] != null) {
streamHolder[0].close();
}
listenableFuture.set(a);
}
catch (Exception e) {
listenableFuture.setException(e);
}
finally {
try {
inputStream.close();
streamHolder[0].close();
Comment thread
shangm2 marked this conversation as resolved.
}
catch (IOException e) {
log.warn(e, "Failed to close input stream");
Expand Down
Loading
Loading