Skip to content
71 changes: 20 additions & 51 deletions core/trino-main/src/main/java/io/trino/likematcher/DFA.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,23 @@
package io.trino.likematcher;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

record DFA(State start, State failed, List<State> states, Map<Integer, List<Transition>> transitions)
record DFA(int start, IntArrayList acceptStates, List<List<Transition>> transitions)
Comment thread
martint marked this conversation as resolved.
Outdated
{
DFA
{
requireNonNull(start, "start is null");
requireNonNull(failed, "failed is null");
states = ImmutableList.copyOf(states);
transitions = ImmutableMap.copyOf(transitions);
requireNonNull(acceptStates, "acceptStates is null");
transitions = ImmutableList.copyOf(transitions);
}

public List<Transition> transitions(State state)
{
return transitions.get(state.id);
}

record State(int id, String label, boolean accept)
{
@Override
public String toString()
{
return "%s:%s%s".formatted(
id,
accept ? "*" : "",
label);
}
}

record Transition(int value, State target)
record Transition(int value, int target)
{
@Override
public String toString()
Expand All @@ -64,43 +42,34 @@ public String toString()
public static class Builder
{
private int nextId;
private State start;
private State failed;
private final List<State> states = new ArrayList<>();
private final Map<Integer, List<Transition>> transitions = new HashMap<>();
private int start;
private final IntArrayList acceptStates = new IntArrayList();
private final List<List<Transition>> transitions = new ArrayList<>();

public State addState(String label, boolean accept)
public int addState(boolean accept)
{
State state = new State(nextId++, label, accept);
states.add(state);
int state = nextId++;
transitions.add(new ArrayList<>());
if (accept) {
acceptStates.add(state);
}
return state;
}

public State addStartState(String label, boolean accept)
public int addStartState(boolean accept)
{
checkState(start == null, "Start state already set");
State state = addState(label, accept);
start = state;
return state;
}

public State addFailState()
{
checkState(failed == null, "Fail state already set");
State state = addState("fail", false);
failed = state;
return state;
start = addState(accept);
Comment thread
martint marked this conversation as resolved.
return start;
}

public void addTransition(State from, int value, State to)
public void addTransition(int from, int value, int to)
{
transitions.computeIfAbsent(from.id(), key -> new ArrayList<>())
.add(new Transition(value, to));
transitions.get(from).add(new Transition(value, to));
}

public DFA build()
{
return new DFA(start, failed, states, transitions);
return new DFA(start, acceptStates, transitions);
}
}
}
234 changes: 168 additions & 66 deletions core/trino-main/src/main/java/io/trino/likematcher/DenseDfaMatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,102 +13,204 @@
*/
package io.trino.likematcher;

import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

class DenseDfaMatcher
implements Matcher
{
// The DFA is encoded as a sequence of transitions for each possible byte value for each state.
// I.e., 256 transitions per state.
// The content of the transitions array is the base offset into
// the next state to follow. I.e., the desired state * 256
private final int[] transitions;
public static final int FAIL_STATE = -1;

// The starting state
private final List<Pattern> pattern;
private final int start;

// For each state, whether it's an accepting state
private final boolean[] accept;

// Artificial state to sink all invalid matches
private final int fail;

private final int end;
private final boolean exact;

/**
* @param exact whether to match to the end of the input
*/
public static DenseDfaMatcher newInstance(DFA dfa, boolean exact)
{
int[] transitions = new int[dfa.states().size() * 256];
boolean[] accept = new boolean[dfa.states().size()];

for (DFA.State state : dfa.states()) {
for (DFA.Transition transition : dfa.transitions(state)) {
transitions[state.id() * 256 + transition.value()] = transition.target().id() * 256;
}
private volatile DenseDfa matcher;

if (state.accept()) {
accept[state.id()] = true;
}
}

return new DenseDfaMatcher(transitions, dfa.start().id(), accept, 0, exact);
}

private DenseDfaMatcher(int[] transitions, int start, boolean[] accept, int fail, boolean exact)
public DenseDfaMatcher(List<Pattern> pattern, int start, int end, boolean exact)
{
this.transitions = transitions;
this.pattern = requireNonNull(pattern, "pattern is null");
this.start = start;
this.accept = accept;
this.fail = fail;
this.end = end;
this.exact = exact;
}

@Override
public boolean match(byte[] input, int offset, int length)
{
DenseDfa matcher = this.matcher;
if (matcher == null) {
matcher = DenseDfa.newInstance(pattern, start, end);
this.matcher = matcher;
}

if (exact) {
return exactMatch(input, offset, length);
return matcher.exactMatch(input, offset, length);
}

return prefixMatch(input, offset, length);
return matcher.prefixMatch(input, offset, length);
}

/**
* Returns a positive match when the final state after all input has been consumed is an accepting state
*/
private boolean exactMatch(byte[] input, int offset, int length)
private static class DenseDfa
{
int state = start << 8;
for (int i = offset; i < offset + length; i++) {
byte inputByte = input[i];
state = transitions[state | (inputByte & 0xFF)];
// The DFA is encoded as a sequence of transitions for each possible byte value for each state.
// I.e., 256 transitions per state.
// The content of the transitions array is the base offset into
// the next state to follow. I.e., the desired state * 256
private final int[] transitions;

// The starting state
private final int start;

// For each state, whether it's an accepting state
private final boolean[] accept;

if (state == fail) {
return false;
public static DenseDfa newInstance(List<Pattern> pattern, int start, int end)
{
DFA dfa = makeNfa(pattern, start, end).toDfa();

int[] transitions = new int[dfa.transitions().size() * 256];
Arrays.fill(transitions, FAIL_STATE);

for (int state = 0; state < dfa.transitions().size(); state++) {
for (DFA.Transition transition : dfa.transitions().get(state)) {
transitions[state * 256 + transition.value()] = transition.target() * 256;
}
}
boolean[] accept = new boolean[dfa.transitions().size()];
for (int state : dfa.acceptStates()) {
accept[state] = true;
}

return new DenseDfa(transitions, dfa.start(), accept);
}

return accept[state >>> 8];
}
private DenseDfa(int[] transitions, int start, boolean[] accept)
{
this.transitions = transitions;
this.start = start;
this.accept = accept;
}

/**
* Returns a positive match as soon as the DFA reaches an accepting state, regardless of whether
* the whole input has been consumed
*/
private boolean prefixMatch(byte[] input, int offset, int length)
{
int state = start << 8;
for (int i = offset; i < offset + length; i++) {
byte inputByte = input[i];
state = transitions[state | (inputByte & 0xFF)];
/**
* Returns a positive match when the final state after all input has been consumed is an accepting state
*/
public boolean exactMatch(byte[] input, int offset, int length)
{
int state = start << 8;
for (int i = offset; i < offset + length; i++) {
byte inputByte = input[i];
state = transitions[state | (inputByte & 0xFF)];

if (state == FAIL_STATE) {
return false;
}
}

if (state == fail) {
return false;
return accept[state >>> 8];
}

/**
* Returns a positive match as soon as the DFA reaches an accepting state, regardless of whether
* the whole input has been consumed
*/
public boolean prefixMatch(byte[] input, int offset, int length)
{
int state = start << 8;
for (int i = offset; i < offset + length; i++) {
byte inputByte = input[i];
state = transitions[state | (inputByte & 0xFF)];

if (state == FAIL_STATE) {
return false;
}

if (accept[state >>> 8]) {
return true;
}
}

if (accept[state >>> 8]) {
return true;
return accept[state >>> 8];
}

private static NFA makeNfa(List<Pattern> pattern, int start, int end)
{
checkArgument(!pattern.isEmpty(), "pattern is empty");

NFA.Builder builder = new NFA.Builder();

int state = builder.addStartState();

for (int e = start; e <= end; e++) {
Pattern item = pattern.get(e);
if (item instanceof Pattern.Literal literal) {
Comment thread
martint marked this conversation as resolved.
Outdated
for (byte current : literal.value().getBytes(UTF_8)) {
state = matchByte(builder, state, current);
}
}
else if (item instanceof Pattern.Any any) {
for (int i = 0; i < any.length(); i++) {
int next = builder.addState();
matchSingleUtf8(builder, state, next);
state = next;
}
}
else if (item instanceof Pattern.ZeroOrMore) {
matchSingleUtf8(builder, state, state);
}
else {
throw new UnsupportedOperationException("Not supported: " + item.getClass().getName());
}
}

builder.setAccept(state);

return builder.build();
}

private static int matchByte(NFA.Builder builder, int state, byte value)
{
int next = builder.addState();
builder.addTransition(state, new NFA.Value(value), next);
return next;
}

return accept[state >>> 8];
private static void matchSingleUtf8(NFA.Builder builder, int from, int to)
{
/*
Implements a state machine to recognize UTF-8 characters.

11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
O ───────────► O ───────────► O ───────────► O ───────────► O
│ ▲ ▲ ▲
├─────────────────────────────┘ │ │
│ 1110xxxx │ │
│ │ │
├────────────────────────────────────────────┘ │
│ 110xxxxx │
│ │
└───────────────────────────────────────────────────────────┘
0xxxxxxx
*/

builder.addTransition(from, new NFA.Prefix(0, 1), to);

int state1 = builder.addState();
int state2 = builder.addState();
int state3 = builder.addState();

builder.addTransition(from, new NFA.Prefix(0b11110, 5), state1);
builder.addTransition(from, new NFA.Prefix(0b1110, 4), state2);
builder.addTransition(from, new NFA.Prefix(0b110, 3), state3);

builder.addTransition(state1, new NFA.Prefix(0b10, 2), state2);
builder.addTransition(state2, new NFA.Prefix(0b10, 2), state3);
builder.addTransition(state3, new NFA.Prefix(0b10, 2), to);
}
}
}
Loading