Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions core/src/main/java/org/opensearch/sql/ast/tree/Rex.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,52 @@
public class Rex extends UnresolvedPlan {

public enum RexMode {
EXTRACT
EXTRACT,
SED
}

/** Field to extract from. */
private final UnresolvedExpression field;

/** Pattern with named capture groups. */
/** Pattern with named capture groups or sed expression. */
private final Literal pattern;

/** Rex mode (only EXTRACT supported). */
/** Rex mode (extract or sed). */
private final RexMode mode;

/** Maximum number of matches (optional). */
private final Optional<Integer> maxMatch;

/** Offset field name for position tracking (optional). */
private final Optional<String> offsetField;

/** Child Plan. */
@Setter private UnresolvedPlan child;

public Rex(UnresolvedExpression field, Literal pattern) {
this(field, pattern, RexMode.EXTRACT, Optional.empty());
this(field, pattern, RexMode.EXTRACT, Optional.empty(), Optional.empty());
}

public Rex(UnresolvedExpression field, Literal pattern, Optional<Integer> maxMatch) {
this(field, pattern, RexMode.EXTRACT, maxMatch);
this(field, pattern, RexMode.EXTRACT, maxMatch, Optional.empty());
}

public Rex(
UnresolvedExpression field, Literal pattern, RexMode mode, Optional<Integer> maxMatch) {
this(field, pattern, mode, maxMatch, Optional.empty());
}

public Rex(
UnresolvedExpression field,
Literal pattern,
RexMode mode,
Optional<Integer> maxMatch,
Optional<String> offsetField) {
this.field = field;
this.pattern = pattern;
this.mode = mode;
this.maxMatch = maxMatch;
this.offsetField = offsetField;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ public RelNode visitRex(Rex node, CalcitePlanContext context) {
RexNode fieldRex = rexVisitor.analyze(node.getField(), context);
String patternStr = (String) node.getPattern().getValue();

if (node.getMode() == Rex.RexMode.SED) {
RexNode sedCall = createOptimizedSedCall(fieldRex, patternStr, context);
String fieldName = node.getField().toString();
projectPlusOverriding(List.of(sedCall), List.of(fieldName), context);
return context.relBuilder.peek();
}

List<String> namedGroups = RegexCommonUtils.getNamedGroupCandidates(patternStr);

if (namedGroups.isEmpty()) {
Expand Down Expand Up @@ -251,6 +258,17 @@ public RelNode visitRex(Rex node, CalcitePlanContext context) {
newFieldNames.add(namedGroups.get(i));
}

if (node.getOffsetField().isPresent()) {
RexNode offsetCall =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
BuiltinFunctionName.REX_OFFSET,
fieldRex,
context.rexBuilder.makeLiteral(patternStr));
newFields.add(offsetCall);
newFieldNames.add(node.getOffsetField().get());
}

projectPlusOverriding(newFields, newFieldNames, context);
return context.relBuilder.peek();
}
Expand Down Expand Up @@ -2136,4 +2154,115 @@ private void buildExpandRelNode(
context.relBuilder.rename(names);
}
}

/** Creates an optimized sed call using native Calcite functions */
private RexNode createOptimizedSedCall(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np: If sed/regex is common across different commands, it may worth creating dedicated AST expression nodes and separate these into its own visit method in the future.

RexNode fieldRex, String sedExpression, CalcitePlanContext context) {
if (sedExpression.startsWith("s/")) {
return createOptimizedSubstitution(fieldRex, sedExpression, context);
} else if (sedExpression.startsWith("y/")) {
return createOptimizedTransliteration(fieldRex, sedExpression, context);
} else {
throw new RuntimeException("Unsupported sed pattern: " + sedExpression);
}
}

/** Creates optimized substitution calls for s/pattern/replacement/flags syntax. */
private RexNode createOptimizedSubstitution(
RexNode fieldRex, String sedExpression, CalcitePlanContext context) {
try {
// Parse sed substitution: s/pattern/replacement/flags
if (!sedExpression.matches("s/.+/.*/.*")) {
throw new IllegalArgumentException("Invalid sed substitution format");
}

// Find the delimiters - sed format is s/pattern/replacement/flags
int firstDelimiter = sedExpression.indexOf('/', 2); // First '/' after 's/'
int secondDelimiter = sedExpression.indexOf('/', firstDelimiter + 1); // Second '/'
int thirdDelimiter = sedExpression.indexOf('/', secondDelimiter + 1); // Third '/' (optional)

if (firstDelimiter == -1 || secondDelimiter == -1) {
throw new IllegalArgumentException("Invalid sed substitution format");
}

String pattern = sedExpression.substring(2, firstDelimiter);
String replacement = sedExpression.substring(firstDelimiter + 1, secondDelimiter);
String flags =
secondDelimiter + 1 < sedExpression.length()
? sedExpression.substring(secondDelimiter + 1)
: "";

// Convert sed backreferences (\1, \2) to Java style ($1, $2)
String javaReplacement = replacement.replaceAll("\\\\(\\d+)", "\\$$1");

if (flags.isEmpty()) {
// 3-parameter REGEXP_REPLACE
return PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_3,
fieldRex,
context.rexBuilder.makeLiteral(pattern),
context.rexBuilder.makeLiteral(javaReplacement));
} else if (flags.matches("[gi]+")) {
// 4-parameter REGEXP_REPLACE with flags
return PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_PG_4,
fieldRex,
context.rexBuilder.makeLiteral(pattern),
context.rexBuilder.makeLiteral(javaReplacement),
context.rexBuilder.makeLiteral(flags));
} else if (flags.matches("\\d+")) {
// 5-parameter REGEXP_REPLACE with occurrence
int occurrence = Integer.parseInt(flags);
return PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_5,
fieldRex,
context.rexBuilder.makeLiteral(pattern),
context.rexBuilder.makeLiteral(javaReplacement),
context.relBuilder.literal(1), // start position
context.relBuilder.literal(occurrence));
} else {
throw new RuntimeException(
"Unsupported sed flags: " + flags + " in expression: " + sedExpression);
}
} catch (Exception e) {
throw new RuntimeException("Failed to optimize sed expression: " + sedExpression, e);
}
}

