Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,54 @@ public static Number coerceToWidestFloatingType(Number a, Number b, double value
return (float) value;
}
}

public static Number integralCosh(Number x) {
double x0 = x.doubleValue();
return Math.cosh(x0);
}

public static Number floatingCosh(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return Math.cosh(x0.doubleValue());
}

public static Number integralSinh(Number x) {
double x0 = x.doubleValue();
return Math.sinh(x0);
}

public static Number floatingSinh(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return Math.sinh(x0.doubleValue());
}

public static Number integralExpm1(Number x) {
double x0 = x.doubleValue();
return Math.expm1(x0);
}

public static Number floatingExpm1(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return Math.expm1(x0.doubleValue());
}

public static Number integralRint(Number x) {
double x0 = x.doubleValue();
return Math.rint(x0);
}

public static Number floatingRint(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return Math.rint(x0.doubleValue());
}

public static Number integralSignum(Number x) {
double x0 = x.doubleValue();
return (int) Math.signum(x0);
}

public static Number floatingSignum(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return (int) Math.signum(x0.doubleValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Optionality;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand Down Expand Up @@ -220,6 +222,42 @@ public UDFOperandMetadata getOperandMetadata() {
};
}

public static ImplementorUDF adaptMathFunctionToUDF(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a Javadoc for this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. Added.

String methodName,
SqlReturnTypeInference returnTypeInference,
NullPolicy nullPolicy,
UDFOperandMetadata operandMetadata) {

NotNullImplementor implementor =
(translator, call, translatedOperands) -> {
Expression operand = translatedOperands.get(0);
RelDataType inputType = call.getOperands().get(0).getType();

// 保留类型区分逻辑
if (SqlTypeFamily.INTEGER.contains(inputType)) {
operand = Expressions.convert_(operand, Number.class);
return Expressions.call(
MathUtils.class, "integral" + StringUtils.capitalize(methodName), operand);
} else {
operand = Expressions.convert_(operand, Number.class);
return Expressions.call(
MathUtils.class, "floating" + StringUtils.capitalize(methodName), operand);
}
};

return new ImplementorUDF(implementor, nullPolicy) {
@Override
public SqlReturnTypeInference getReturnTypeInference() {
return returnTypeInference;
}

@Override
public UDFOperandMetadata getOperandMetadata() {
return operandMetadata;
}
};
}

public static List<Expression> prependFunctionProperties(
List<Expression> operands, RexToLixTranslator translator) {
List<Expression> operandsWithProperties = new ArrayList<>(operands);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodToUDF;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodWithPropertiesToUDF;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptMathFunctionToUDF;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -103,6 +104,51 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
public static final SqlOperator DIVIDE = new DivideFunction().toUDF("DIVIDE");
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");
// public static final SqlOperator COSH = new CoshFunction().toUDF("COSH");
// public static final SqlOperator SINH = new SinhFunction().toUDF("SINH");
// public static final SqlOperator EXPM1 = new Expm1Function().toUDF("EXPM1");
// public static final SqlOperator RINT = new RintFunction().toUDF("RINT");
// public static final SqlOperator SIGNUM = new SignumFunction().toUDF("SIGNUM");

public static final SqlOperator COSH =
adaptMathFunctionToUDF(
"Cosh",
ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE),
NullPolicy.ANY,
PPLOperandTypes.NUMERIC)
.toUDF("COSH");

public static final SqlOperator SINH =
adaptMathFunctionToUDF(
"Sinh",
ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE),
NullPolicy.ANY,
PPLOperandTypes.NUMERIC)
.toUDF("SINH");

public static final SqlOperator RINT =
adaptMathFunctionToUDF(
"Rint",
ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE),
NullPolicy.ANY,
PPLOperandTypes.NUMERIC)
.toUDF("RINT");

public static final SqlOperator EXPM1 =
adaptMathFunctionToUDF(
"Expm1",
ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE),
NullPolicy.ANY,
PPLOperandTypes.NUMERIC)
.toUDF("EXPM1");

public static final SqlOperator SIGNUM =
adaptMathFunctionToUDF(
"Signum",
ReturnTypes.INTEGER.andThen(SqlTypeTransforms.FORCE_NULLABLE),
NullPolicy.ANY,
PPLOperandTypes.NUMERIC)
.toUDF("SIGNUM");

