diff --git a/core/trino-main/src/main/java/io/trino/likematcher/DFA.java b/core/trino-main/src/main/java/io/trino/likematcher/DFA.java new file mode 100644 index 000000000000..79aaebc9ede6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/DFA.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +record DFA(State start, State failed, List states, Map> transitions) +{ + DFA + { + requireNonNull(start, "start is null"); + requireNonNull(failed, "failed is null"); + states = ImmutableList.copyOf(states); + transitions = ImmutableMap.copyOf(transitions); + } + + public List transitions(State state) + { + return transitions.get(state.id); + } + + record State(int id, String label, boolean accept) + { + @Override + public String toString() + { + return "%s:%s%s".formatted( + id, + accept ? "*" : "", + label); + } + } + + record Transition(int value, State target) + { + @Override + public String toString() + { + return format("-[%s]-> %s", value, target); + } + } + + public static class Builder + { + private int nextId; + private State start; + private State failed; + private final List states = new ArrayList<>(); + private final Map> transitions = new HashMap<>(); + + public State addState(String label, boolean accept) + { + State state = new State(nextId++, label, accept); + states.add(state); + return state; + } + + public State addStartState(String label, boolean accept) + { + checkState(start == null, "Start state already set"); + State state = addState(label, accept); + start = state; + return state; + } + + public State addFailState() + { + checkState(failed == null, "Fail state already set"); + State state = addState("fail", false); + failed = state; + return state; + } + + public void addTransition(State from, int value, State to) + { + transitions.computeIfAbsent(from.id(), key -> new ArrayList<>()) + .add(new Transition(value, to)); + } + + public DFA build() + { + return new DFA(start, failed, states, transitions); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java b/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java new file mode 100644 index 000000000000..ab70e14e7896 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +class DenseDfaMatcher +{ + // The DFA is encoded as a sequence of transitions for each possible byte value for each state. + // I.e., 256 transitions per state. + // The content of the transitions array is the base offset into + // the next state to follow. I.e., the desired state * 256 + private final int[] transitions; + + // The starting state + private final int start; + + // For each state, whether it's an accepting state + private final boolean[] accept; + + // Artificial state to sink all invalid matches + private final int fail; + + private final boolean exact; + + /** + * @param exact whether to match to the end of the input + */ + public static DenseDfaMatcher newInstance(DFA dfa, boolean exact) + { + int[] transitions = new int[dfa.states().size() * 256]; + boolean[] accept = new boolean[dfa.states().size()]; + + for (DFA.State state : dfa.states()) { + for (DFA.Transition transition : dfa.transitions(state)) { + transitions[state.id() * 256 + transition.value()] = transition.target().id() * 256; + } + + if (state.accept()) { + accept[state.id()] = true; + } + } + + return new DenseDfaMatcher(transitions, dfa.start().id(), accept, 0, exact); + } + + private DenseDfaMatcher(int[] transitions, int start, boolean[] accept, int fail, boolean exact) + { + this.transitions = transitions; + this.start = start; + this.accept = accept; + this.fail = fail; + this.exact = exact; + } + + public boolean match(byte[] input, int offset, int length) + { + if (exact) { + return exactMatch(input, offset, length); + } + + return prefixMatch(input, offset, length); + } + + /** + * Returns a positive match when the final state after all input has been consumed is an accepting state + */ + private boolean exactMatch(byte[] input, int offset, int length) + { + int state = start << 8; + for (int i = offset; i < offset + length; i++) { + byte inputByte = input[i]; + state = transitions[state | (inputByte & 0xFF)]; + + if (state == fail) { + return false; + } + } + + return accept[state >>> 8]; + } + + /** + * Returns a positive match as soon as the DFA reaches an accepting state, regardless of whether + * the whole input has been consumed + */ + private boolean prefixMatch(byte[] input, int offset, int length) + { + int state = start << 8; + for (int i = offset; i < offset + length; i++) { + byte inputByte = input[i]; + state = transitions[state | (inputByte & 0xFF)]; + + if (state == fail) { + return false; + } + + if (accept[state >>> 8]) { + return true; + } + } + + return accept[state >>> 8]; + } +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java b/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java new file mode 100644 index 000000000000..560f04ddce8f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java @@ -0,0 +1,386 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import io.trino.likematcher.Pattern.Any; +import io.trino.likematcher.Pattern.Literal; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.nio.charset.StandardCharsets.UTF_8; + +public class LikeMatcher +{ + private final String pattern; + private final Optional escape; + + private final int minSize; + private final OptionalInt maxSize; + private final byte[] prefix; + private final byte[] suffix; + private final Optional matcher; + + private LikeMatcher( + String pattern, + Optional escape, + int minSize, + OptionalInt maxSize, + byte[] prefix, + byte[] suffix, + Optional matcher) + { + this.pattern = pattern; + this.escape = escape; + this.minSize = minSize; + this.maxSize = maxSize; + this.prefix = prefix; + this.suffix = suffix; + this.matcher = matcher; + } + + public String getPattern() + { + return pattern; + } + + public Optional getEscape() + { + return escape; + } + + public static LikeMatcher compile(String pattern) + { + return compile(pattern, Optional.empty()); + } + + public static LikeMatcher compile(String pattern, Optional escape) + { + List parsed = parse(pattern, escape); + List optimized = optimize(parsed); + + // Calculate minimum and maximum size for candidate strings + // This is used for short-circuiting the match if the size of + // the input is outside those bounds + int minSize = 0; + int maxSize = 0; + boolean unbounded = false; + for (Pattern expression : optimized) { + if (expression instanceof Literal literal) { + int length = literal.value().getBytes(UTF_8).length; + minSize += length; + maxSize += length; + } + else if (expression instanceof Any any) { + int length = any.min(); + minSize += length; + maxSize += length * 4; // at most 4 bytes for a single UTF-8 codepoint + + unbounded = unbounded || any.unbounded(); + } + else { + throw new UnsupportedOperationException("Not supported: " + expression.getClass().getName()); + } + } + + // Calculate exact match prefix and suffix + // If the pattern starts and ends with a literal, we can perform a quick + // exact match to short-circuit DFA evaluation + byte[] prefix = new byte[0]; + byte[] suffix = new byte[0]; + List middle = new ArrayList<>(); + for (int i = 0; i < optimized.size(); i++) { + Pattern expression = optimized.get(i); + + if (i == 0) { + if (expression instanceof Literal literal) { + prefix = literal.value().getBytes(UTF_8); + continue; + } + } + else if (i == optimized.size() - 1) { + if (expression instanceof Literal literal) { + suffix = literal.value().getBytes(UTF_8); + continue; + } + } + + middle.add(expression); + } + + // If the pattern (after excluding constant prefix/suffixes) ends with an unbounded match (i.e., %) + // we can perform a non-exact match and end as soon as the DFA reaches an accept state -- there + // is no need to consume the remaining input + // This section determines whether the pattern is a candidate for non-exact match. + boolean exact = true; // whether to match to the end of the input + if (!middle.isEmpty()) { + // guaranteed to be Any because any Literal would've been turned into a suffix above + Any last = (Any) middle.get(middle.size() - 1); + if (last.unbounded()) { + exact = false; + + // Since the matcher will stop early, no need for an unbounded matcher (it produces a simpler DFA) + if (last.min() == 0) { + // We'd end up with an empty string match at the end, so just remove it + middle.remove(middle.size() - 1); + } + else { + middle.set(middle.size() - 1, new Any(last.min(), false)); + } + } + } + + Optional matcher = Optional.empty(); + if (!middle.isEmpty()) { + matcher = Optional.of(DenseDfaMatcher.newInstance(makeNfa(middle).toDfa(), exact)); + } + + return new LikeMatcher( + pattern, + escape, + minSize, + unbounded ? OptionalInt.empty() : OptionalInt.of(maxSize), + prefix, + suffix, + matcher); + } + + public boolean match(byte[] input) + { + return match(input, 0, input.length); + } + + public boolean match(byte[] input, int offset, int length) + { + if (length < minSize) { + return false; + } + + if (maxSize.isPresent() && length > maxSize.getAsInt()) { + return false; + } + + if (!startsWith(prefix, input, offset)) { + return false; + } + + if (!startsWith(suffix, input, offset + length - suffix.length)) { + return false; + } + + if (matcher.isPresent()) { + return matcher.get().match(input, offset + prefix.length, length - suffix.length - prefix.length); + } + + return true; + } + + private boolean startsWith(byte[] pattern, byte[] input, int offset) + { + for (int i = 0; i < pattern.length; i++) { + if (pattern[i] != input[offset + i]) { + return false; + } + } + + return true; + } + + private static List parse(String pattern, Optional escape) + { + List result = new ArrayList<>(); + + StringBuilder literal = new StringBuilder(); + boolean inEscape = false; + for (int i = 0; i < pattern.length(); i++) { + char character = pattern.charAt(i); + + if (inEscape) { + if (character != '%' && character != '_' && character != escape.get()) { + throw new IllegalArgumentException("Escape character must be followed by '%', '_' or the escape character itself"); + } + literal.append(character); + inEscape = false; + } + else if (escape.isPresent() && character == escape.get()) { + inEscape = true; + } + else if (character == '%' || character == '_') { + if (literal.length() != 0) { + result.add(new Literal(literal.toString())); + literal = new StringBuilder(); + } + + if (character == '%') { + result.add(new Any(0, true)); + } + else { + result.add(new Any(1, false)); + } + } + else { + literal.append(character); + } + } + + if (inEscape) { + throw new IllegalArgumentException("Escape character must be followed by '%', '_' or the escape character itself"); + } + + if (literal.length() != 0) { + result.add(new Literal(literal.toString())); + } + + return result; + } + + private static List optimize(List pattern) + { + if (pattern.isEmpty()) { + return pattern; + } + + List result = new ArrayList<>(); + + int anyPatternStart = -1; + for (int i = 0; i < pattern.size(); i++) { + Pattern current = pattern.get(i); + + if (anyPatternStart == -1 && current instanceof Any) { + anyPatternStart = i; + } + else if (current instanceof Literal) { + if (anyPatternStart != -1) { + result.add(collapse(pattern, anyPatternStart, i)); + } + + result.add(current); + anyPatternStart = -1; + } + } + + if (anyPatternStart != -1) { + result.add(collapse(pattern, anyPatternStart, pattern.size())); + } + + return result; + } + + /** + * Collapses a sequence of consecutive Any items + */ + private static Any collapse(List pattern, int start, int end) + { + int min = 0; + boolean unbounded = false; + + for (int i = start; i < end; i++) { + Any any = (Any) pattern.get(i); + + min += any.min(); + unbounded = unbounded || any.unbounded(); + } + + return new Any(min, unbounded); + } + + private static NFA makeNfa(List pattern) + { + checkArgument(!pattern.isEmpty(), "pattern is empty"); + + NFA.Builder builder = new NFA.Builder(); + + NFA.State state = builder.addStartState(); + + for (Pattern item : pattern) { + if (item instanceof Literal literal) { + for (byte current : literal.value().getBytes(UTF_8)) { + state = matchByte(builder, state, current); + } + } + else if (item instanceof Any any) { + NFA.State previous; + int i = 0; + do { + previous = state; + state = matchSingleUtf8(builder, state); + i++; + } + while (i < any.min()); + + if (any.min() == 0) { + builder.addTransition(previous, new NFA.Epsilon(), state); + } + + if (any.unbounded()) { + builder.addTransition(state, new NFA.Epsilon(), previous); + } + } + else { + throw new UnsupportedOperationException("Not supported: " + item.getClass().getName()); + } + } + + builder.setAccept(state); + + return builder.build(); + } + + private static NFA.State matchByte(NFA.Builder builder, NFA.State state, byte value) + { + NFA.State next = builder.addState(); + builder.addTransition(state, new NFA.Value(value), next); + return next; + } + + private static NFA.State matchSingleUtf8(NFA.Builder builder, NFA.State start) + { + /* + Implements a state machine to recognize UTF-8 characters. + + 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + O ───────────► O ───────────► O ───────────► O ───────────► O + │ ▲ ▲ ▲ + ├─────────────────────────────┘ │ │ + │ 1110xxxx │ │ + │ │ │ + ├────────────────────────────────────────────┘ │ + │ 110xxxxx │ + │ │ + └───────────────────────────────────────────────────────────┘ + 0xxxxxxx + */ + + NFA.State next = builder.addState(); + + builder.addTransition(start, new NFA.Prefix(0, 1), next); + + NFA.State state1 = builder.addState(); + NFA.State state2 = builder.addState(); + NFA.State state3 = builder.addState(); + + builder.addTransition(start, new NFA.Prefix(0b11110, 5), state1); + builder.addTransition(start, new NFA.Prefix(0b1110, 4), state2); + builder.addTransition(start, new NFA.Prefix(0b110, 3), state3); + + builder.addTransition(state1, new NFA.Prefix(0b10, 2), state2); + builder.addTransition(state2, new NFA.Prefix(0b10, 2), state3); + builder.addTransition(state3, new NFA.Prefix(0b10, 2), next); + + return next; + } +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/NFA.java b/core/trino-main/src/main/java/io/trino/likematcher/NFA.java new file mode 100644 index 000000000000..70316f2eb79d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/NFA.java @@ -0,0 +1,206 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +record NFA(State start, State accept, List states, Map> transitions) +{ + NFA { + requireNonNull(start, "start is null"); + requireNonNull(accept, "accept is null"); + states = ImmutableList.copyOf(states); + transitions = ImmutableMap.copyOf(transitions); + } + + public DFA toDfa() + { + Map, DFA.State> activeStates = new HashMap<>(); + + DFA.Builder builder = new DFA.Builder(); + DFA.State failed = builder.addFailState(); + for (int i = 0; i < 256; i++) { + builder.addTransition(failed, i, failed); + } + + Set initial = transitiveClosure(Set.of(this.start)); + Queue> queue = new ArrayDeque<>(); + queue.add(initial); + + DFA.State dfaStartState = builder.addStartState(makeLabel(initial), initial.contains(accept)); + activeStates.put(initial, dfaStartState); + + Set> visited = new HashSet<>(); + while (!queue.isEmpty()) { + Set current = queue.poll(); + + if (!visited.add(current)) { + continue; + } + + // For each possible byte value... + for (int byteValue = 0; byteValue < 256; byteValue++) { + Set next = new HashSet<>(); + for (NFA.State nfaState : current) { + for (Transition transition : transitions(nfaState)) { + Condition condition = transition.condition(); + State target = states.get(transition.target()); + + if (condition instanceof Value valueTransition && valueTransition.value() == (byte) byteValue) { + next.add(target); + } + else if (condition instanceof Prefix prefixTransition) { + if (byteValue >>> (8 - prefixTransition.bits()) == prefixTransition.prefix()) { + next.add(target); + } + } + } + } + + DFA.State from = activeStates.get(current); + DFA.State to = failed; + if (!next.isEmpty()) { + Set closure = transitiveClosure(next); + to = activeStates.computeIfAbsent(closure, nfaStates -> builder.addState(makeLabel(nfaStates), nfaStates.contains(accept))); + queue.add(closure); + } + builder.addTransition(from, byteValue, to); + } + } + + return builder.build(); + } + + private List transitions(State state) + { + return transitions.getOrDefault(state.id(), ImmutableList.of()); + } + + /** + * Traverse epsilon transitions to compute the reachable set of states + */ + private Set transitiveClosure(Set states) + { + Set result = new HashSet<>(); + + Queue queue = new ArrayDeque<>(states); + while (!queue.isEmpty()) { + State state = queue.poll(); + + if (result.contains(state)) { + continue; + } + + transitions(state).stream() + .filter(transition -> transition.condition() instanceof Epsilon) + .forEach(transition -> { + State target = this.states.get(transition.target()); + result.add(target); + queue.add(target); + }); + } + + result.addAll(states); + + return result; + } + + private String makeLabel(Set states) + { + return "{" + states.stream() + .map(NFA.State::id) + .map(Object::toString) + .sorted() + .collect(Collectors.joining(",")) + "}"; + } + + public static class Builder + { + private int nextId; + private State start; + private State accept; + private final List states = new ArrayList<>(); + private final Map> transitions = new HashMap<>(); + + public State addState() + { + State state = new State(nextId++); + states.add(state); + return state; + } + + public State addStartState() + { + checkState(start == null, "Start state is already set"); + start = addState(); + return start; + } + + public void setAccept(State state) + { + checkState(accept == null, "Accept state is already set"); + accept = state; + } + + public void addTransition(State from, Condition condition, State to) + { + transitions.computeIfAbsent(from.id(), key -> new ArrayList<>()) + .add(new Transition(to.id(), condition)); + } + + public NFA build() + { + return new NFA(start, accept, states, transitions); + } + } + + public record State(int id) + { + @Override + public String toString() + { + return "(" + id + ")"; + } + } + + record Transition(int target, Condition condition) {} + + sealed interface Condition + permits Epsilon, Value, Prefix + { + } + + record Epsilon() + implements Condition {} + + record Value(byte value) + implements Condition {} + + record Prefix(int prefix, int bits) + implements Condition {} +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java b/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java new file mode 100644 index 000000000000..dbf6cea13cbd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; + +sealed interface Pattern + permits Pattern.Any, Pattern.Literal +{ + record Literal(String value) + implements Pattern + { + @Override + public String toString() + { + return value; + } + } + + record Any(int min, boolean unbounded) + implements Pattern + { + public Any + { + checkArgument(min > 0 || unbounded, "Any must be unbounded or require at least 1 character"); + } + + @Override + public String toString() + { + return format("{%s%s}", min, unbounded ? "+" : ""); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index 1598cdc545ca..7744bfe4d768 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -21,6 +21,7 @@ import io.airlift.slice.Slices; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; +import io.trino.likematcher.LikeMatcher; import io.trino.metadata.FunctionNullability; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; @@ -94,7 +95,6 @@ import io.trino.sql.tree.SymbolReference; import io.trino.sql.tree.WhenClause; import io.trino.type.FunctionType; -import io.trino.type.JoniRegexp; import io.trino.type.LikeFunctions; import io.trino.type.TypeCoercion; import io.trino.util.FastutilSetHelper; @@ -170,7 +170,7 @@ public class ExpressionInterpreter private final TypeCoercion typeCoercion; // identity-based cache for LIKE expressions with constant pattern and escape char - private final IdentityHashMap likePatternCache = new IdentityHashMap<>(); + private final IdentityHashMap likePatternCache = new IdentityHashMap<>(); private final IdentityHashMap> inListCache = new IdentityHashMap<>(); public ExpressionInterpreter(Expression expression, PlannerContext plannerContext, Session session, Map, Type> expressionTypes) @@ -1153,15 +1153,15 @@ protected Object visitLikePredicate(LikePredicate node, Object context) if (value instanceof Slice && pattern instanceof Slice && (escape == null || escape instanceof Slice)) { - JoniRegexp regex; + LikeMatcher matcher; if (escape == null) { - regex = LikeFunctions.compileLikePattern((Slice) pattern); + matcher = LikeMatcher.compile(((Slice) pattern).toStringUtf8(), Optional.empty()); } else { - regex = LikeFunctions.likePattern((Slice) pattern, (Slice) escape); + matcher = LikeFunctions.likePattern((Slice) pattern, (Slice) escape); } - return evaluateLikePredicate(node, (Slice) value, regex); + return evaluateLikePredicate(node, (Slice) value, matcher); } // if pattern is a constant without % or _ replace with a comparison @@ -1205,20 +1205,20 @@ else if (valueType instanceof VarcharType) { optimizedEscape); } - private boolean evaluateLikePredicate(LikePredicate node, Slice value, JoniRegexp regex) + private boolean evaluateLikePredicate(LikePredicate node, Slice value, LikeMatcher matcher) { if (type(node.getValue()) instanceof VarcharType) { - return LikeFunctions.likeVarchar(value, regex); + return LikeFunctions.likeVarchar(value, matcher); } Type type = type(node.getValue()); checkState(type instanceof CharType, "LIKE value is neither VARCHAR or CHAR"); - return LikeFunctions.likeChar((long) ((CharType) type).getLength(), value, regex); + return LikeFunctions.likeChar((long) ((CharType) type).getLength(), value, matcher); } - private JoniRegexp getConstantPattern(LikePredicate node) + private LikeMatcher getConstantPattern(LikePredicate node) { - JoniRegexp result = likePatternCache.get(node); + LikeMatcher result = likePatternCache.get(node); if (result == null) { StringLiteral pattern = (StringLiteral) node.getPattern(); @@ -1228,7 +1228,7 @@ private JoniRegexp getConstantPattern(LikePredicate node) result = LikeFunctions.likePattern(Slices.utf8Slice(pattern.getValue()), escape); } else { - result = LikeFunctions.compileLikePattern(Slices.utf8Slice(pattern.getValue())); + result = LikeMatcher.compile(pattern.getValue(), Optional.empty()); } likePatternCache.put(node, result); diff --git a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java index 8020104bf0db..3190e7eaedbe 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java +++ b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java @@ -13,14 +13,9 @@ */ package io.trino.type; -import io.airlift.jcodings.specific.NonStrictUTF8Encoding; -import io.airlift.joni.Matcher; -import io.airlift.joni.Option; -import io.airlift.joni.Regex; -import io.airlift.joni.Syntax; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; -import io.airlift.slice.Slices; +import io.trino.likematcher.LikeMatcher; import io.trino.spi.TrinoException; import io.trino.spi.function.LiteralParameter; import io.trino.spi.function.LiteralParameters; @@ -30,40 +25,22 @@ import java.util.Optional; -import static io.airlift.joni.constants.MetaChar.INEFFECTIVE_META_CHAR; -import static io.airlift.joni.constants.SyntaxProperties.OP_ASTERISK_ZERO_INF; -import static io.airlift.joni.constants.SyntaxProperties.OP_DOT_ANYCHAR; -import static io.airlift.joni.constants.SyntaxProperties.OP_LINE_ANCHOR; import static io.airlift.slice.SliceUtf8.getCodePointAt; import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; -import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.Chars.padSpaces; import static io.trino.util.Failures.checkCondition; -import static java.nio.charset.StandardCharsets.UTF_8; public final class LikeFunctions { public static final String LIKE_PATTERN_FUNCTION_NAME = "$like_pattern"; - private static final Syntax SYNTAX = new Syntax( - OP_DOT_ANYCHAR | OP_ASTERISK_ZERO_INF | OP_LINE_ANCHOR, - 0, - 0, - Option.NONE, - new Syntax.MetaCharTable( - '\\', /* esc */ - INEFFECTIVE_META_CHAR, /* anychar '.' */ - INEFFECTIVE_META_CHAR, /* anytime '*' */ - INEFFECTIVE_META_CHAR, /* zero or one time '?' */ - INEFFECTIVE_META_CHAR, /* one or more time '+' */ - INEFFECTIVE_META_CHAR)); /* anychar anytime */ private LikeFunctions() {} @ScalarFunction(value = "like", hidden = true) @LiteralParameters("x") @SqlType(StandardTypes.BOOLEAN) - public static boolean likeChar(@LiteralParameter("x") Long x, @SqlType("char(x)") Slice value, @SqlType(LikePatternType.NAME) JoniRegexp pattern) + public static boolean likeChar(@LiteralParameter("x") Long x, @SqlType("char(x)") Slice value, @SqlType(LikePatternType.NAME) LikeMatcher pattern) { return likeVarchar(padSpaces(value, x.intValue()), pattern); } @@ -72,42 +49,35 @@ public static boolean likeChar(@LiteralParameter("x") Long x, @SqlType("char(x)" @ScalarFunction(value = "like", hidden = true) @LiteralParameters("x") @SqlType(StandardTypes.BOOLEAN) - public static boolean likeVarchar(@SqlType("varchar(x)") Slice value, @SqlType(LikePatternType.NAME) JoniRegexp pattern) + public static boolean likeVarchar(@SqlType("varchar(x)") Slice value, @SqlType(LikePatternType.NAME) LikeMatcher matcher) { - // Joni can infinite loop with UTF8Encoding when invalid UTF-8 is encountered. - // NonStrictUTF8Encoding must be used to avoid this issue. - Matcher matcher; - int offset; if (value.hasByteArray()) { - offset = value.byteArrayOffset(); - matcher = pattern.regex().matcher(value.byteArray(), offset, offset + value.length()); + return matcher.match(value.byteArray(), value.byteArrayOffset(), value.length()); } else { - offset = 0; - matcher = pattern.matcher(value.getBytes()); + return matcher.match(value.getBytes(), 0, value.length()); } - return getMatchingOffset(matcher, offset, offset + value.length()) != -1; } @ScalarFunction(value = LIKE_PATTERN_FUNCTION_NAME, hidden = true) @LiteralParameters("x") @SqlType(LikePatternType.NAME) - public static JoniRegexp likePattern(@SqlType("varchar(x)") Slice pattern) + public static LikeMatcher likePattern(@SqlType("varchar(x)") Slice pattern) { - return compileLikePattern(pattern); - } - - public static JoniRegexp compileLikePattern(Slice pattern) - { - return likePattern(pattern.toStringUtf8(), '0', false); + return LikeMatcher.compile(pattern.toStringUtf8(), Optional.empty()); } @ScalarFunction(value = LIKE_PATTERN_FUNCTION_NAME, hidden = true) @LiteralParameters({"x", "y"}) @SqlType(LikePatternType.NAME) - public static JoniRegexp likePattern(@SqlType("varchar(x)") Slice pattern, @SqlType("varchar(y)") Slice escape) + public static LikeMatcher likePattern(@SqlType("varchar(x)") Slice pattern, @SqlType("varchar(y)") Slice escape) { - return likePattern(pattern.toStringUtf8(), getEscapeChar(escape), true); + try { + return LikeMatcher.compile(pattern.toStringUtf8(), getEscapeChar(escape)); + } + catch (RuntimeException e) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, e); + } } public static boolean isLikePattern(Slice pattern, Optional escape) @@ -187,76 +157,16 @@ private static void checkEscape(boolean condition) checkCondition(condition, INVALID_FUNCTION_ARGUMENT, "Escape character must be followed by '%%', '_' or the escape character itself"); } - @SuppressWarnings("NestedSwitchStatement") - private static JoniRegexp likePattern(String patternString, char escapeChar, boolean shouldEscape) - { - StringBuilder regex = new StringBuilder(patternString.length() * 2); - - regex.append('^'); - boolean escaped = false; - for (char currentChar : patternString.toCharArray()) { - checkEscape(!escaped || currentChar == '%' || currentChar == '_' || currentChar == escapeChar); - if (shouldEscape && !escaped && (currentChar == escapeChar)) { - escaped = true; - } - else { - switch (currentChar) { - case '%': - regex.append(escaped ? "%" : ".*"); - escaped = false; - break; - case '_': - regex.append(escaped ? "_" : "."); - escaped = false; - break; - default: - // escape special regex characters - switch (currentChar) { - case '\\': - case '^': - case '$': - case '.': - case '*': - regex.append('\\'); - } - - regex.append(currentChar); - escaped = false; - } - } - } - checkEscape(!escaped); - regex.append('$'); - - byte[] bytes = regex.toString().getBytes(UTF_8); - Regex joniRegex = new Regex(bytes, 0, bytes.length, Option.MULTILINE, NonStrictUTF8Encoding.INSTANCE, SYNTAX); - return new JoniRegexp(Slices.wrappedBuffer(bytes), joniRegex); - } - - @SuppressWarnings("NumericCastThatLosesPrecision") - private static char getEscapeChar(Slice escape) + private static Optional getEscapeChar(Slice escape) { String escapeString = escape.toStringUtf8(); if (escapeString.isEmpty()) { // escaping disabled - return (char) -1; // invalid character + return Optional.empty(); // invalid character } if (escapeString.length() == 1) { - return escapeString.charAt(0); + return Optional.of(escapeString.charAt(0)); } throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Escape string must be a single character"); } - - private static int getMatchingOffset(Matcher matcher, int at, int range) - { - try { - return matcher.matchInterruptible(at, range, Option.NONE); - } - catch (InterruptedException interruptedException) { - Thread.currentThread().interrupt(); - throw new TrinoException(GENERIC_USER_ERROR, "" + - "Regular expression matching was interrupted, likely because it took too long. " + - "Regular expression in the worst case can have a catastrophic amount of backtracking and having exponential time complexity"); - } - } } diff --git a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java index 63cc636a072c..5edce0406160 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java +++ b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java @@ -14,13 +14,17 @@ package io.trino.type; import io.airlift.slice.Slice; +import io.trino.likematcher.LikeMatcher; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; import io.trino.spi.type.TypeSignature; -import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; +import java.util.Optional; + +import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.airlift.slice.Slices.utf8Slice; public class LikePatternType extends AbstractVariableWidthType @@ -30,7 +34,7 @@ public class LikePatternType private LikePatternType() { - super(new TypeSignature(NAME), JoniRegexp.class); + super(new TypeSignature(NAME), LikeMatcher.class); } @Override @@ -52,13 +56,40 @@ public Object getObject(Block block, int position) return null; } - return joniRegexp(block.getSlice(position, 0, block.getSliceLength(position))); + // layout is: ? + int offset = 0; + int length = block.getInt(position, offset); + offset += SIZE_OF_INT; + String pattern = block.getSlice(position, offset, length).toStringUtf8(); + offset += length; + + boolean hasEscape = block.getByte(position, offset) != 0; + offset++; + + Optional escape = Optional.empty(); + if (hasEscape) { + escape = Optional.of((char) block.getInt(position, offset)); + } + + return LikeMatcher.compile(pattern, escape); } @Override public void writeObject(BlockBuilder blockBuilder, Object value) { - Slice pattern = ((JoniRegexp) value).pattern(); - blockBuilder.writeBytes(pattern, 0, pattern.length()).closeEntry(); + LikeMatcher matcher = (LikeMatcher) value; + + Slice pattern = utf8Slice(matcher.getPattern()); + int length = pattern.length(); + blockBuilder.writeInt(length); + blockBuilder.writeBytes(pattern, 0, length); + if (matcher.getEscape().isEmpty()) { + blockBuilder.writeByte(0); + } + else { + blockBuilder.writeByte(1); + blockBuilder.writeInt(matcher.getEscape().get()); + } + blockBuilder.closeEntry(); } } diff --git a/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java b/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java new file mode 100644 index 000000000000..539e77279258 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TestLikeMatcher +{ + @Test + public void test() + { + // min length short-circuit + assertFalse(match("__", "a")); + + // max length short-circuit + assertFalse(match("__", "abcdefghi")); + + // prefix short-circuit + assertFalse(match("a%", "xyz")); + + // prefix match + assertTrue(match("a%", "a")); + assertTrue(match("a%", "ab")); + assertTrue(match("a_", "ab")); + + // suffix short-circuit + assertFalse(match("%a", "xyz")); + + // suffix match + assertTrue(match("%z", "z")); + assertTrue(match("%z", "yz")); + assertTrue(match("_z", "yz")); + + // match literal + assertTrue(match("abcd", "abcd")); + + // match one + assertFalse(match("_", "")); + assertTrue(match("_", "a")); + assertFalse(match("_", "ab")); + + // match zero or more + assertTrue(match("%", "")); + assertTrue(match("%", "a")); + assertTrue(match("%", "ab")); + + // non-strict matching + assertTrue(match("_%", "abcdefg")); + assertFalse(match("_a%", "abcdefg")); + + // strict matching + assertTrue(match("_ab_", "xabc")); + assertFalse(match("_ab_", "xyxw")); + assertTrue(match("_a%b_", "xaxxxbx")); + + // optimization of consecutive _ and % + assertTrue(match("_%_%_%_%", "abcdefghij")); + + // utf-8 + LikeMatcher single = LikeMatcher.compile("_"); + LikeMatcher multiple = LikeMatcher.compile("_a%b_"); // prefix and suffix with _a and b_ to avoid optimizations + for (int i = 0; i < Character.MAX_CODE_POINT; i++) { + assertTrue(single.match(Character.toString(i).getBytes(StandardCharsets.UTF_8))); + + String value = "aa" + (char) i + "bb"; + assertTrue(multiple.match(value.getBytes(StandardCharsets.UTF_8))); + } + } + + @Test + public void testEscape() + { + assertTrue(match("-%", "%", '-')); + assertTrue(match("-_", "_", '-')); + assertTrue(match("--", "-", '-')); + } + + private static boolean match(String pattern, String value) + { + return match(pattern, value, Optional.empty()); + } + + private static boolean match(String pattern, String value, char escape) + { + return match(pattern, value, Optional.of(escape)); + } + + private static boolean match(String pattern, String value, Optional escape) + { + String padding = "++++"; + String padded = padding + value + padding; + byte[] bytes = padded.getBytes(StandardCharsets.UTF_8); + + boolean withoutPadding = LikeMatcher.compile(pattern, escape).match(value.getBytes(StandardCharsets.UTF_8)); + boolean withPadding = LikeMatcher.compile(pattern, escape).match(bytes, padding.length(), bytes.length - padding.length() * 2); // exclude padding + + assertEquals(withoutPadding, withPadding); + return withPadding; + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java new file mode 100644 index 000000000000..778fe40f676c --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.airlift.jcodings.specific.NonStrictUTF8Encoding; +import io.airlift.joni.Matcher; +import io.airlift.joni.Option; +import io.airlift.joni.Regex; +import io.airlift.joni.Syntax; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.likematcher.LikeMatcher; +import io.trino.type.JoniRegexp; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.results.format.ResultFormatType; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.Optional; + +import static io.airlift.joni.constants.MetaChar.INEFFECTIVE_META_CHAR; +import static io.airlift.joni.constants.SyntaxProperties.OP_ASTERISK_ZERO_INF; +import static io.airlift.joni.constants.SyntaxProperties.OP_DOT_ANYCHAR; +import static io.airlift.joni.constants.SyntaxProperties.OP_LINE_ANCHOR; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.util.Failures.checkCondition; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +@OutputTimeUnit(NANOSECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10, time = 500, timeUnit = MILLISECONDS) +@Measurement(iterations = 30, time = 500, timeUnit = MILLISECONDS) +public class BenchmarkLike +{ + private static final Syntax SYNTAX = new Syntax( + OP_DOT_ANYCHAR | OP_ASTERISK_ZERO_INF | OP_LINE_ANCHOR, + 0, + 0, + Option.NONE, + new Syntax.MetaCharTable( + '\\', /* esc */ + INEFFECTIVE_META_CHAR, /* anychar '.' */ + INEFFECTIVE_META_CHAR, /* anytime '*' */ + INEFFECTIVE_META_CHAR, /* zero or one time '?' */ + INEFFECTIVE_META_CHAR, /* one or more time '+' */ + INEFFECTIVE_META_CHAR)); /* anychar anytime */ + + @State(Thread) + public static class Data + { + @Param({ + "%", + "_%", + "%_", + "abc%", + "%abc", + "_____", + "abc%def%ghi", + "%abc%def%", + }) + private String pattern; + + private Slice data; + private byte[] bytes; + private JoniRegexp joniPattern; + private LikeMatcher matcher; + + @Setup + public void setup() + { + data = Slices.utf8Slice( + switch (pattern) { + case "%" -> "qeroighqeorhgqerhb2eriuyerqiubgierubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet"; + case "_%", "%_" -> "qeroighqeorhgqerhb2eriuyerqiubgierubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet"; + case "abc%" -> "abcqeroighqeorhgqerhb2eriuyerqiubgierubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet"; + case "%abc" -> "qeroighqeorhgqerhb2eriuyerqiubgierubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhetabc"; + case "_____" -> "abcde"; + case "abc%def%ghi" -> "abc qeroighqeorhgqerhb2eriuyerqiubgier def ubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet ghi"; + case "%abc%def%" -> "fdnbqerbfklerqbgqjerbgkr abc qeroighqeorhgqerhb2eriuyerqiubgier def ubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet"; + default -> throw new IllegalArgumentException("Unknown pattern: " + pattern); + }); + + matcher = LikeMatcher.compile(pattern, Optional.empty()); + joniPattern = compileJoni(Slices.utf8Slice(pattern).toStringUtf8(), '0', false); + + bytes = data.getBytes(); + } + } + + @Benchmark + public boolean benchmarkJoni(Data data) + { + return likeVarchar(data.data, data.joniPattern); + } + + @Benchmark + public boolean benchmarkCurrent(Data data) + { + return data.matcher.match(data.bytes, 0, data.bytes.length); + } + + public static boolean likeVarchar(Slice value, JoniRegexp pattern) + { + Matcher matcher; + int offset; + if (value.hasByteArray()) { + offset = value.byteArrayOffset(); + matcher = pattern.regex().matcher(value.byteArray(), offset, offset + value.length()); + } + else { + offset = 0; + matcher = pattern.matcher(value.getBytes()); + } + return matcher.match(offset, offset + value.length(), Option.NONE) != -1; + } + + private static JoniRegexp compileJoni(String patternString, char escapeChar, boolean shouldEscape) + { + byte[] bytes = likeToRegex(patternString, escapeChar, shouldEscape).getBytes(UTF_8); + Regex joniRegex = new Regex(bytes, 0, bytes.length, Option.MULTILINE, NonStrictUTF8Encoding.INSTANCE, SYNTAX); + return new JoniRegexp(Slices.wrappedBuffer(bytes), joniRegex); + } + + private static String likeToRegex(String patternString, char escapeChar, boolean shouldEscape) + { + StringBuilder regex = new StringBuilder(patternString.length() * 2); + + regex.append('^'); + boolean escaped = false; + for (char currentChar : patternString.toCharArray()) { + checkEscape(!escaped || currentChar == '%' || currentChar == '_' || currentChar == escapeChar); + if (shouldEscape && !escaped && (currentChar == escapeChar)) { + escaped = true; + } + else { + switch (currentChar) { + case '%' -> { + regex.append(escaped ? "%" : ".*"); + escaped = false; + } + case '_' -> { + regex.append(escaped ? "_" : "."); + escaped = false; + } + default -> { + // escape special regex characters + switch (currentChar) { + case '\\', '^', '$', '.', '*' -> regex.append('\\'); + } + regex.append(currentChar); + escaped = false; + } + } + } + } + checkEscape(!escaped); + regex.append('$'); + return regex.toString(); + } + + private static void checkEscape(boolean condition) + { + checkCondition(condition, INVALID_FUNCTION_ARGUMENT, "Escape character must be followed by '%%', '_' or the escape character itself"); + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkLike.class.getSimpleName() + ".*") + .resultFormat(ResultFormatType.JSON) + .build(); + + new Runner(options).run(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestJoniRegexpFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestJoniRegexpFunctions.java index aef185dd25a6..d1664e5532c4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestJoniRegexpFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestJoniRegexpFunctions.java @@ -25,7 +25,6 @@ import static io.trino.operator.scalar.JoniRegexpFunctions.regexpReplace; import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; import static io.trino.sql.analyzer.RegexLibrary.JONI; -import static io.trino.type.LikeFunctions.likeVarchar; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -38,16 +37,6 @@ public TestJoniRegexpFunctions() super(JONI); } - @Test - public void testMatchInterruptible() - throws IOException, InterruptedException - { - String source = Resources.toString(Resources.getResource("regularExpressionExtraLongSource.txt"), UTF_8); - String pattern = "\\((.*,)+(.*\\))"; - // Test the interruptible version of `Matcher#match` by "LIKE" - testJoniRegexpFunctionsInterruptible(() -> likeVarchar(utf8Slice(source), joniRegexp(utf8Slice(pattern)))); - } - @Test public void testSearchInterruptible() throws IOException, InterruptedException diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java index 61b25126516d..e2f7d39f4842 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLikeFunctions.java @@ -15,9 +15,9 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.trino.likematcher.LikeMatcher; import io.trino.spi.TrinoException; import io.trino.spi.expression.StandardFunctions; -import io.trino.type.JoniRegexp; import io.trino.type.LikeFunctions; import org.testng.annotations.Test; @@ -59,9 +59,9 @@ public void testFunctionNameConstantsInSync() @Test public void testLikeBasic() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("f%b__")); - assertTrue(likeVarchar(utf8Slice("foobar"), regex)); - assertTrue(likeVarchar(offsetHeapSlice("foobar"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("f%b__").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("foobar"), matcher)); + assertTrue(likeVarchar(offsetHeapSlice("foobar"), matcher)); assertFunction("'foob' LIKE 'f%b__'", BOOLEAN, false); assertFunction("'foob' LIKE 'f%b'", BOOLEAN, true); @@ -81,13 +81,13 @@ public void testLikeBasic() @Test public void testLikeChar() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("f%b__")); - assertTrue(likeChar(6L, utf8Slice("foobar"), regex)); - assertTrue(likeChar(6L, offsetHeapSlice("foobar"), regex)); - assertTrue(likeChar(6L, utf8Slice("foob"), regex)); - assertTrue(likeChar(6L, offsetHeapSlice("foob"), regex)); - assertFalse(likeChar(7L, utf8Slice("foob"), regex)); - assertFalse(likeChar(7L, offsetHeapSlice("foob"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("f%b__").toStringUtf8(), Optional.empty()); + assertTrue(likeChar(6L, utf8Slice("foobar"), matcher)); + assertTrue(likeChar(6L, offsetHeapSlice("foobar"), matcher)); + assertTrue(likeChar(6L, utf8Slice("foob"), matcher)); + assertTrue(likeChar(6L, offsetHeapSlice("foob"), matcher)); + assertFalse(likeChar(7L, utf8Slice("foob"), matcher)); + assertFalse(likeChar(7L, offsetHeapSlice("foob"), matcher)); // pattern shorter than value length assertFunction("CAST('foo' AS char(6)) LIKE 'foo'", BOOLEAN, false); @@ -124,67 +124,66 @@ public void testLikeChar() @Test public void testLikeSpacesInPattern() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("ala ")); - assertTrue(likeVarchar(utf8Slice("ala "), regex)); - assertFalse(likeVarchar(utf8Slice("ala"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("ala ").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("ala "), matcher)); + assertFalse(likeVarchar(utf8Slice("ala"), matcher)); } @Test public void testLikeNewlineInPattern() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("%o\nbar")); - assertTrue(likeVarchar(utf8Slice("foo\nbar"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("%o\nbar").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("foo\nbar"), matcher)); } @Test public void testLikeNewlineBeforeMatch() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("%b%")); - assertTrue(likeVarchar(utf8Slice("foo\nbar"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("%b%").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("foo\nbar"), matcher)); } @Test public void testLikeNewlineInMatch() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("f%b%")); - assertTrue(likeVarchar(utf8Slice("foo\nbar"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("f%b%").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("foo\nbar"), matcher)); } - @Test(timeOut = 1000) + @Test public void testLikeUtf8Pattern() { - JoniRegexp regex = likePattern(utf8Slice("%\u540d\u8a89%"), utf8Slice("\\")); - assertFalse(likeVarchar(utf8Slice("foo"), regex)); + LikeMatcher matcher = likePattern(utf8Slice("%\u540d\u8a89%"), utf8Slice("\\")); + assertFalse(likeVarchar(utf8Slice("foo"), matcher)); } - @SuppressWarnings("NumericCastThatLosesPrecision") - @Test(timeOut = 1000) + @Test public void testLikeInvalidUtf8Value() { Slice value = Slices.wrappedBuffer(new byte[] {'a', 'b', 'c', (byte) 0xFF, 'x', 'y'}); - JoniRegexp regex = likePattern(utf8Slice("%b%"), utf8Slice("\\")); - assertTrue(likeVarchar(value, regex)); + LikeMatcher matcher = likePattern(utf8Slice("%b%"), utf8Slice("\\")); + assertTrue(likeVarchar(value, matcher)); } @Test public void testBackslashesNoSpecialTreatment() { - JoniRegexp regex = LikeFunctions.compileLikePattern(utf8Slice("\\abc\\/\\\\")); - assertTrue(likeVarchar(utf8Slice("\\abc\\/\\\\"), regex)); + LikeMatcher matcher = LikeMatcher.compile(utf8Slice("\\abc\\/\\\\").toStringUtf8(), Optional.empty()); + assertTrue(likeVarchar(utf8Slice("\\abc\\/\\\\"), matcher)); } @Test public void testSelfEscaping() { - JoniRegexp regex = likePattern(utf8Slice("\\\\abc\\%"), utf8Slice("\\")); - assertTrue(likeVarchar(utf8Slice("\\abc%"), regex)); + LikeMatcher matcher = likePattern(utf8Slice("\\\\abc\\%"), utf8Slice("\\")); + assertTrue(likeVarchar(utf8Slice("\\abc%"), matcher)); } @Test public void testAlternateEscapedCharacters() { - JoniRegexp regex = likePattern(utf8Slice("xxx%x_abcxx"), utf8Slice("x")); - assertTrue(likeVarchar(utf8Slice("x%_abcx"), regex)); + LikeMatcher matcher = likePattern(utf8Slice("xxx%x_abcxx"), utf8Slice("x")); + assertTrue(likeVarchar(utf8Slice("x%_abcx"), matcher)); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestExpressionCompiler.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestExpressionCompiler.java index 0290da73cbec..dd8b042be211 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestExpressionCompiler.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestExpressionCompiler.java @@ -24,6 +24,7 @@ import io.airlift.log.Logging; import io.airlift.slice.Slice; import io.airlift.units.Duration; +import io.trino.likematcher.LikeMatcher; import io.trino.operator.scalar.BitwiseFunctions; import io.trino.operator.scalar.FunctionAssertions; import io.trino.operator.scalar.JoniRegexpFunctions; @@ -49,7 +50,6 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.tree.Extract.Field; -import io.trino.type.JoniRegexp; import io.trino.type.LikeFunctions; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -1590,8 +1590,8 @@ public void testLike() for (String pattern : stringLefts) { Boolean expected = null; if (value != null && pattern != null) { - JoniRegexp regex = LikeFunctions.likePattern(utf8Slice(pattern), utf8Slice("\\")); - expected = LikeFunctions.likeVarchar(utf8Slice(value), regex); + LikeMatcher matcher = LikeFunctions.likePattern(utf8Slice(pattern), utf8Slice("\\")); + expected = LikeFunctions.likeVarchar(utf8Slice(value), matcher); } assertExecute(generateExpression("%s like %s", value, pattern), BOOLEAN, expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java index 4aadd6e5dc2a..4ba428f1d08b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java @@ -70,6 +70,7 @@ import static io.trino.type.CodePointsType.CODE_POINTS; import static io.trino.type.JoniRegexpType.JONI_REGEXP; import static io.trino.type.JsonPathType.JSON_PATH; +import static io.trino.type.LikeFunctions.likePattern; import static io.trino.type.LikePatternType.LIKE_PATTERN; import static io.trino.type.Re2JRegexpType.RE2J_REGEXP_SIGNATURE; import static io.trino.type.UnknownType.UNKNOWN; @@ -239,11 +240,20 @@ public void testEncodeTimestampWithTimeZone() @Test public void testEncodeRegex() { - assertRoundTrip(castVarcharToJoniRegexp(utf8Slice("[a-z]")), LIKE_PATTERN, (left, right) -> left.pattern().equals(right.pattern())); assertRoundTrip(castVarcharToJoniRegexp(utf8Slice("[a-z]")), JONI_REGEXP, (left, right) -> left.pattern().equals(right.pattern())); assertRoundTrip(castVarcharToRe2JRegexp(utf8Slice("[a-z]")), PLANNER_CONTEXT.getTypeManager().getType(RE2J_REGEXP_SIGNATURE), (left, right) -> left.pattern().equals(right.pattern())); } + @Test + public void testEncodeLikePattern() + { + assertRoundTrip(likePattern(utf8Slice("abc")), LIKE_PATTERN, (left, right) -> left.getPattern().equals(right.getPattern())); + assertRoundTrip(likePattern(utf8Slice("abc_")), LIKE_PATTERN, (left, right) -> left.getPattern().equals(right.getPattern())); + assertRoundTrip(likePattern(utf8Slice("abc%")), LIKE_PATTERN, (left, right) -> left.getPattern().equals(right.getPattern())); + + assertRoundTrip(likePattern(utf8Slice("a_b%cX%X_"), utf8Slice("/")), LIKE_PATTERN, (left, right) -> left.getPattern().equals(right.getPattern())); + } + @Test public void testEncodeJsonPath() { diff --git a/core/trino-main/src/test/java/io/trino/type/TestLikePatternType.java b/core/trino-main/src/test/java/io/trino/type/TestLikePatternType.java new file mode 100644 index 000000000000..5516ac528932 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/type/TestLikePatternType.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.trino.likematcher.LikeMatcher; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.PageBuilderStatus; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.type.LikePatternType.LIKE_PATTERN; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestLikePatternType +{ + @Test + public void testGetObject() + { + BlockBuilder blockBuilder = LIKE_PATTERN.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 10); + LIKE_PATTERN.writeObject(blockBuilder, LikeMatcher.compile("helloX_world", Optional.of('X'))); + LIKE_PATTERN.writeObject(blockBuilder, LikeMatcher.compile("foo%_bar")); + Block block = blockBuilder.build(); + + LikeMatcher pattern = (LikeMatcher) LIKE_PATTERN.getObject(block, 0); + assertThat(pattern.getPattern()).isEqualTo("helloX_world"); + assertThat(pattern.getEscape()).isEqualTo(Optional.of('X')); + + pattern = (LikeMatcher) LIKE_PATTERN.getObject(block, 1); + assertThat(pattern.getPattern()).isEqualTo("foo%_bar"); + assertThat(pattern.getEscape()).isEqualTo(Optional.empty()); + } +}