Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@

<dependency>
<groupId>io.airlift</groupId>
<artifactId>aircompressor</artifactId>
<artifactId>aircompressor-v3</artifactId>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@
*/
package io.trino.execution.buffer;

import io.airlift.compress.v3.lz4.Lz4Compressor;
import io.airlift.compress.v3.zstd.ZstdCompressor;

import java.util.OptionalInt;

public enum CompressionCodec
{
NONE, LZ4, ZSTD
NONE, LZ4, ZSTD;

public OptionalInt maxCompressedLength(int inputSize)
{
return switch (this) {
case NONE -> OptionalInt.empty();
case LZ4 -> OptionalInt.of(Lz4Compressor.create().maxCompressedLength(inputSize));
case ZSTD -> OptionalInt.of(ZstdCompressor.create().maxCompressedLength(inputSize));
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
package io.trino.execution.buffer;

import com.google.common.base.VerifyException;
import io.airlift.compress.Decompressor;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.airlift.compress.lz4.Lz4RawCompressor;
import io.airlift.compress.v3.Decompressor;
import io.airlift.compress.v3.lz4.Lz4Decompressor;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.Slices;
Expand All @@ -34,6 +33,7 @@
import java.io.UnsupportedEncodingException;
import java.security.GeneralSecurityException;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.instanceSize;
Expand Down Expand Up @@ -63,15 +63,17 @@ public PageDeserializer(
BlockEncodingSerde blockEncodingSerde,
Optional<Decompressor> decompressor,
Optional<SecretKey> encryptionKey,
int blockSizeInBytes)
int blockSizeInBytes,
OptionalInt maxCompressedBlockSizeInBytes)
{
this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
requireNonNull(encryptionKey, "encryptionKey is null");
encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key"));
input = new SerializedPageInput(
requireNonNull(decompressor, "decompressor is null"),
encryptionKey,
blockSizeInBytes);
blockSizeInBytes,
maxCompressedBlockSizeInBytes);
}

public Page deserialize(Slice serializedPage)
Expand Down Expand Up @@ -101,7 +103,7 @@ private static class SerializedPageInput

private final ReadBuffer[] buffers;

private SerializedPageInput(Optional<Decompressor> decompressor, Optional<SecretKey> encryptionKey, int blockSizeInBytes)
private SerializedPageInput(Optional<Decompressor> decompressor, Optional<SecretKey> encryptionKey, int blockSizeInBytes, OptionalInt maxCompressedBlockSizeInBytes)
{
this.decompressor = requireNonNull(decompressor, "decompressor is null");
this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null");
Expand All @@ -122,7 +124,7 @@ private SerializedPageInput(Optional<Decompressor> decompressor, Optional<Secret
int bufferSize;
if (decompressor.isPresent()) {
// to store compressed block size
bufferSize = Lz4RawCompressor.maxCompressedLength(blockSizeInBytes)
bufferSize = maxCompressedBlockSizeInBytes.orElseThrow()
// to store compressed block size
+ Integer.BYTES
// to guarantee a single long can always be read entirely
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
package io.trino.execution.buffer;

import com.google.common.base.VerifyException;
import io.airlift.compress.Compressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.compress.lz4.Lz4RawCompressor;
import io.airlift.compress.v3.Compressor;
import io.airlift.compress.v3.lz4.Lz4Compressor;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
Expand All @@ -33,12 +32,12 @@
import java.nio.charset.Charset;
import java.security.GeneralSecurityException;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.slice.SizeOf.sizeOfByteArray;
import static io.airlift.slice.SizeOf.sizeOfIntArray;
import static io.trino.execution.buffer.PageCodecMarker.COMPRESSED;
import static io.trino.execution.buffer.PageCodecMarker.ENCRYPTED;
import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE;
Expand Down Expand Up @@ -67,15 +66,17 @@ public PageSerializer(
BlockEncodingSerde blockEncodingSerde,
Optional<Compressor> compressor,
Optional<SecretKey> encryptionKey,
int blockSizeInBytes)
int blockSizeInBytes,
OptionalInt maxCompressedBlockSize)
{
this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
requireNonNull(encryptionKey, "encryptionKey is null");
encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key"));
output = new SerializedPageOutput(
requireNonNull(compressor, "compressor is null"),
encryptionKey,
blockSizeInBytes);
blockSizeInBytes,
maxCompressedBlockSize);
}

public Slice serialize(Page page)
Expand All @@ -95,7 +96,8 @@ private static class SerializedPageOutput
{
private static final int INSTANCE_SIZE = instanceSize(SerializedPageOutput.class);
// TODO: implement getRetainedSizeInBytes in Lz4Compressor
private static final int COMPRESSOR_RETAINED_SIZE = toIntExact(instanceSize(Lz4Compressor.class) + sizeOfIntArray(Lz4RawCompressor.MAX_TABLE_SIZE));
// TODO: need a fix
private static final int COMPRESSOR_RETAINED_SIZE = toIntExact(instanceSize(Lz4Compressor.class));
Comment thread
wendigo marked this conversation as resolved.
Outdated
private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8));

private static final double MINIMUM_COMPRESSION_RATIO = 0.8;
Expand All @@ -111,7 +113,8 @@ private static class SerializedPageOutput
private SerializedPageOutput(
Optional<Compressor> compressor,
Optional<SecretKey> encryptionKey,
int blockSizeInBytes)
int blockSizeInBytes,
OptionalInt maxCompressedBlockSize)
{
this.compressor = requireNonNull(compressor, "compressor is null");
this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null");
Expand All @@ -129,7 +132,7 @@ private SerializedPageOutput(
if (encryptionKey.isPresent()) {
int bufferSize = blockSizeInBytes;
if (compressor.isPresent()) {
bufferSize = compressor.get().maxCompressedLength(blockSizeInBytes)
bufferSize = maxCompressedBlockSize.orElseThrow()
// to store compressed block size
+ Integer.BYTES;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,34 @@
*/
package io.trino.execution.buffer;

import io.airlift.compress.Compressor;
import io.airlift.compress.Decompressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.v3.Compressor;
import io.airlift.compress.v3.Decompressor;
import io.airlift.compress.v3.lz4.Lz4Compressor;
import io.airlift.compress.v3.lz4.Lz4Decompressor;
import io.airlift.compress.v3.zstd.ZstdCompressor;
import io.airlift.compress.v3.zstd.ZstdDecompressor;
import io.trino.spi.block.BlockEncodingSerde;

import javax.crypto.SecretKey;

import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;

import static io.trino.execution.buffer.CompressionCodec.LZ4;
import static io.trino.execution.buffer.CompressionCodec.NONE;
import static io.trino.execution.buffer.CompressionCodec.ZSTD;
import static java.util.Objects.requireNonNull;

public class PagesSerdeFactory
{
private static final int SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES = 64 * 1024;

private static final Map<CompressionCodec, OptionalInt> MAX_COMPRESSED_LENGTH = Map.of(
NONE, NONE.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES),
LZ4, LZ4.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES),
ZSTD, ZSTD.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES));

private final BlockEncodingSerde blockEncodingSerde;
private final CompressionCodec compressionCodec;

Expand All @@ -42,29 +52,39 @@ public PagesSerdeFactory(BlockEncodingSerde blockEncodingSerde, CompressionCodec

public PageSerializer createSerializer(Optional<SecretKey> encryptionKey)
{
return new PageSerializer(blockEncodingSerde, createCompressor(compressionCodec), encryptionKey, SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES);
return new PageSerializer(
blockEncodingSerde,
createCompressor(compressionCodec),
encryptionKey,
SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES,
MAX_COMPRESSED_LENGTH.get(compressionCodec));
}

public PageDeserializer createDeserializer(Optional<SecretKey> encryptionKey)
{
return new PageDeserializer(blockEncodingSerde, createDecompressor(compressionCodec), encryptionKey, SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES);
return new PageDeserializer(
blockEncodingSerde,
createDecompressor(compressionCodec),
encryptionKey,
SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES,
MAX_COMPRESSED_LENGTH.get(compressionCodec));
}

public static Optional<Compressor> createCompressor(CompressionCodec compressionCodec)
{
return switch (compressionCodec) {
case NONE -> Optional.empty();
case LZ4 -> Optional.of(new Lz4Compressor());
case ZSTD -> Optional.of(new ZstdCompressor());
case LZ4 -> Optional.of(Lz4Compressor.create());
case ZSTD -> Optional.of(ZstdCompressor.create());
};
}

public static Optional<Decompressor> createDecompressor(CompressionCodec compressionCodec)
{
return switch (compressionCodec) {
case NONE -> Optional.empty();
case LZ4 -> Optional.of(new Lz4Decompressor());
case ZSTD -> Optional.of(new ZstdDecompressor());
case LZ4 -> Optional.of(Lz4Decompressor.create());
case ZSTD -> Optional.of(ZstdDecompressor.create());
};
}
}
13 changes: 12 additions & 1 deletion core/trino-main/src/main/java/io/trino/server/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import com.google.inject.TypeLiteral;
import io.airlift.bootstrap.ApplicationConfigurationException;
import io.airlift.bootstrap.Bootstrap;
import io.airlift.compress.v3.lz4.Lz4NativeCompressor;
import io.airlift.compress.v3.snappy.SnappyNativeCompressor;
import io.airlift.compress.v3.zstd.ZstdNativeCompressor;
import io.airlift.discovery.client.Announcer;
import io.airlift.discovery.client.DiscoveryModule;
import io.airlift.discovery.client.ServiceAnnouncement;
Expand Down Expand Up @@ -142,6 +145,10 @@ private void doStart(String trinoVersion)
Injector injector = app.initialize();

log.info("Trino version: %s", injector.getInstance(NodeVersion.class).getVersion());
log.info("Zstandard native compression: %s", formatEnabled(ZstdNativeCompressor.isEnabled()));
log.info("Lz4 native compression: %s", formatEnabled(Lz4NativeCompressor.isEnabled()));
log.info("Snappy native compression: %s", formatEnabled(SnappyNativeCompressor.isEnabled()));

logLocation(log, "Working directory", Paths.get("."));
logLocation(log, "Etc directory", Paths.get("etc"));

Expand Down Expand Up @@ -188,7 +195,6 @@ private void doStart(String trinoVersion)
injector.getInstance(Announcer.class).start();

injector.getInstance(StartupStatus.class).startupComplete();

log.info("Server startup completed in %s", Duration.nanosSince(startTime).convertToMostSuccinctTimeUnit());
log.info("======== SERVER STARTED ========");
}
Expand Down Expand Up @@ -290,4 +296,9 @@ private static void logLocation(Logger log, String name, Path path)
}
log.info("%s: %s", name, path);
}