// IP comparing functions
public static final SqlOperator NOT_EQUALS_IP =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDFUNCTION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDTIME;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.AND;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY;
Expand All @@ -37,6 +38,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CONV;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CONVERT_TZ;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COS;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COSH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CRC32;
Expand All @@ -61,11 +63,13 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DEGREES;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDEFUNCTION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.E;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EARLIEST;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EXISTS;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EXP;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EXPM1;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EXTRACT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.FILTER;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.FLOOR;
Expand Down Expand Up @@ -137,6 +141,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTHNAME;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLYFUNCTION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTI_MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOTEQUAL;
Expand All @@ -159,6 +164,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REPLACE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REVERSE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.RIGHT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.RINT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ROUND;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.RTRIM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND;
Expand All @@ -167,8 +173,10 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SHA1;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SHA2;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIGN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIGNUM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIMPLE_QUERY_STRING;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SINH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SPAN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SQRT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV_POP;
Expand All @@ -180,6 +188,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTRING;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTIME;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACTFUNCTION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TAKE;
Expand Down Expand Up @@ -668,8 +677,11 @@ void populate() {
registerOperator(OR, SqlStdOperatorTable.OR);
registerOperator(NOT, SqlStdOperatorTable.NOT);
registerOperator(ADD, SqlStdOperatorTable.PLUS);
registerOperator(ADDFUNCTION, SqlStdOperatorTable.PLUS);
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
registerOperator(SUBTRACTFUNCTION, SqlStdOperatorTable.MINUS);
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);
registerOperator(MULTIPLYFUNCTION, SqlStdOperatorTable.MULTIPLY);
registerOperator(TRUNCATE, SqlStdOperatorTable.TRUNCATE);
registerOperator(ASCII, SqlStdOperatorTable.ASCII);
registerOperator(LENGTH, SqlStdOperatorTable.CHAR_LENGTH);
Expand Down Expand Up @@ -725,6 +737,11 @@ void populate() {
registerOperator(INTERNAL_REGEXP_REPLACE_3, SqlLibraryOperators.REGEXP_REPLACE_3);

// Register PPL UDF operator
registerOperator(COSH, PPLBuiltinOperators.COSH);
registerOperator(SINH, PPLBuiltinOperators.SINH);
registerOperator(EXPM1, PPLBuiltinOperators.EXPM1);
registerOperator(RINT, PPLBuiltinOperators.RINT);
registerOperator(SIGNUM, PPLBuiltinOperators.SIGNUM);
registerOperator(SPAN, PPLBuiltinOperators.SPAN);
registerOperator(E, PPLBuiltinOperators.E);
registerOperator(CONV, PPLBuiltinOperators.CONV);
Expand All @@ -733,6 +750,7 @@ void populate() {
registerOperator(MODULUSFUNCTION, PPLBuiltinOperators.MOD);
registerOperator(CRC32, PPLBuiltinOperators.CRC32);
registerOperator(DIVIDE, PPLBuiltinOperators.DIVIDE);
registerOperator(DIVIDEFUNCTION, PPLBuiltinOperators.DIVIDE);
registerOperator(SHA2, PPLBuiltinOperators.SHA2);
registerOperator(CIDRMATCH, PPLBuiltinOperators.CIDRMATCH);
registerOperator(INTERNAL_GROK, PPLBuiltinOperators.GROK);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.udf.math;

import java.math.BigDecimal;
import java.util.List;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

/** Implementation for cosh function. */
public class CoshFunction extends ImplementorUDF {
public CoshFunction() {
super(new CoshImplementor(), NullPolicy.ANY);
}

@Override
public SqlReturnTypeInference getReturnTypeInference() {
return ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE);
}

@Override
public UDFOperandMetadata getOperandMetadata() {
return PPLOperandTypes.NUMERIC;
}

public static class CoshImplementor implements NotNullImplementor {

@Override
public Expression implement(
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
Expression operand = translatedOperands.get(0);
RelDataType inputType = call.getOperands().get(0).getType();

if (SqlTypeFamily.INTEGER.contains(inputType)) {
operand = Expressions.convert_(operand, Number.class);
return Expressions.call(CoshImplementor.class, "IntegralCosh", operand);
} else {
operand = Expressions.convert_(operand, Number.class);
return Expressions.call(CoshImplementor.class, "FloatingCosh", operand);
}
}

public static Number IntegralCosh(Number x) {
double x0 = x.doubleValue();
return Math.cosh(x0);
}

public static Number FloatingCosh(Number x) {
BigDecimal x0 = new BigDecimal(x.toString());
return Math.cosh(x0.doubleValue());
}
}
}
Loading
Loading