diff --git a/engine/src/main/java/com/arcadedb/query/sql/function/DefaultSQLFunctionFactory.java b/engine/src/main/java/com/arcadedb/query/sql/function/DefaultSQLFunctionFactory.java index ff992cb1df..6edd0d379c 100644 --- a/engine/src/main/java/com/arcadedb/query/sql/function/DefaultSQLFunctionFactory.java +++ b/engine/src/main/java/com/arcadedb/query/sql/function/DefaultSQLFunctionFactory.java @@ -50,6 +50,7 @@ import com.arcadedb.query.sql.function.math.SQLFunctionEval; import com.arcadedb.query.sql.function.math.SQLFunctionMax; import com.arcadedb.query.sql.function.math.SQLFunctionMin; +import com.arcadedb.query.sql.function.math.SQLFunctionSquareRoot; import com.arcadedb.query.sql.function.math.SQLFunctionSum; import com.arcadedb.query.sql.function.misc.SQLFunctionCoalesce; import com.arcadedb.query.sql.function.misc.SQLFunctionCount; @@ -99,6 +100,7 @@ public DefaultSQLFunctionFactory() { register(SQLFunctionMax.NAME, SQLFunctionMax.class); register(SQLFunctionMin.NAME, SQLFunctionMin.class); register(SQLFunctionSet.NAME, SQLFunctionSet.class); + register(SQLFunctionSquareRoot.NAME, SQLFunctionSquareRoot.class); register(SQLFunctionSum.NAME, SQLFunctionSum.class); register(SQLFunctionSysdate.NAME, SQLFunctionSysdate.class); register(SQLFunctionUnionAll.NAME, SQLFunctionUnionAll.class); diff --git a/engine/src/main/java/com/arcadedb/query/sql/function/math/SQLFunctionSquareRoot.java b/engine/src/main/java/com/arcadedb/query/sql/function/math/SQLFunctionSquareRoot.java new file mode 100644 index 0000000000..3c2d0da410 --- /dev/null +++ b/engine/src/main/java/com/arcadedb/query/sql/function/math/SQLFunctionSquareRoot.java @@ -0,0 +1,83 @@ +/* + * Copyright © 2021-present Arcade Data Ltd (info@arcadedata.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-FileCopyrightText: 2021-present Arcade Data Ltd (info@arcadedata.com) + * SPDX-License-Identifier: Apache-2.0 + */ +package com.arcadedb.query.sql.function.math; + +import com.arcadedb.database.Identifiable; +import com.arcadedb.query.sql.executor.CommandContext; + +import java.math.*; +import java.time.*; + +public class SQLFunctionSquareRoot extends SQLFunctionMathAbstract { + public static final String NAME = "sqrt"; + private Object result; + + public SQLFunctionSquareRoot() { + super(NAME); + } + + public Object execute(final Object iThis, final Identifiable iRecord, final Object iCurrentResult, final Object[] iParams, final CommandContext iContext) { + final Object inputValue = iParams[0]; + + if (inputValue == null) { + result = null; + } else if (inputValue instanceof Number && ((Number) inputValue).doubleValue() < 0.0) { + result = null; + } else if (inputValue instanceof BigDecimal) { + result = ((BigDecimal) inputValue).sqrt(new MathContext(10)); + } else if (inputValue instanceof BigInteger) { + result = ((BigInteger) inputValue).sqrt(); + } else if (inputValue instanceof Integer) { + result = (int) Math.sqrt((Integer) inputValue); + } else if (inputValue instanceof Long) { + result = (new Double ((int) Math.sqrt((Long) inputValue))).longValue(); + } else if (inputValue instanceof Short) { + result = (new Double (Math.sqrt((Short) inputValue))).shortValue(); + } else if (inputValue instanceof Double) { + result = Math.sqrt((Double) inputValue); + } else if (inputValue instanceof Float) { + result = (new Double(Math.sqrt((Float) inputValue))).floatValue(); + } else if (inputValue instanceof Duration) { + final int seconds = ((Duration) inputValue).toSecondsPart(); + final long nanos = ((Duration) inputValue).toNanosPart(); + if (seconds < 0 && nanos < 0) + result = null; + else { + result = Duration.ofSeconds((int) Math.sqrt(seconds), (long) Math.sqrt(nanos)); + } + } else { + throw new IllegalArgumentException("Argument to square root must be a number."); + } + + return getResult(); + } + + public boolean aggregateResults() { + return false; + } + + public String getSyntax() { + return "sqrt()"; + } + + @Override + public Object getResult() { + return result; + } +} diff --git a/engine/src/test/java/com/arcadedb/query/sql/functions/math/SQLFunctionSquareRootTest.java b/engine/src/test/java/com/arcadedb/query/sql/functions/math/SQLFunctionSquareRootTest.java new file mode 100644 index 0000000000..d1eecdf6ef --- /dev/null +++ b/engine/src/test/java/com/arcadedb/query/sql/functions/math/SQLFunctionSquareRootTest.java @@ -0,0 +1,178 @@ +/* + * Copyright © 2021-present Arcade Data Ltd (info@arcadedata.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-FileCopyrightText: 2021-present Arcade Data Ltd (info@arcadedata.com) + * SPDX-License-Identifier: Apache-2.0 + */ +package com.arcadedb.query.sql.functions.math; + +import com.arcadedb.TestHelper; +import com.arcadedb.query.sql.executor.ResultSet; +import com.arcadedb.query.sql.function.math.SQLFunctionSquareRoot; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.math.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SQLFunctionSquareRootTest { + + private SQLFunctionSquareRoot function; + + @BeforeEach + public void setup() { + function = new SQLFunctionSquareRoot(); + } + + @Test + public void testEmpty() { + final Object result = function.getResult(); + assertNull(result); + } + + @Test + public void testNull() { + function.execute(null, null, null, new Object[]{null}, null); + final Object result = function.getResult(); + assertNull(result); + } + + @Test + public void testPositiveInteger() { + function.execute(null, null, null, new Object[]{4}, null); + final Object result = function.getResult(); + assertTrue(result instanceof Integer); + assertEquals(result, 2); + } + + @Test + public void testNegativeInteger() { + function.execute(null, null, null, new Object[]{-4}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveLong() { + function.execute(null, null, null, new Object[]{4L}, null); + final Object result = function.getResult(); + assertTrue(result instanceof Long); + assertEquals(result, 2L); + } + + @Test + public void testNegativeLong() { + function.execute(null, null, null, new Object[]{-4L}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveShort() { + function.execute(null, null, null, new Object[]{(short) 4}, null); + final Object result = function.getResult(); + assertTrue(result instanceof Short); + assertEquals(result, (short) 2); + } + + @Test + public void testNegativeShort() { + function.execute(null, null, null, new Object[]{(short) -4}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveDouble() { + function.execute(null, null, null, new Object[]{4.0D}, null); + final Object result = function.getResult(); + assertTrue(result instanceof Double); + assertEquals(result, 2.0D); + } + + @Test + public void testNegativeDouble() { + function.execute(null, null, null, new Object[]{-4.0D}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveFloat() { + function.execute(null, null, null, new Object[]{4.0F}, null); + final Object result = function.getResult(); + assertTrue(result instanceof Float); + assertEquals(result, 2.0F); + } + + @Test + public void testNegativeFloat() { + function.execute(null, null, null, new Object[]{-4.0F}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveBigDecimal() { + function.execute(null, null, null, new Object[]{new BigDecimal("4.0")}, null); + final Object result = function.getResult(); + assertTrue(result instanceof BigDecimal); + assertEquals(result, new BigDecimal("2")); + } + + @Test + public void testNegativeBigDecimal() { + function.execute(null, null, null, new Object[]{BigDecimal.valueOf(-4.0D)}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testPositiveBigInteger() { + function.execute(null, null, null, new Object[]{new BigInteger("4")}, null); + final Object result = function.getResult(); + assertTrue(result instanceof BigInteger); + assertEquals(result, new BigInteger("2")); + } + + @Test + public void testNegativeBigInteger() { + function.execute(null, null, null, new Object[]{new BigInteger("-4")}, null); + final Object result = function.getResult(); + assertEquals(result, null); + } + + @Test + public void testNonNumber() { + try { + function.execute(null, null, null, new Object[]{"abc"}, null); + Assertions.fail("Expected IllegalArgumentException"); + } catch (final IllegalArgumentException e) { + // OK + } + } + + @Test + public void testFromQuery() throws Exception { + TestHelper.executeInNewDatabase("./target/databases/testSqrtFunction", (db) -> { + final ResultSet result = db.query("sql", "select sqrt(4.0) as sqrt"); + assertEquals(2.0F, ((Number) result.next().getProperty("sqrt")).floatValue()); + }); + } +}