private static String formatEnabled(boolean flag)
{
return flag ? "enabled" : "disabled";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
package io.trino.server.protocol;

import com.google.inject.Inject;
import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.v3.zstd.ZstdCompressor;
import io.airlift.compress.v3.zstd.ZstdDecompressor;
import io.trino.server.ProtocolConfig;

import static com.google.common.io.BaseEncoding.base64Url;
Expand Down Expand Up @@ -43,7 +43,7 @@ public String encodePreparedStatementForHeader(String preparedStatement)
return preparedStatement;
}

ZstdCompressor compressor = new ZstdCompressor();
ZstdCompressor compressor = ZstdCompressor.create();
byte[] inputBytes = preparedStatement.getBytes(UTF_8);
byte[] compressed = new byte[compressor.maxCompressedLength(inputBytes.length)];
int outputSize = compressor.compress(inputBytes, 0, inputBytes.length, compressed, 0, compressed.length);
Expand All @@ -63,9 +63,9 @@ public String decodePreparedStatementFromHeader(String headerValue)

String encoded = headerValue.substring(PREFIX.length());
byte[] compressed = base64Url().decode(encoded);

byte[] preparedStatement = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(compressed, 0, compressed.length))];
new ZstdDecompressor().decompress(compressed, 0, compressed.length, preparedStatement, 0, preparedStatement.length);
ZstdDecompressor decompressor = ZstdDecompressor.create();
byte[] preparedStatement = new byte[toIntExact(decompressor.getDecompressedSize(compressed, 0, compressed.length))];
decompressor.decompress(compressed, 0, compressed.length, preparedStatement, 0, preparedStatement.length);
return new String(preparedStatement, UTF_8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,26 @@
*/
package io.trino.server.security.oauth2;

import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.zstd.ZstdInputStream;
import io.airlift.compress.zstd.ZstdOutputStream;
import io.jsonwebtoken.CompressionCodec;
import io.jsonwebtoken.CompressionException;
import io.airlift.compress.v3.zstd.ZstdInputStream;
import io.airlift.compress.v3.zstd.ZstdOutputStream;
import io.jsonwebtoken.io.CompressionAlgorithm;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;

import static java.lang.Math.toIntExact;
import static java.util.Arrays.copyOfRange;

public class ZstdCodec
implements CompressionCodec
implements CompressionAlgorithm
{
public static final String CODEC_NAME = "ZSTD";

@Override
public String getAlgorithmName()
public String getId()
{
return CODEC_NAME;
}

@Override
public byte[] compress(byte[] bytes)
throws CompressionException
{
ZstdCompressor compressor = new ZstdCompressor();
byte[] compressed = new byte[compressor.maxCompressedLength(bytes.length)];
int outputSize = compressor.compress(bytes, 0, bytes.length, compressed, 0, compressed.length);
return copyOfRange(compressed, 0, outputSize);
}

@Override
public byte[] decompress(byte[] bytes)
throws CompressionException
{
byte[] output = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(bytes, 0, bytes.length))];
new ZstdDecompressor().decompress(bytes, 0, bytes.length, output, 0, output.length);
return output;
}

@Override
public OutputStream compress(OutputStream out)
{
Expand All @@ -74,10 +49,4 @@ public InputStream decompress(InputStream in)
{
return new ZstdInputStream(in);
}

@Override
public String getId()
{
return CODEC_NAME;
}
}
Loading