Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.nio.charset.StandardCharsets;

import static io.airlift.slice.SliceUtf8.lengthOfCodePointFromStartByte;
import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.Math.toIntExact;
Expand All @@ -64,7 +65,7 @@ public static boolean regexpLike(@SqlType("varchar(x)") Slice source, @SqlType(J
offset = 0;
matcher = pattern.matcher(source.getBytes());
}
return matcher.search(offset, offset + source.length(), Option.DEFAULT) != -1;
return getSearchingOffset(matcher, offset, offset + source.length()) != -1;
}

private static int getNextStart(Slice source, Matcher matcher)
Expand Down Expand Up @@ -110,7 +111,7 @@ public static Slice regexpReplace(@SqlType("varchar(x)") Slice source, @SqlType(
int lastEnd = 0;
int nextStart = 0; // nextStart is the same as lastEnd, unless the last match was zero-width. In such case, nextStart is lastEnd + 1.
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
int offset = getSearchingOffset(matcher, nextStart, source.length());
if (offset == -1) {
break;
}
Expand Down Expand Up @@ -227,7 +228,7 @@ public static Block regexpExtractAll(@SqlType("varchar(x)") Slice source, @SqlTy

int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
int offset = getSearchingOffset(matcher, nextStart, source.length());
if (offset == -1) {
break;
}
Expand Down Expand Up @@ -267,7 +268,7 @@ public static Slice regexpExtract(@SqlType("varchar(x)") Slice source, @SqlType(
validateGroup(groupIndex, matcher.getEagerRegion());
int group = toIntExact(groupIndex);

int offset = matcher.search(0, source.length(), Option.DEFAULT);
int offset = getSearchingOffset(matcher, 0, source.length());
if (offset == -1) {
return null;
}
Expand All @@ -294,7 +295,7 @@ public static Block regexpSplit(@SqlType("varchar(x)") Slice source, @SqlType(Jo
int lastEnd = 0;
int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
int offset = getSearchingOffset(matcher, nextStart, source.length());
if (offset == -1) {
break;
}
Expand Down Expand Up @@ -368,8 +369,8 @@ public static long regexpPosition(
// subtract 1 because codePointCount starts from zero
int nextStart = SliceUtf8.offsetOfCodePoint(source, (int) start - 1);
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
// Check whether offset is negative, offset is -1 if no pattern was found or -2 if process was interrupted
int offset = getSearchingOffset(matcher, nextStart, source.length());
// Check whether offset is negative, offset is -1 if no pattern was found
if (offset < 0) {
return -1;
}
Expand All @@ -395,9 +396,9 @@ public static long regexpCount(@SqlType("varchar(x)") Slice source, @SqlType(Jon
// Start from zero, implies the first byte
int nextStart = 0;
while (true) {
// mather.search returns `source.length` if `nextStart` equals `source.length - 1`.
// getSearchingOffset returns `source.length` if `nextStart` equals `source.length - 1`.
// It should return -1 if `nextStart` is greater than `source.length - 1`.
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
int offset = getSearchingOffset(matcher, nextStart, source.length());
if (offset < 0) {
break;
}
Expand All @@ -408,4 +409,20 @@ public static long regexpCount(@SqlType("varchar(x)") Slice source, @SqlType(Jon

return count;
}

public static int getSearchingOffset(Matcher matcher, int at, int range)
{
try {
return matcher.searchInterruptible(at, range, Option.DEFAULT);
}
catch (InterruptedException interruptedException) {
// The JONI library is compliant with the InterruptedException contract. They reset the interrupted flag before throwing an exception.
// Since the InterruptedException is being caught the interrupt flag must either be recovered or the thread must be terminated.
// Since we are simply throwing a different exception, the interrupt flag must be recovered to propagate the interrupted status to the upper level code.
Comment thread
tangjiangling marked this conversation as resolved.
Thread.currentThread().interrupt();
throw new TrinoException(GENERIC_USER_ERROR, "" +
"Regular expression matching was interrupted, likely because it took too long. " +
"Regular expression in the worst case can have a catastrophic amount of backtracking and having exponential time complexity");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.collect.ImmutableList;
import io.airlift.joni.Matcher;
import io.airlift.joni.Option;
import io.airlift.joni.Region;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
Expand All @@ -33,6 +32,7 @@
import io.trino.type.JoniRegexpType;

import static io.airlift.slice.SliceUtf8.lengthOfCodePointFromStartByte;
import static io.trino.operator.scalar.JoniRegexpFunctions.getSearchingOffset;
import static io.trino.spi.type.VarcharType.VARCHAR;

@ScalarFunction("regexp_replace")
Expand All @@ -51,7 +51,7 @@ public Slice regexpReplace(
{
// If there is no match we can simply return the original source without doing copy.
Matcher matcher = pattern.matcher(source.getBytes());
if (matcher.search(0, source.length(), Option.DEFAULT) == -1) {
if (getSearchingOffset(matcher, 0, source.length()) == -1) {
return source;
}

Expand Down Expand Up @@ -110,8 +110,7 @@ public Slice regexpReplace(
}
output.appendBytes(replaced);
}
while (matcher.search(nextStart, source.length(), Option.DEFAULT) != -1);

while (getSearchingOffset(matcher, nextStart, source.length()) != -1);
// Append the last un-matched part
output.writeBytes(source, appendPosition, source.length() - appendPosition);
return output.slice();
Expand Down
19 changes: 18 additions & 1 deletion core/trino-main/src/main/java/io/trino/type/LikeFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static io.airlift.joni.constants.SyntaxProperties.OP_LINE_ANCHOR;
import static io.airlift.slice.SliceUtf8.getCodePointAt;
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.Chars.padSpaces;
import static io.trino.util.Failures.checkCondition;
Expand Down Expand Up @@ -85,7 +86,7 @@ public static boolean likeVarchar(@SqlType("varchar(x)") Slice value, @SqlType(L
offset = 0;
matcher = pattern.matcher(value.getBytes());
}
return matcher.match(offset, offset + value.length(), Option.NONE) != -1;
return getMatchingOffset(matcher, offset, offset + value.length()) != -1;
}

@ScalarFunction(value = LIKE_PATTERN_FUNCTION_NAME, hidden = true)
Expand Down Expand Up @@ -245,4 +246,20 @@ private static char getEscapeChar(Slice escape)
}
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Escape string must be a single character");
}

private static int getMatchingOffset(Matcher matcher, int at, int range)
{
try {
return matcher.matchInterruptible(at, range, Option.NONE);
}
catch (InterruptedException interruptedException) {
// The JONI library is compliant with the InterruptedException contract. They reset the interrupted flag before throwing an exception.
// Since the InterruptedException is being caught the interrupt flag must either be recovered or the thread must be terminated.
// Since we are simply throwing a different exception, the interrupt flag must be recovered to propagate the interrupted status to the upper level code.
Thread.currentThread().interrupt();
throw new TrinoException(GENERIC_USER_ERROR, "" +
"Regular expression matching was interrupted, likely because it took too long. " +
"Regular expression in the worst case can have a catastrophic amount of backtracking and having exponential time complexity");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,22 @@
*/
package io.trino.operator.scalar;

import com.google.common.io.Resources;
import io.trino.spi.TrinoException;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicReference;

import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp;
import static io.trino.operator.scalar.JoniRegexpFunctions.regexpReplace;
import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static io.trino.sql.analyzer.RegexLibrary.JONI;
import static io.trino.type.LikeFunctions.likeVarchar;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;

public class TestJoniRegexpFunctions
extends AbstractTestRegexpFunctions
Expand All @@ -22,4 +37,49 @@ public TestJoniRegexpFunctions()
{
super(JONI);
}

@Test
public void testMatchInterruptible()
throws IOException, InterruptedException
{
String source = Resources.toString(Resources.getResource("regularExpressionExtraLongSource.txt"), UTF_8);
String pattern = "\\((.*,)+(.*\\))";
// Test the interruptible version of `Matcher#match` by "LIKE"
testJoniRegexpFunctionsInterruptible(() -> likeVarchar(utf8Slice(source), joniRegexp(utf8Slice(pattern))));
}

@Test
public void testSearchInterruptible()
throws IOException, InterruptedException
{
String source = Resources.toString(Resources.getResource("regularExpressionExtraLongSource.txt"), UTF_8);
String pattern = "\\((.*,)+(.*\\))";
// Test the interruptible version of `Matcher#search` by "REGEXP_REPLACE"
testJoniRegexpFunctionsInterruptible(() -> regexpReplace(utf8Slice(source), joniRegexp(utf8Slice(pattern))));
}

private static void testJoniRegexpFunctionsInterruptible(Runnable joniRegexpRunnable)
throws InterruptedException
{
AtomicReference<TrinoException> trinoException = new AtomicReference<>();
Thread searchChildThread = new Thread(() -> {
try {
joniRegexpRunnable.run();
}
catch (TrinoException e) {
trinoException.compareAndSet(null, e);
}
});

searchChildThread.start();

// wait for the child thread to make some progress
searchChildThread.join(1000);
searchChildThread.interrupt();

// wait for child thread to get in to terminated state
searchChildThread.join();
assertNotNull(trinoException.get());
assertEquals(trinoException.get().getErrorCode(), GENERIC_USER_ERROR.toErrorCode());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql;
package io.trino.operator.scalar;

import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.operator.scalar.AbstractTestFunctions;
import io.trino.spi.TrinoException;
import io.trino.spi.expression.StandardFunctions;
import io.trino.type.JoniRegexp;
Expand Down

Large diffs are not rendered by default.