Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@

import javax.crypto.SecretKey;

import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.Date;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

import static io.airlift.http.client.Request.Builder.fromRequest;
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
Expand All @@ -41,18 +45,23 @@
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.temporal.ChronoUnit.MINUTES;
import static java.util.Objects.requireNonNull;

public class InternalAuthenticationManager
implements HttpRequestFilter
{
private static final Logger log = Logger.get(InternalAuthenticationManager.class);
private static final Supplier<Instant> DEFAULT_EXPIRATION_SUPPLIER = () -> ZonedDateTime.now().plusMinutes(6).toInstant();
// Leave a 5 minute buffer to allow for clock skew and GC pauses
private static final Function<Instant, Instant> TOKEN_REUSE_THRESHOLD = instant -> instant.minus(5, MINUTES);

private static final String TRINO_INTERNAL_BEARER = "X-Trino-Internal-Bearer";

private final SecretKey hmac;
private final String nodeId;
private final JwtParser jwtParser;
private final AtomicReference<InternalToken> currentToken;

@Inject
public InternalAuthenticationManager(InternalCommunicationConfig internalCommunicationConfig, SecurityConfig securityConfig, NodeInfo nodeInfo)
Expand Down Expand Up @@ -84,6 +93,7 @@ public InternalAuthenticationManager(String sharedSecret, String nodeId)
this.hmac = hmacShaKeyFor(Hashing.sha256().hashString(sharedSecret, UTF_8).asBytes());
this.nodeId = nodeId;
this.jwtParser = newJwtParserBuilder().verifyWith(hmac).build();
this.currentToken = new AtomicReference<>(createJwt());
}

public static boolean isInternalRequest(ContainerRequestContext request)
Expand Down Expand Up @@ -118,17 +128,34 @@ public void handleInternalRequest(ContainerRequestContext request)
public Request filterRequest(Request request)
{
return fromRequest(request)
.addHeader(TRINO_INTERNAL_BEARER, generateJwt())
.addHeader(TRINO_INTERNAL_BEARER, getOrGenerateJwt())
.build();
}

private String generateJwt()
private String getOrGenerateJwt()
{
return newJwtBuilder()
InternalToken token = currentToken.get();
if (token.isExpired()) {
InternalToken newToken = createJwt();
if (currentToken.compareAndSet(token, newToken)) {
token = newToken;
}
else {
// Another thread already generated a new token
token = currentToken.get();
}
}
return token.token();
}

private InternalToken createJwt()
{
Instant expiration = DEFAULT_EXPIRATION_SUPPLIER.get();
return new InternalToken(expiration, newJwtBuilder()
.signWith(hmac)
.subject(nodeId)
.expiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant()))
.compact();
.expiration(Date.from(expiration))
.compact());
}

private String parseJwt(String jwt)
Expand All @@ -138,4 +165,18 @@ private String parseJwt(String jwt)
.getPayload()
.getSubject();
}

private record InternalToken(Instant expiration, String token)
{
public InternalToken
{
expiration = TOKEN_REUSE_THRESHOLD.apply(requireNonNull(expiration, "expiration is null"));
requireNonNull(token, "token is null");
}

public boolean isExpired()
{
return Instant.now().isAfter(expiration);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.server;

import com.google.inject.Binder;
import com.google.inject.Scopes;
import com.google.inject.multibindings.Multibinder;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.discovery.client.ForDiscoveryClient;
Expand Down Expand Up @@ -52,7 +53,7 @@ protected void setup(Binder binder)
}
discoveryFilterBinder.addBinding().to(InternalAuthenticationManager.class);
configBinder(binder).bindConfigDefaults(HttpClientConfig.class, ForDiscoveryClient.class, config -> configureClient(config, internalCommunicationConfig));
binder.bind(InternalAuthenticationManager.class);
binder.bind(InternalAuthenticationManager.class).in(Scopes.SINGLETON);
}

private static class DiscoveryEncodeAddressAsHostname
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
*/
package io.trino.sql.analyzer;

import com.google.common.base.CharMatcher;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheLoader;
import com.google.common.collect.ImmutableList;
import io.trino.cache.EvictableCacheBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.type.NamedTypeSignature;
import io.trino.spi.type.RowFieldName;
Expand All @@ -36,15 +40,14 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH;
import static io.trino.spi.type.StandardTypes.INTERVAL_DAY_TO_SECOND;
import static io.trino.spi.type.StandardTypes.INTERVAL_YEAR_TO_MONTH;
import static io.trino.spi.type.StandardTypes.ROW;
Expand All @@ -58,15 +61,34 @@
import static io.trino.spi.type.TypeSignatureParameter.typeParameter;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.VarcharType.UNBOUNDED_LENGTH;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static io.trino.type.IntervalDayTimeType.INTERVAL_DAY_TIME;
import static io.trino.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;