/** Creates optimized transliteration calls for y/from/to/ syntax. */
private RexNode createOptimizedTransliteration(
RexNode fieldRex, String sedExpression, CalcitePlanContext context) {
try {
// Parse sed transliteration: y/from/to/
if (!sedExpression.matches("y/.+/.*/.*")) {
throw new IllegalArgumentException("Invalid sed transliteration format");
}

int firstSlash = sedExpression.indexOf('/', 1);
int secondSlash = sedExpression.indexOf('/', firstSlash + 1);
int thirdSlash = sedExpression.indexOf('/', secondSlash + 1);

if (firstSlash == -1 || secondSlash == -1) {
throw new IllegalArgumentException("Invalid sed transliteration format");
}

String from = sedExpression.substring(firstSlash + 1, secondSlash);
String to =
sedExpression.substring(
secondSlash + 1, thirdSlash != -1 ? thirdSlash : sedExpression.length());

// Use Calcite's native TRANSLATE3 function
return PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
BuiltinFunctionName.INTERNAL_TRANSLATE3,
fieldRex,
context.rexBuilder.makeLiteral(from),
context.rexBuilder.makeLiteral(to));
} catch (Exception e) {
throw new RuntimeException("Failed to optimize sed expression: " + sedExpression, e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ public enum BuiltinFunctionName {
REGEX_MATCH(FunctionName.of("regex_match")),
REX_EXTRACT(FunctionName.of("REX_EXTRACT")),
REX_EXTRACT_MULTI(FunctionName.of("REX_EXTRACT_MULTI")),
REX_OFFSET(FunctionName.of("REX_OFFSET")),
REPLACE(FunctionName.of("replace")),
REVERSE(FunctionName.of("reverse")),
RIGHT(FunctionName.of("right")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.sql.expression.function.udf.RelevanceQueryFunction;
import org.opensearch.sql.expression.function.udf.RexExtractFunction;
import org.opensearch.sql.expression.function.udf.RexExtractMultiFunction;
import org.opensearch.sql.expression.function.udf.RexOffsetFunction;
import org.opensearch.sql.expression.function.udf.SpanFunction;
import org.opensearch.sql.expression.function.udf.condition.EarliestFunction;
import org.opensearch.sql.expression.function.udf.condition.EnhancedCoalesceFunction;
Expand Down Expand Up @@ -406,6 +407,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
public static final SqlOperator REX_EXTRACT = new RexExtractFunction().toUDF("REX_EXTRACT");
public static final SqlOperator REX_EXTRACT_MULTI =
new RexExtractMultiFunction().toUDF("REX_EXTRACT_MULTI");
public static final SqlOperator REX_OFFSET = new RexOffsetFunction().toUDF("REX_OFFSET");

// Aggregation functions
public static final SqlAggFunction AVG_NULLABLE = new NullableSqlAvgAggFunction(SqlKind.AVG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REVERSE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REX_EXTRACT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REX_EXTRACT_MULTI;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REX_OFFSET;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.RIGHT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.RINT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ROUND;
Expand Down Expand Up @@ -715,6 +716,7 @@ void populate() {
registerOperator(MULTI_MATCH, PPLBuiltinOperators.MULTI_MATCH);
registerOperator(REX_EXTRACT, PPLBuiltinOperators.REX_EXTRACT);
registerOperator(REX_EXTRACT_MULTI, PPLBuiltinOperators.REX_EXTRACT_MULTI);
registerOperator(REX_OFFSET, PPLBuiltinOperators.REX_OFFSET);

// Register PPL Datetime UDF operator
registerOperator(TIMESTAMP, PPLBuiltinOperators.TIMESTAMP);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.udf;

import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

/** Custom REX_OFFSET function for calculating regex match positions. */
public final class RexOffsetFunction extends ImplementorUDF {

public RexOffsetFunction() {
super(new RexOffsetImplementor(), NullPolicy.ARG0);
}

@Override
public SqlReturnTypeInference getReturnTypeInference() {
return ReturnTypes.VARCHAR_2000_NULLABLE;
}

@Override
public UDFOperandMetadata getOperandMetadata() {
return PPLOperandTypes.STRING_STRING;
}

private static class RexOffsetImplementor implements NotNullImplementor {

@Override
public Expression implement(
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
Expression field = translatedOperands.get(0);
Expression pattern = translatedOperands.get(1);

return Expressions.call(RexOffsetFunction.class, "calculateOffsets", field, pattern);
}
}

public static String calculateOffsets(String text, String patternStr) {
if (text == null || patternStr == null) {
return null;
}

try {
Pattern pattern = Pattern.compile(patternStr);
Matcher matcher = pattern.matcher(text);

if (!matcher.find()) {
return null;
}

List<String> offsetPairs = new java.util.ArrayList<>();

Pattern namedGroupPattern = Pattern.compile("\\(\\?<([^>]+)>");
Matcher namedGroupMatcher = namedGroupPattern.matcher(patternStr);

int groupIndex = 1;

while (namedGroupMatcher.find()) {
String groupName = namedGroupMatcher.group(1);

if (groupIndex <= matcher.groupCount()) {
int start = matcher.start(groupIndex);
int end = matcher.end(groupIndex);

if (start >= 0 && end >= 0) {
offsetPairs.add(groupName + "=" + start + "-" + (end - 1));
}
}
groupIndex++;
}

java.util.Collections.sort(offsetPairs);
return offsetPairs.isEmpty() ? null : String.join("&", offsetPairs);
} catch (PatternSyntaxException e) {
throw new IllegalArgumentException(
"Invalid regex pattern in rex command: " + e.getMessage(), e);
}
}
}
Loading
Loading