Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -115,6 +116,7 @@ public final class TypeRegistry
private final ConcurrentMap<String, ParametricType> parametricTypes = new ConcurrentHashMap<>();

private final NonEvictableCache<TypeSignature, Type> parametricTypeCache;
private final NonEvictableCache<String, Type> sqlTypeCache;
private final TypeManager typeManager;
private final TypeOperators typeOperators;

Expand Down Expand Up @@ -168,6 +170,7 @@ public TypeRegistry(TypeOperators typeOperators, FeaturesConfig featuresConfig)
addParametricType(TIME_WITH_TIME_ZONE);

parametricTypeCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));
sqlTypeCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));

typeManager = new InternalTypeManager(this, typeOperators);

Expand Down Expand Up @@ -198,11 +201,17 @@ public Type getType(TypeId id)
public Type fromSqlType(String sqlType)
{
try {
return getType(toTypeSignature(SQL_PARSER.createType(sqlType)));
return sqlTypeCache.get(sqlType, () -> getType(toTypeSignature(SQL_PARSER.createType(sqlType))));
}
catch (ParsingException e) {
throw new TypeNotFoundException(sqlType, e);
}
catch (ExecutionException e) {
if (e.getCause() instanceof ParsingException parsingException) {
throw new TypeNotFoundException(sqlType, parsingException);
}
throw new RuntimeException("Could not get type from cache", e);
}
}

private Type instantiateParametricType(TypeSignature signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@

import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.KeyDeserializer;
import com.google.common.base.CharMatcher;
import com.google.inject.Inject;
import io.trino.spi.type.TypeId;
import io.trino.spi.type.TypeManager;

import java.util.Base64;

import static com.google.common.base.Preconditions.checkArgument;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.lang.Integer.parseInt;

public class SymbolKeyDeserializer
extends KeyDeserializer
{
private static final Base64.Decoder DECODER = Base64.getDecoder();
private static final CharMatcher DIGIT_MATCHER = CharMatcher.inRange('0', '9').precomputed();

private final TypeManager typeManager;

@Inject
Expand All @@ -39,11 +39,28 @@ public SymbolKeyDeserializer(TypeManager typeManager)
@Override
public Object deserializeKey(String key, DeserializationContext context)
{
String[] parts = key.split(":");
checkArgument(parts.length == 2, "Expected two parts, found: " + parts.length);
int keyLength = key.length();

// Shortest valid key is "1|n|t", which has length 5
checkArgument(keyLength > 4, "Symbol key is malformed: %s", key);

int lastDigitIndex = getLeadingDigitsLength(key, keyLength);
checkArgument(lastDigitIndex > 0, "Symbol key is malformed: %s", key);

String name = new String(DECODER.decode(parts[0].getBytes(UTF_8)), UTF_8);
String type = new String(DECODER.decode(parts[1].getBytes(UTF_8)), UTF_8);
int length = parseInt(key.substring(0, lastDigitIndex));
checkArgument(lastDigitIndex + length + 2 < keyLength, "Symbol key is malformed: %s", key);

String type = key.substring(lastDigitIndex + 1, lastDigitIndex + length + 1);
String name = key.substring(lastDigitIndex + length + 2);
return new Symbol(typeManager.getType(TypeId.of(type)), name);
}

public static int getLeadingDigitsLength(String input, int length)
{
int index = 0;
while (index < length && DIGIT_MATCHER.matches(input.charAt(index))) {
index++;
}
return index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,15 @@
import com.fasterxml.jackson.databind.SerializerProvider;

import java.io.IOException;
import java.util.Base64;

import static java.nio.charset.StandardCharsets.UTF_8;

public class SymbolKeySerializer
extends JsonSerializer<Symbol>
{
private static final Base64.Encoder ENCODER = Base64.getEncoder();

@Override
public void serialize(Symbol value, JsonGenerator generator, SerializerProvider serializers)
throws IOException
{
String name = ENCODER.encodeToString(value.name().getBytes(UTF_8));
String type = ENCODER.encodeToString(value.type().getTypeId().getId().getBytes(UTF_8));
generator.writeFieldName(name + ":" + type);
String type = value.type().getTypeId().getId();
generator.writeFieldName(type.length() + "|" + type + "|" + value.name());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableMap;
import io.airlift.json.JsonCodec;
import io.airlift.json.JsonCodecFactory;
import io.airlift.json.ObjectMapperProvider;
import io.trino.spi.type.TestingTypeManager;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeId;
import io.trino.type.TypeDeserializer;
import org.junit.jupiter.api.Test;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class TestSymbolKeySerialization
{
private static final TestingTypeManager TYPE_MANAGER = new TestingTypeManager();
private static final ObjectMapperProvider OBJECT_MAPPER_PROVIDER = createObjectMapperProvider(TYPE_MANAGER);
private static final JsonCodec<Map<Symbol, String>> SYMBOL_KEY_CODEC = new JsonCodecFactory(OBJECT_MAPPER_PROVIDER)
.mapJsonCodec(Symbol.class, String.class);

@Test
void testRoundTrip()
{
Map<Symbol, String> symbols = Map.of(
new Symbol(TYPE_MANAGER.getType(TypeId.of("integer")), "a"), "value",
new Symbol(TYPE_MANAGER.getType(TypeId.of("varchar")), "b"), "value",
new Symbol(TYPE_MANAGER.getType(TypeId.of("integer")), "abcd"), "value",
new Symbol(TYPE_MANAGER.getType(TypeId.of("integer")), "1abcd"), "value",
new Symbol(TYPE_MANAGER.getType(TypeId.of("varchar")), "b".repeat(256)), "value",
new Symbol(TYPE_MANAGER.getType(TypeId.of("id")), "a"), "value");

assertThat(SYMBOL_KEY_CODEC.fromJson(SYMBOL_KEY_CODEC.toJson(symbols)))
.isEqualTo(symbols);

assertThat(SYMBOL_KEY_CODEC.toJson(symbols))
.contains("7|integer|a")
.contains("7|varchar|b")
.contains("7|varchar|%s".formatted("b".repeat(256)))
.contains("7|integer|1abcd")
.contains("7|integer|abcd")
.contains("2|id|a");
}

@Test
void testMalformedSymbolKey()
{
assertThatThrownBy(() -> SYMBOL_KEY_CODEC.fromJson("{\"1|a|\":\"value\"}"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Symbol key is malformed: 1|a|");

assertThatThrownBy(() -> SYMBOL_KEY_CODEC.fromJson("{\"256|a|\":\"value\"}"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Symbol key is malformed: 256|a|");

assertThatThrownBy(() -> SYMBOL_KEY_CODEC.fromJson("{\"1|a\":\"value\"}"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Symbol key is malformed: 1|a");
}

private static ObjectMapperProvider createObjectMapperProvider(TestingTypeManager typeManager)
{
ObjectMapperProvider provider = new ObjectMapperProvider();
provider.setKeyDeserializers(ImmutableMap.of(Symbol.class, new SymbolKeyDeserializer(typeManager)));
provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TypeDeserializer(typeManager::getType)));
return provider;
}
}
Loading