public final class TypeSignatureTranslator
{
private static final SqlParser SQL_PARSER = new SqlParser();

private static final Cache<String, DataType> DATA_TYPE_CACHE = EvictableCacheBuilder.newBuilder()
.maximumSize(4096)
.build(new CacheLoader<>() {
@Override
public DataType load(String signature)
{
return parseDataType(signature);
}
});

private static final CharMatcher IS_DIGIT = CharMatcher.inRange('0', '9')
.precomputed();

private static final CharMatcher IS_VALID_IDENTIFIER_CHAR = CharMatcher.inRange('a', 'z')
.or(CharMatcher.inRange('A', 'Z'))
.or(CharMatcher.is('_'))
.or(CharMatcher.inRange('0', '9'))
.precomputed();

private TypeSignatureTranslator() {}

public static DataType toSqlType(Type type)
Expand All @@ -93,7 +115,16 @@ public static TypeSignature parseTypeSignature(String signature, Set<String> typ
{
Set<String> variables = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
variables.addAll(typeVariables);
return toTypeSignature(SQL_PARSER.createType(signature), variables);
try {
return toTypeSignature(DATA_TYPE_CACHE.get(signature.toLowerCase(ENGLISH), () -> parseDataType(signature)), variables);
}
catch (Exception e) {
if (e.getCause() != null) {
throwIfUnchecked(e.getCause());
}
throwIfUnchecked(e);
throw new RuntimeException(e);
Copy link

Copilot AI Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping parseTypeSignature failures in a generic RuntimeException loses the original semanticException and TYPE_MISMATCH error code. Consider throwing a semanticException with the appropriate error code to preserve user-facing diagnostics.

Suggested change
throw new RuntimeException(e);
throw new TrinoException(TYPE_MISMATCH, e.getMessage(), e);

Copilot uses AI. Check for mistakes.
}
}

private static TypeSignature toTypeSignature(GenericDataType type, Set<String> typeVariables)
Expand All @@ -111,13 +142,7 @@ private static TypeSignature toTypeSignature(GenericDataType type, Set<String> t
for (DataTypeParameter parameter : type.getArguments()) {
switch (parameter) {
case NumericParameter numericParameter -> {
String value = numericParameter.getValue();
try {
parameters.add(numericParameter(Long.parseLong(value)));
}
catch (NumberFormatException e) {
throw semanticException(TYPE_MISMATCH, parameter, "Invalid type parameter: %s", value);
}
parameters.add(numericParameter(numericParameter.getParsedValue()));
}
case TypeParameter typeParameter -> {
DataType value = typeParameter.getValue();
Expand Down Expand Up @@ -182,7 +207,7 @@ private static List<TypeSignatureParameter> translateParameters(DateTimeDataType
if (type.getPrecision().isPresent()) {
DataTypeParameter precision = type.getPrecision().get();
if (precision instanceof NumericParameter numericParameter) {
parameters.add(numericParameter(Long.parseLong(numericParameter.getValue())));
parameters.add(numericParameter(numericParameter.getParsedValue()));
}
else if (precision instanceof TypeParameter typeParameter) {
DataType typeVariable = typeParameter.getValue();
Expand All @@ -203,7 +228,7 @@ private static String canonicalize(Identifier identifier)
return identifier.getValue();
}

return identifier.getValue().toLowerCase(Locale.ENGLISH); // TODO: make this toUpperCase to match standard SQL semantics
return identifier.getValue().toLowerCase(ENGLISH); // TODO: make this toUpperCase to match standard SQL semantics
}

@VisibleForTesting
Expand Down Expand Up @@ -253,7 +278,7 @@ static DataType toDataType(TypeSignature typeSignature)
new Identifier(typeSignature.getBase(), false),
typeSignature.getParameters().stream()
.filter(parameter -> parameter.getLongLiteral() != UNBOUNDED_LENGTH)
.map(parameter -> new NumericParameter(Optional.empty(), String.valueOf(parameter)))
.map(parameter -> new NumericParameter(Optional.empty(), parameter.toString()))
.collect(toImmutableList()));
default -> new GenericDataType(
Optional.empty(),
Expand All @@ -266,19 +291,35 @@ static DataType toDataType(TypeSignature typeSignature)

private static boolean requiresDelimiting(String identifier)
{
if (!identifier.matches("[a-zA-Z][a-zA-Z0-9_]*")) {
if (!isValidIdentifier(identifier)) {
return true;
}

return ReservedIdentifiers.reserved(identifier);
}

private static boolean isValidIdentifier(String identifier)
{
if (IS_DIGIT.matches(identifier.charAt(0))) {
return false;
}

// We've already checked that first char does not contain digits,
// so to avoid copying we are checking whole string.
return IS_VALID_IDENTIFIER_CHAR.matchesAllOf(identifier);
}

private static DataTypeParameter toTypeParameter(TypeSignatureParameter parameter)
{
return switch (parameter.getKind()) {
case LONG -> new NumericParameter(Optional.empty(), String.valueOf(parameter.getLongLiteral()));
case LONG -> new NumericParameter(Optional.empty(), parameter.toString());
case TYPE -> new TypeParameter(toDataType(parameter.getTypeSignature()));
default -> throw new UnsupportedOperationException("Unsupported parameter kind");
};
}

private static DataType parseDataType(String signature)
{
return SQL_PARSER.createType(signature);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ private boolean exploreNode(int group, Context context, Set<PlanNodeId> changedP
context.checkTimeoutNotExhausted();

done = true;
Iterator<Rule<?>> possiblyMatchingRules = ruleIndex.getCandidates(node).iterator();
while (possiblyMatchingRules.hasNext()) {
Rule<?> rule = possiblyMatchingRules.next();
for (Rule<?> rule : ruleIndex.getCandidates(node)) {
long timeStart = nanoTime();
long timeEnd;
boolean invoked = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,51 @@

package io.trino.sql.planner.iterative;

import com.google.common.cache.Cache;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.reflect.TypeToken;
import io.trino.cache.EvictableCacheBuilder;
import io.trino.matching.Pattern;
import io.trino.matching.pattern.TypeOfPattern;

import java.util.Set;
import java.util.stream.Stream;
import java.util.concurrent.ExecutionException;

public class RuleIndex
{
private final ListMultimap<Class<?>, Rule<?>> rulesByRootType;
private final Cache<Class<?>, Set<Rule<?>>> rulesByClass;

private RuleIndex(ListMultimap<Class<?>, Rule<?>> rulesByRootType)
{
this.rulesByRootType = ImmutableListMultimap.copyOf(rulesByRootType);
this.rulesByClass = EvictableCacheBuilder.newBuilder()
.maximumSize(128) // we have a limited number of node types, so this is more than enough
.build();
}

public Stream<Rule<?>> getCandidates(Object object)
public Set<Rule<?>> getCandidates(Object object)
{
return supertypes(object.getClass())
.flatMap(clazz -> rulesByRootType.get(clazz).stream());
try {
return rulesByClass.get(object.getClass(), () -> computeCandidates(object.getClass()));
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

private static Stream<Class<?>> supertypes(Class<?> type)
public Set<Rule<?>> computeCandidates(Class<?> key)
{
return TypeToken.of(type).getTypes().stream()
.map(TypeToken::getRawType);
ImmutableSet.Builder<Rule<?>> builder = ImmutableSet.builder();
TypeToken.of(key).getTypes().forEach(clazz -> {
Class<?> rawType = clazz.getRawType();
if (rulesByRootType.containsKey(rawType)) {
builder.addAll(rulesByRootType.get(rawType));
}
});
return builder.build();
}

public static Builder builder()
Expand All @@ -62,12 +79,10 @@ public Builder register(Set<Rule<?>> rules)
public Builder register(Rule<?> rule)
{
Pattern<?> pattern = getFirstPattern(rule.getPattern());
if (pattern instanceof TypeOfPattern) {
rulesByRootType.put(((TypeOfPattern<?>) pattern).expectedClass(), rule);
}
else {
if (!(pattern instanceof TypeOfPattern<?> typeOfPattern)) {
throw new IllegalArgumentException("Unexpected Pattern: " + pattern);
}
rulesByRootType.put(typeOfPattern.expectedClass(), rule);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,11 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context)
Expression rewritten;
if (row instanceof Row value) {
// preserve the structure of row
rewritten = new Row(value.items().stream()
.map(item -> rewriter.rewrite(item, context))
.collect(toImmutableList()));
ImmutableList.Builder<Expression> rowValues = ImmutableList.builderWithExpectedSize(value.items().size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original looks nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but resizes the list instead of preallocating it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns a new builder, expecting the specified number of elements to be added.
If expectedSize is exactly the number of elements added to the builder before ImmutableList.Builder.build is called, the builder is likely to perform better than an unsized builder() would have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the toImmutableList() is:

  private static final Collector<Object, ?, ImmutableList<Object>> TO_IMMUTABLE_LIST =
      Collector.of(
          ImmutableList::builder,
          ImmutableList.Builder::add,
          ImmutableList.Builder::combine,
          ImmutableList.Builder::build);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicer != more performant :P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meh - I doubt it matters at all, so I would keep nicer. I will not fight though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we observed a measurable performance improvement? If not, then readability is more important.

for (Expression item : value.items()) {
rowValues.add(rewriter.rewrite(item, context));
}
rewritten = new Row(rowValues.build());
}
else {
rewritten = rewriter.rewrite(row, context);
Expand Down
Loading
Loading