diff --git a/core/trino-main/src/main/java/io/trino/likematcher/DFA.java b/core/trino-main/src/main/java/io/trino/likematcher/DFA.java index 79aaebc9ede6..2eec9c32dda8 100644 --- a/core/trino-main/src/main/java/io/trino/likematcher/DFA.java +++ b/core/trino-main/src/main/java/io/trino/likematcher/DFA.java @@ -14,45 +14,23 @@ package io.trino.likematcher; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -record DFA(State start, State failed, List states, Map> transitions) +record DFA(int start, IntArrayList acceptStates, List> transitions) { DFA { - requireNonNull(start, "start is null"); - requireNonNull(failed, "failed is null"); - states = ImmutableList.copyOf(states); - transitions = ImmutableMap.copyOf(transitions); + requireNonNull(acceptStates, "acceptStates is null"); + transitions = ImmutableList.copyOf(transitions); } - public List transitions(State state) - { - return transitions.get(state.id); - } - - record State(int id, String label, boolean accept) - { - @Override - public String toString() - { - return "%s:%s%s".formatted( - id, - accept ? "*" : "", - label); - } - } - - record Transition(int value, State target) + record Transition(int value, int target) { @Override public String toString() @@ -64,43 +42,34 @@ public String toString() public static class Builder { private int nextId; - private State start; - private State failed; - private final List states = new ArrayList<>(); - private final Map> transitions = new HashMap<>(); + private int start; + private final IntArrayList acceptStates = new IntArrayList(); + private final List> transitions = new ArrayList<>(); - public State addState(String label, boolean accept) + public int addState(boolean accept) { - State state = new State(nextId++, label, accept); - states.add(state); + int state = nextId++; + transitions.add(new ArrayList<>()); + if (accept) { + acceptStates.add(state); + } return state; } - public State addStartState(String label, boolean accept) + public int addStartState(boolean accept) { - checkState(start == null, "Start state already set"); - State state = addState(label, accept); - start = state; - return state; - } - - public State addFailState() - { - checkState(failed == null, "Fail state already set"); - State state = addState("fail", false); - failed = state; - return state; + start = addState(accept); + return start; } - public void addTransition(State from, int value, State to) + public void addTransition(int from, int value, int to) { - transitions.computeIfAbsent(from.id(), key -> new ArrayList<>()) - .add(new Transition(value, to)); + transitions.get(from).add(new Transition(value, to)); } public DFA build() { - return new DFA(start, failed, states, transitions); + return new DFA(start, acceptStates, transitions); } } } diff --git a/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java b/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java index ab70e14e7896..d5aa193fbabf 100644 --- a/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java +++ b/core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java @@ -13,102 +13,204 @@ */ package io.trino.likematcher; +import java.util.Arrays; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + class DenseDfaMatcher + implements Matcher { - // The DFA is encoded as a sequence of transitions for each possible byte value for each state. - // I.e., 256 transitions per state. - // The content of the transitions array is the base offset into - // the next state to follow. I.e., the desired state * 256 - private final int[] transitions; + public static final int FAIL_STATE = -1; - // The starting state + private final List pattern; private final int start; - - // For each state, whether it's an accepting state - private final boolean[] accept; - - // Artificial state to sink all invalid matches - private final int fail; - + private final int end; private final boolean exact; - /** - * @param exact whether to match to the end of the input - */ - public static DenseDfaMatcher newInstance(DFA dfa, boolean exact) - { - int[] transitions = new int[dfa.states().size() * 256]; - boolean[] accept = new boolean[dfa.states().size()]; - - for (DFA.State state : dfa.states()) { - for (DFA.Transition transition : dfa.transitions(state)) { - transitions[state.id() * 256 + transition.value()] = transition.target().id() * 256; - } + private volatile DenseDfa matcher; - if (state.accept()) { - accept[state.id()] = true; - } - } - - return new DenseDfaMatcher(transitions, dfa.start().id(), accept, 0, exact); - } - - private DenseDfaMatcher(int[] transitions, int start, boolean[] accept, int fail, boolean exact) + public DenseDfaMatcher(List pattern, int start, int end, boolean exact) { - this.transitions = transitions; + this.pattern = requireNonNull(pattern, "pattern is null"); this.start = start; - this.accept = accept; - this.fail = fail; + this.end = end; this.exact = exact; } + @Override public boolean match(byte[] input, int offset, int length) { + DenseDfa matcher = this.matcher; + if (matcher == null) { + matcher = DenseDfa.newInstance(pattern, start, end); + this.matcher = matcher; + } + if (exact) { - return exactMatch(input, offset, length); + return matcher.exactMatch(input, offset, length); } - return prefixMatch(input, offset, length); + return matcher.prefixMatch(input, offset, length); } - /** - * Returns a positive match when the final state after all input has been consumed is an accepting state - */ - private boolean exactMatch(byte[] input, int offset, int length) + private static class DenseDfa { - int state = start << 8; - for (int i = offset; i < offset + length; i++) { - byte inputByte = input[i]; - state = transitions[state | (inputByte & 0xFF)]; + // The DFA is encoded as a sequence of transitions for each possible byte value for each state. + // I.e., 256 transitions per state. + // The content of the transitions array is the base offset into + // the next state to follow. I.e., the desired state * 256 + private final int[] transitions; + + // The starting state + private final int start; + + // For each state, whether it's an accepting state + private final boolean[] accept; - if (state == fail) { - return false; + public static DenseDfa newInstance(List pattern, int start, int end) + { + DFA dfa = makeNfa(pattern, start, end).toDfa(); + + int[] transitions = new int[dfa.transitions().size() * 256]; + Arrays.fill(transitions, FAIL_STATE); + + for (int state = 0; state < dfa.transitions().size(); state++) { + for (DFA.Transition transition : dfa.transitions().get(state)) { + transitions[state * 256 + transition.value()] = transition.target() * 256; + } + } + boolean[] accept = new boolean[dfa.transitions().size()]; + for (int state : dfa.acceptStates()) { + accept[state] = true; } + + return new DenseDfa(transitions, dfa.start(), accept); } - return accept[state >>> 8]; - } + private DenseDfa(int[] transitions, int start, boolean[] accept) + { + this.transitions = transitions; + this.start = start; + this.accept = accept; + } - /** - * Returns a positive match as soon as the DFA reaches an accepting state, regardless of whether - * the whole input has been consumed - */ - private boolean prefixMatch(byte[] input, int offset, int length) - { - int state = start << 8; - for (int i = offset; i < offset + length; i++) { - byte inputByte = input[i]; - state = transitions[state | (inputByte & 0xFF)]; + /** + * Returns a positive match when the final state after all input has been consumed is an accepting state + */ + public boolean exactMatch(byte[] input, int offset, int length) + { + int state = start << 8; + for (int i = offset; i < offset + length; i++) { + byte inputByte = input[i]; + state = transitions[state | (inputByte & 0xFF)]; + + if (state == FAIL_STATE) { + return false; + } + } - if (state == fail) { - return false; + return accept[state >>> 8]; + } + + /** + * Returns a positive match as soon as the DFA reaches an accepting state, regardless of whether + * the whole input has been consumed + */ + public boolean prefixMatch(byte[] input, int offset, int length) + { + int state = start << 8; + for (int i = offset; i < offset + length; i++) { + byte inputByte = input[i]; + state = transitions[state | (inputByte & 0xFF)]; + + if (state == FAIL_STATE) { + return false; + } + + if (accept[state >>> 8]) { + return true; + } } - if (accept[state >>> 8]) { - return true; + return accept[state >>> 8]; + } + + private static NFA makeNfa(List pattern, int start, int end) + { + checkArgument(!pattern.isEmpty(), "pattern is empty"); + + NFA.Builder builder = new NFA.Builder(); + + int state = builder.addStartState(); + + for (int e = start; e <= end; e++) { + Pattern item = pattern.get(e); + if (item instanceof Pattern.Literal literal) { + for (byte current : literal.value().getBytes(UTF_8)) { + state = matchByte(builder, state, current); + } + } + else if (item instanceof Pattern.Any any) { + for (int i = 0; i < any.length(); i++) { + int next = builder.addState(); + matchSingleUtf8(builder, state, next); + state = next; + } + } + else if (item instanceof Pattern.ZeroOrMore) { + matchSingleUtf8(builder, state, state); + } + else { + throw new UnsupportedOperationException("Not supported: " + item.getClass().getName()); + } } + + builder.setAccept(state); + + return builder.build(); + } + + private static int matchByte(NFA.Builder builder, int state, byte value) + { + int next = builder.addState(); + builder.addTransition(state, new NFA.Value(value), next); + return next; } - return accept[state >>> 8]; + private static void matchSingleUtf8(NFA.Builder builder, int from, int to) + { + /* + Implements a state machine to recognize UTF-8 characters. + + 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + O ───────────► O ───────────► O ───────────► O ───────────► O + │ ▲ ▲ ▲ + ├─────────────────────────────┘ │ │ + │ 1110xxxx │ │ + │ │ │ + ├────────────────────────────────────────────┘ │ + │ 110xxxxx │ + │ │ + └───────────────────────────────────────────────────────────┘ + 0xxxxxxx + */ + + builder.addTransition(from, new NFA.Prefix(0, 1), to); + + int state1 = builder.addState(); + int state2 = builder.addState(); + int state3 = builder.addState(); + + builder.addTransition(from, new NFA.Prefix(0b11110, 5), state1); + builder.addTransition(from, new NFA.Prefix(0b1110, 4), state2); + builder.addTransition(from, new NFA.Prefix(0b110, 3), state3); + + builder.addTransition(state1, new NFA.Prefix(0b10, 2), state2); + builder.addTransition(state2, new NFA.Prefix(0b10, 2), state3); + builder.addTransition(state3, new NFA.Prefix(0b10, 2), to); + } } } diff --git a/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java b/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java index 560f04ddce8f..575cb25f2d7b 100644 --- a/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java +++ b/core/trino-main/src/main/java/io/trino/likematcher/LikeMatcher.java @@ -21,7 +21,6 @@ import java.util.Optional; import java.util.OptionalInt; -import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.UTF_8; public class LikeMatcher @@ -33,7 +32,7 @@ public class LikeMatcher private final OptionalInt maxSize; private final byte[] prefix; private final byte[] suffix; - private final Optional matcher; + private final Optional matcher; private LikeMatcher( String pattern, @@ -42,7 +41,7 @@ private LikeMatcher( OptionalInt maxSize, byte[] prefix, byte[] suffix, - Optional matcher) + Optional matcher) { this.pattern = pattern; this.escape = escape; @@ -65,13 +64,17 @@ public Optional getEscape() public static LikeMatcher compile(String pattern) { - return compile(pattern, Optional.empty()); + return compile(pattern, Optional.empty(), true); } public static LikeMatcher compile(String pattern, Optional escape) + { + return compile(pattern, escape, true); + } + + public static LikeMatcher compile(String pattern, Optional escape, boolean optimize) { List parsed = parse(pattern, escape); - List optimized = optimize(parsed); // Calculate minimum and maximum size for candidate strings // This is used for short-circuiting the match if the size of @@ -79,18 +82,19 @@ public static LikeMatcher compile(String pattern, Optional escape) int minSize = 0; int maxSize = 0; boolean unbounded = false; - for (Pattern expression : optimized) { + for (Pattern expression : parsed) { if (expression instanceof Literal literal) { int length = literal.value().getBytes(UTF_8).length; minSize += length; maxSize += length; } + else if (expression instanceof Pattern.ZeroOrMore) { + unbounded = true; + } else if (expression instanceof Any any) { - int length = any.min(); + int length = any.length(); minSize += length; maxSize += length * 4; // at most 4 bytes for a single UTF-8 codepoint - - unbounded = unbounded || any.unbounded(); } else { throw new UnsupportedOperationException("Not supported: " + expression.getClass().getName()); @@ -102,24 +106,17 @@ else if (expression instanceof Any any) { // exact match to short-circuit DFA evaluation byte[] prefix = new byte[0]; byte[] suffix = new byte[0]; - List middle = new ArrayList<>(); - for (int i = 0; i < optimized.size(); i++) { - Pattern expression = optimized.get(i); - - if (i == 0) { - if (expression instanceof Literal literal) { - prefix = literal.value().getBytes(UTF_8); - continue; - } - } - else if (i == optimized.size() - 1) { - if (expression instanceof Literal literal) { - suffix = literal.value().getBytes(UTF_8); - continue; - } - } - middle.add(expression); + int patternStart = 0; + int patternEnd = parsed.size() - 1; + if (parsed.size() > 0 && parsed.get(0) instanceof Literal literal) { + prefix = literal.value().getBytes(UTF_8); + patternStart++; + } + + if (parsed.size() > 1 && parsed.get(parsed.size() - 1) instanceof Literal literal) { + suffix = literal.value().getBytes(UTF_8); + patternEnd--; } // If the pattern (after excluding constant prefix/suffixes) ends with an unbounded match (i.e., %) @@ -127,26 +124,20 @@ else if (i == optimized.size() - 1) { // is no need to consume the remaining input // This section determines whether the pattern is a candidate for non-exact match. boolean exact = true; // whether to match to the end of the input - if (!middle.isEmpty()) { - // guaranteed to be Any because any Literal would've been turned into a suffix above - Any last = (Any) middle.get(middle.size() - 1); - if (last.unbounded()) { - exact = false; - - // Since the matcher will stop early, no need for an unbounded matcher (it produces a simpler DFA) - if (last.min() == 0) { - // We'd end up with an empty string match at the end, so just remove it - middle.remove(middle.size() - 1); - } - else { - middle.set(middle.size() - 1, new Any(last.min(), false)); - } - } + if (patternStart <= patternEnd && parsed.get(patternEnd) instanceof Pattern.ZeroOrMore) { + // guaranteed to be Any or ZeroOrMore because any Literal would've been turned into a suffix above + exact = false; + patternEnd--; } - Optional matcher = Optional.empty(); - if (!middle.isEmpty()) { - matcher = Optional.of(DenseDfaMatcher.newInstance(makeNfa(middle).toDfa(), exact)); + Optional matcher = Optional.empty(); + if (patternStart <= patternEnd) { + if (optimize) { + matcher = Optional.of(new DenseDfaMatcher(parsed, patternStart, patternEnd, exact)); + } + else { + matcher = Optional.of(new NfaMatcher(parsed, patternStart, patternEnd, exact)); + } } return new LikeMatcher( @@ -200,11 +191,13 @@ private boolean startsWith(byte[] pattern, byte[] input, int offset) return true; } - private static List parse(String pattern, Optional escape) + static List parse(String pattern, Optional escape) { List result = new ArrayList<>(); StringBuilder literal = new StringBuilder(); + int anyCount = 0; + boolean anyUnbounded = false; boolean inEscape = false; for (int i = 0; i < pattern.length(); i++) { char character = pattern.charAt(i); @@ -213,26 +206,47 @@ private static List parse(String pattern, Optional escape) if (character != '%' && character != '_' && character != escape.get()) { throw new IllegalArgumentException("Escape character must be followed by '%', '_' or the escape character itself"); } + literal.append(character); inEscape = false; } else if (escape.isPresent() && character == escape.get()) { inEscape = true; + + if (anyCount != 0) { + result.add(new Any(anyCount)); + anyCount = 0; + } + + if (anyUnbounded) { + result.add(new Pattern.ZeroOrMore()); + anyUnbounded = false; + } } else if (character == '%' || character == '_') { if (literal.length() != 0) { result.add(new Literal(literal.toString())); - literal = new StringBuilder(); + literal.setLength(0); } if (character == '%') { - result.add(new Any(0, true)); + anyUnbounded = true; } else { - result.add(new Any(1, false)); + anyCount++; } } else { + if (anyCount != 0) { + result.add(new Any(anyCount)); + anyCount = 0; + } + + if (anyUnbounded) { + result.add(new Pattern.ZeroOrMore()); + anyUnbounded = false; + } + literal.append(character); } } @@ -244,143 +258,16 @@ else if (character == '%' || character == '_') { if (literal.length() != 0) { result.add(new Literal(literal.toString())); } - - return result; - } - - private static List optimize(List pattern) - { - if (pattern.isEmpty()) { - return pattern; - } - - List result = new ArrayList<>(); - - int anyPatternStart = -1; - for (int i = 0; i < pattern.size(); i++) { - Pattern current = pattern.get(i); - - if (anyPatternStart == -1 && current instanceof Any) { - anyPatternStart = i; + else { + if (anyCount != 0) { + result.add(new Any(anyCount)); } - else if (current instanceof Literal) { - if (anyPatternStart != -1) { - result.add(collapse(pattern, anyPatternStart, i)); - } - result.add(current); - anyPatternStart = -1; + if (anyUnbounded) { + result.add(new Pattern.ZeroOrMore()); } } - if (anyPatternStart != -1) { - result.add(collapse(pattern, anyPatternStart, pattern.size())); - } - return result; } - - /** - * Collapses a sequence of consecutive Any items - */ - private static Any collapse(List pattern, int start, int end) - { - int min = 0; - boolean unbounded = false; - - for (int i = start; i < end; i++) { - Any any = (Any) pattern.get(i); - - min += any.min(); - unbounded = unbounded || any.unbounded(); - } - - return new Any(min, unbounded); - } - - private static NFA makeNfa(List pattern) - { - checkArgument(!pattern.isEmpty(), "pattern is empty"); - - NFA.Builder builder = new NFA.Builder(); - - NFA.State state = builder.addStartState(); - - for (Pattern item : pattern) { - if (item instanceof Literal literal) { - for (byte current : literal.value().getBytes(UTF_8)) { - state = matchByte(builder, state, current); - } - } - else if (item instanceof Any any) { - NFA.State previous; - int i = 0; - do { - previous = state; - state = matchSingleUtf8(builder, state); - i++; - } - while (i < any.min()); - - if (any.min() == 0) { - builder.addTransition(previous, new NFA.Epsilon(), state); - } - - if (any.unbounded()) { - builder.addTransition(state, new NFA.Epsilon(), previous); - } - } - else { - throw new UnsupportedOperationException("Not supported: " + item.getClass().getName()); - } - } - - builder.setAccept(state); - - return builder.build(); - } - - private static NFA.State matchByte(NFA.Builder builder, NFA.State state, byte value) - { - NFA.State next = builder.addState(); - builder.addTransition(state, new NFA.Value(value), next); - return next; - } - - private static NFA.State matchSingleUtf8(NFA.Builder builder, NFA.State start) - { - /* - Implements a state machine to recognize UTF-8 characters. - - 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - O ───────────► O ───────────► O ───────────► O ───────────► O - │ ▲ ▲ ▲ - ├─────────────────────────────┘ │ │ - │ 1110xxxx │ │ - │ │ │ - ├────────────────────────────────────────────┘ │ - │ 110xxxxx │ - │ │ - └───────────────────────────────────────────────────────────┘ - 0xxxxxxx - */ - - NFA.State next = builder.addState(); - - builder.addTransition(start, new NFA.Prefix(0, 1), next); - - NFA.State state1 = builder.addState(); - NFA.State state2 = builder.addState(); - NFA.State state3 = builder.addState(); - - builder.addTransition(start, new NFA.Prefix(0b11110, 5), state1); - builder.addTransition(start, new NFA.Prefix(0b1110, 4), state2); - builder.addTransition(start, new NFA.Prefix(0b110, 3), state3); - - builder.addTransition(state1, new NFA.Prefix(0b10, 2), state2); - builder.addTransition(state2, new NFA.Prefix(0b10, 2), state3); - builder.addTransition(state3, new NFA.Prefix(0b10, 2), next); - - return next; - } } diff --git a/core/trino-main/src/main/java/io/trino/likematcher/Matcher.java b/core/trino-main/src/main/java/io/trino/likematcher/Matcher.java new file mode 100644 index 000000000000..1ca657cad848 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/Matcher.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +public interface Matcher +{ + boolean match(byte[] input, int offset, int length); +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/NFA.java b/core/trino-main/src/main/java/io/trino/likematcher/NFA.java index 70316f2eb79d..f06e954a389a 100644 --- a/core/trino-main/src/main/java/io/trino/likematcher/NFA.java +++ b/core/trino-main/src/main/java/io/trino/likematcher/NFA.java @@ -13,8 +13,8 @@ */ package io.trino.likematcher; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import it.unimi.dsi.fastutil.ints.IntArraySet; +import it.unimi.dsi.fastutil.ints.IntSet; import java.util.ArrayDeque; import java.util.ArrayList; @@ -24,40 +24,39 @@ import java.util.Map; import java.util.Queue; import java.util.Set; -import java.util.stream.Collectors; -import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; -record NFA(State start, State accept, List states, Map> transitions) +final class NFA { - NFA { - requireNonNull(start, "start is null"); - requireNonNull(accept, "accept is null"); - states = ImmutableList.copyOf(states); - transitions = ImmutableMap.copyOf(transitions); + private final int start; + private final int accept; + private final List> transitions; + + private NFA(int start, int accept, List> transitions) + { + this.start = start; + this.accept = accept; + this.transitions = requireNonNull(transitions, "transitions is null"); } public DFA toDfa() { - Map, DFA.State> activeStates = new HashMap<>(); + Map activeStates = new HashMap<>(); DFA.Builder builder = new DFA.Builder(); - DFA.State failed = builder.addFailState(); - for (int i = 0; i < 256; i++) { - builder.addTransition(failed, i, failed); - } - Set initial = transitiveClosure(Set.of(this.start)); - Queue> queue = new ArrayDeque<>(); + IntSet initial = new IntArraySet(); + initial.add(start); + Queue queue = new ArrayDeque<>(); queue.add(initial); - DFA.State dfaStartState = builder.addStartState(makeLabel(initial), initial.contains(accept)); + int dfaStartState = builder.addStartState(initial.contains(accept)); activeStates.put(initial, dfaStartState); - Set> visited = new HashSet<>(); + Set visited = new HashSet<>(); while (!queue.isEmpty()) { - Set current = queue.poll(); + IntSet current = queue.poll(); if (!visited.add(current)) { continue; @@ -65,11 +64,11 @@ public DFA toDfa() // For each possible byte value... for (int byteValue = 0; byteValue < 256; byteValue++) { - Set next = new HashSet<>(); - for (NFA.State nfaState : current) { + IntSet next = new IntArraySet(); + for (int nfaState : current) { for (Transition transition : transitions(nfaState)) { Condition condition = transition.condition(); - State target = states.get(transition.target()); + int target = transition.target(); if (condition instanceof Value valueTransition && valueTransition.value() == (byte) byteValue) { next.add(target); @@ -82,122 +81,66 @@ else if (condition instanceof Prefix prefixTransition) { } } - DFA.State from = activeStates.get(current); - DFA.State to = failed; if (!next.isEmpty()) { - Set closure = transitiveClosure(next); - to = activeStates.computeIfAbsent(closure, nfaStates -> builder.addState(makeLabel(nfaStates), nfaStates.contains(accept))); - queue.add(closure); + int from = activeStates.get(current); + int to = activeStates.computeIfAbsent(next, nfaStates -> builder.addState(nfaStates.contains(accept))); + builder.addTransition(from, byteValue, to); + + queue.add(next); } - builder.addTransition(from, byteValue, to); } } return builder.build(); } - private List transitions(State state) - { - return transitions.getOrDefault(state.id(), ImmutableList.of()); - } - - /** - * Traverse epsilon transitions to compute the reachable set of states - */ - private Set transitiveClosure(Set states) - { - Set result = new HashSet<>(); - - Queue queue = new ArrayDeque<>(states); - while (!queue.isEmpty()) { - State state = queue.poll(); - - if (result.contains(state)) { - continue; - } - - transitions(state).stream() - .filter(transition -> transition.condition() instanceof Epsilon) - .forEach(transition -> { - State target = this.states.get(transition.target()); - result.add(target); - queue.add(target); - }); - } - - result.addAll(states); - - return result; - } - - private String makeLabel(Set states) + private List transitions(int state) { - return "{" + states.stream() - .map(NFA.State::id) - .map(Object::toString) - .sorted() - .collect(Collectors.joining(",")) + "}"; + return transitions.get(state); } public static class Builder { private int nextId; - private State start; - private State accept; - private final List states = new ArrayList<>(); - private final Map> transitions = new HashMap<>(); + private int start; + private int accept; + private final List> transitions = new ArrayList<>(); - public State addState() + public int addState() { - State state = new State(nextId++); - states.add(state); - return state; + transitions.add(new ArrayList<>()); + return nextId++; } - public State addStartState() + public int addStartState() { - checkState(start == null, "Start state is already set"); start = addState(); return start; } - public void setAccept(State state) + public void setAccept(int state) { - checkState(accept == null, "Accept state is already set"); accept = state; } - public void addTransition(State from, Condition condition, State to) + public void addTransition(int from, Condition condition, int to) { - transitions.computeIfAbsent(from.id(), key -> new ArrayList<>()) - .add(new Transition(to.id(), condition)); + transitions.get(from).add(new Transition(to, condition)); } public NFA build() { - return new NFA(start, accept, states, transitions); - } - } - - public record State(int id) - { - @Override - public String toString() - { - return "(" + id + ")"; + return new NFA(start, accept, transitions); } } record Transition(int target, Condition condition) {} sealed interface Condition - permits Epsilon, Value, Prefix + permits Value, Prefix { } - record Epsilon() - implements Condition {} - record Value(byte value) implements Condition {} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/NfaMatcher.java b/core/trino-main/src/main/java/io/trino/likematcher/NfaMatcher.java new file mode 100644 index 000000000000..c5beff515924 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/likematcher/NfaMatcher.java @@ -0,0 +1,163 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.likematcher; + +import java.util.Arrays; +import java.util.List; + +final class NfaMatcher + implements Matcher +{ + private static final int ANY = -1; + private static final int NONE = -2; + private static final int INVALID_CODEPOINT = -1; + + private final boolean exact; + + private final boolean[] loopback; + private final int[] match; + private final int acceptState; + private final int stateCount; + + public NfaMatcher(List pattern, int start, int end, boolean exact) + { + this.exact = exact; + + stateCount = calculateStateCount(pattern, start, end); + + loopback = new boolean[stateCount]; + match = new int[stateCount]; + Arrays.fill(match, NONE); + acceptState = stateCount - 1; + + int state = 0; + for (int j = start; j <= end; j++) { + Pattern element = pattern.get(j); + if (element instanceof Pattern.Literal literal) { + for (int i = 0; i < literal.value().length(); i++) { + match[state++] = literal.value().charAt(i); + } + } + else if (element instanceof Pattern.Any any) { + for (int i = 0; i < any.length(); i++) { + match[state++] = ANY; + } + } + else if (element instanceof Pattern.ZeroOrMore) { + loopback[state] = true; + } + } + } + + private static int calculateStateCount(List pattern, int start, int end) + { + int states = 1; + for (int i = start; i <= end; i++) { + Pattern element = pattern.get(i); + if (element instanceof Pattern.Literal literal) { + states += literal.value().length(); + } + else if (element instanceof Pattern.Any any) { + states += any.length(); + } + } + return states; + } + + @Override + public boolean match(byte[] input, int offset, int length) + { + boolean[] seen = new boolean[stateCount + 1]; + int[] currentStates = new int[stateCount]; + int[] nextStates = new int[stateCount]; + int currentStatesIndex = 0; + int nextStatesIndex; + + currentStates[currentStatesIndex++] = 0; + + int limit = offset + length; + int current = offset; + boolean accept = false; + while (current < limit) { + int codepoint = INVALID_CODEPOINT; + + // decode the next UTF-8 codepoint + int header = input[current] & 0xFF; + if (header < 0x80) { + // normal ASCII + // 0xxx_xxxx + codepoint = header; + current++; + } + else if ((header & 0b1110_0000) == 0b1100_0000) { + // 110x_xxxx 10xx_xxxx + if (current + 1 < limit) { + codepoint = ((header & 0b0001_1111) << 6) | (input[current + 1] & 0b0011_1111); + current += 2; + } + } + else if ((header & 0b1111_0000) == 0b1110_0000) { + // 1110_xxxx 10xx_xxxx 10xx_xxxx + if (current + 2 < limit) { + codepoint = ((header & 0b0000_1111) << 12) | ((input[current + 1] & 0b0011_1111) << 6) | (input[current + 2] & 0b0011_1111); + current += 3; + } + } + else if ((header & 0b1111_1000) == 0b1111_0000) { + // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx + if (current + 3 < limit) { + codepoint = ((header & 0b0000_0111) << 18) | ((input[current + 1] & 0b0011_1111) << 12) | ((input[current + 2] & 0b0011_1111) << 6) | (input[current + 3] & 0b0011_1111); + current += 4; + } + } + + if (codepoint == INVALID_CODEPOINT) { + return false; + } + + accept = false; + nextStatesIndex = 0; + Arrays.fill(seen, false); + for (int i = 0; i < currentStatesIndex; i++) { + int state = currentStates[i]; + if (!seen[state] && loopback[state]) { + nextStates[nextStatesIndex++] = state; + accept |= state == acceptState; + seen[state] = true; + } + int next = state + 1; + if (!seen[next] && (match[state] == ANY || match[state] == codepoint)) { + nextStates[nextStatesIndex++] = next; + accept |= next == acceptState; + seen[next] = true; + } + } + + if (nextStatesIndex == 0) { + return false; + } + + if (!exact && accept) { + return true; + } + + int[] tmp = currentStates; + currentStates = nextStates; + nextStates = tmp; + currentStatesIndex = nextStatesIndex; + } + + return accept; + } +} diff --git a/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java b/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java index dbf6cea13cbd..cd9b6947539c 100644 --- a/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java +++ b/core/trino-main/src/main/java/io/trino/likematcher/Pattern.java @@ -13,11 +13,12 @@ */ package io.trino.likematcher; +import com.google.common.base.Strings; + import static com.google.common.base.Preconditions.checkArgument; -import static java.lang.String.format; sealed interface Pattern - permits Pattern.Any, Pattern.Literal + permits Pattern.Any, Pattern.Literal, Pattern.ZeroOrMore { record Literal(String value) implements Pattern @@ -29,18 +30,28 @@ public String toString() } } - record Any(int min, boolean unbounded) + record ZeroOrMore() + implements Pattern + { + @Override + public String toString() + { + return "%"; + } + } + + record Any(int length) implements Pattern { public Any { - checkArgument(min > 0 || unbounded, "Any must be unbounded or require at least 1 character"); + checkArgument(length > 0, "Length must be > 0"); } @Override public String toString() { - return format("{%s%s}", min, unbounded ? "+" : ""); + return Strings.repeat("_", length); } } } diff --git a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java index 96936c9b27d6..839bbcedc50d 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java +++ b/core/trino-main/src/main/java/io/trino/type/LikeFunctions.java @@ -61,7 +61,7 @@ public static boolean likeVarchar(@SqlType("varchar") Slice value, @SqlType(Like @SqlType(LikePatternType.NAME) public static LikeMatcher likePattern(@SqlType("varchar") Slice pattern) { - return LikeMatcher.compile(pattern.toStringUtf8(), Optional.empty()); + return LikeMatcher.compile(pattern.toStringUtf8(), Optional.empty(), false); } @ScalarFunction(value = LIKE_PATTERN_FUNCTION_NAME, hidden = true) @@ -69,7 +69,7 @@ public static LikeMatcher likePattern(@SqlType("varchar") Slice pattern) public static LikeMatcher likePattern(@SqlType("varchar") Slice pattern, @SqlType("varchar") Slice escape) { try { - return LikeMatcher.compile(pattern.toStringUtf8(), getEscapeCharacter(Optional.of(escape))); + return LikeMatcher.compile(pattern.toStringUtf8(), getEscapeCharacter(Optional.of(escape)), false); } catch (RuntimeException e) { throw new TrinoException(INVALID_FUNCTION_ARGUMENT, e); diff --git a/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java b/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java index 539e77279258..a02b21aa0ac3 100644 --- a/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java +++ b/core/trino-main/src/test/java/io/trino/likematcher/TestLikeMatcher.java @@ -74,13 +74,23 @@ public void test() // optimization of consecutive _ and % assertTrue(match("_%_%_%_%", "abcdefghij")); + assertTrue(match("%a%a%a%a%a%a%", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); + assertTrue(match("%a%a%a%a%a%a%", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab")); + assertTrue(match("%a%b%a%b%a%b%", "aabbaabbaabbaabbaabbaabbaabbaabbaabbaabbaabbaabbaabbaabb")); + assertTrue(match("%aaaa%bbbb%aaaa%bbbb%aaaa%bbbb%", "aaaabbbbaaaabbbbaaaabbbb")); + assertTrue(match("%aaaaaaaaaaaaaaaaaaaaaaaaaa%", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); + // utf-8 - LikeMatcher single = LikeMatcher.compile("_"); - LikeMatcher multiple = LikeMatcher.compile("_a%b_"); // prefix and suffix with _a and b_ to avoid optimizations + LikeMatcher singleOptimized = LikeMatcher.compile("_", Optional.empty(), true); + LikeMatcher multipleOptimized = LikeMatcher.compile("_a%b_", Optional.empty(), true); // prefix and suffix with _a and b_ to avoid optimizations + LikeMatcher single = LikeMatcher.compile("_", Optional.empty(), false); + LikeMatcher multiple = LikeMatcher.compile("_a%b_", Optional.empty(), false); // prefix and suffix with _a and b_ to avoid optimizations for (int i = 0; i < Character.MAX_CODE_POINT; i++) { + assertTrue(singleOptimized.match(Character.toString(i).getBytes(StandardCharsets.UTF_8))); assertTrue(single.match(Character.toString(i).getBytes(StandardCharsets.UTF_8))); String value = "aa" + (char) i + "bb"; + assertTrue(multipleOptimized.match(value.getBytes(StandardCharsets.UTF_8))); assertTrue(multiple.match(value.getBytes(StandardCharsets.UTF_8))); } } @@ -91,6 +101,8 @@ public void testEscape() assertTrue(match("-%", "%", '-')); assertTrue(match("-_", "_", '-')); assertTrue(match("--", "-", '-')); + + assertTrue(match("%$_%", "xxxxx_xxxxx", '$')); } private static boolean match(String pattern, String value) @@ -109,10 +121,17 @@ private static boolean match(String pattern, String value, Optional e String padded = padding + value + padding; byte[] bytes = padded.getBytes(StandardCharsets.UTF_8); - boolean withoutPadding = LikeMatcher.compile(pattern, escape).match(value.getBytes(StandardCharsets.UTF_8)); - boolean withPadding = LikeMatcher.compile(pattern, escape).match(bytes, padding.length(), bytes.length - padding.length() * 2); // exclude padding + boolean optimizedWithoutPadding = LikeMatcher.compile(pattern, escape, true).match(value.getBytes(StandardCharsets.UTF_8)); + + boolean optimizedWithPadding = LikeMatcher.compile(pattern, escape, true).match(bytes, padding.length(), bytes.length - padding.length() * 2); // exclude padding + assertEquals(optimizedWithoutPadding, optimizedWithPadding); + + boolean withoutPadding = LikeMatcher.compile(pattern, escape, false).match(value.getBytes(StandardCharsets.UTF_8)); + assertEquals(optimizedWithoutPadding, withoutPadding); + + boolean withPadding = LikeMatcher.compile(pattern, escape, false).match(bytes, padding.length(), bytes.length - padding.length() * 2); // exclude padding + assertEquals(optimizedWithoutPadding, withPadding); - assertEquals(withoutPadding, withPadding); return withPadding; } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java index 778fe40f676c..7aad0cfce2dd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkLike.java @@ -85,13 +85,16 @@ public static class Data "_____", "abc%def%ghi", "%abc%def%", + "%a%a%a%a%", + "%aaaaaaaaaaaaaaaaaaaaaaaaaa%" }) private String pattern; private Slice data; private byte[] bytes; private JoniRegexp joniPattern; - private LikeMatcher matcher; + private LikeMatcher dfaMatcher; + private LikeMatcher nfaMatcher; @Setup public void setup() @@ -105,10 +108,13 @@ public void setup() case "_____" -> "abcde"; case "abc%def%ghi" -> "abc qeroighqeorhgqerhb2eriuyerqiubgier def ubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet ghi"; case "%abc%def%" -> "fdnbqerbfklerqbgqjerbgkr abc qeroighqeorhgqerhb2eriuyerqiubgier def ubgleuqrbgilquebriuqebryqebrhqerhqsnajkbcowuhet"; + case "%a%a%a%a%" -> "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + case "%aaaaaaaaaaaaaaaaaaaaaaaaaa%" -> "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; default -> throw new IllegalArgumentException("Unknown pattern: " + pattern); }); - matcher = LikeMatcher.compile(pattern, Optional.empty()); + dfaMatcher = LikeMatcher.compile(pattern, Optional.empty(), true); + nfaMatcher = LikeMatcher.compile(pattern, Optional.empty(), false); joniPattern = compileJoni(Slices.utf8Slice(pattern).toStringUtf8(), '0', false); bytes = data.getBytes(); @@ -116,15 +122,59 @@ public void setup() } @Benchmark - public boolean benchmarkJoni(Data data) + public boolean matchJoni(Data data) { return likeVarchar(data.data, data.joniPattern); } @Benchmark - public boolean benchmarkCurrent(Data data) + public boolean matchDfa(Data data) { - return data.matcher.match(data.bytes, 0, data.bytes.length); + return data.dfaMatcher.match(data.bytes, 0, data.bytes.length); + } + + @Benchmark + public boolean matchNfa(Data data) + { + return data.nfaMatcher.match(data.bytes, 0, data.bytes.length); + } + + @Benchmark + public JoniRegexp compileJoni(Data data) + { + return compileJoni(data.pattern, (char) 0, false); + } + + @Benchmark + public LikeMatcher compileDfa(Data data) + { + return LikeMatcher.compile(data.pattern, Optional.empty(), true); + } + + @Benchmark + public LikeMatcher compileNfa(Data data) + { + return LikeMatcher.compile(data.pattern, Optional.empty(), false); + } + + @Benchmark + public boolean allJoni(Data data) + { + return likeVarchar(data.data, compileJoni(Slices.utf8Slice(data.pattern).toStringUtf8(), '0', false)); + } + + @Benchmark + public boolean allDfa(Data data) + { + return LikeMatcher.compile(data.pattern, Optional.empty(), true) + .match(data.bytes, 0, data.bytes.length); + } + + @Benchmark + public boolean allNfa(Data data) + { + return LikeMatcher.compile(data.pattern, Optional.empty(), false) + .match(data.bytes, 0, data.bytes.length); } public static boolean likeVarchar(Slice value, JoniRegexp pattern)