Skip to content

Commit 38fc833

Browse files
Handle zero denominator in divide and modulus for byte data type (#272) (#1716) (#1734)
* Fixed bug of byte/short values not handling divide/modulus arithmetic equations Signed-off-by: Matthew Wells <[email protected]> (cherry picked from commit 2c80631) Co-authored-by: Matthew Wells <[email protected]>
1 parent f133840 commit 38fc833

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ private static DefaultFunctionResolver addFunction() {
106106
private static DefaultFunctionResolver divideBase(FunctionName functionName) {
107107
return define(functionName,
108108
impl(nullMissingHandling(
109-
(v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())),
109+
(v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() :
110+
new ExprByteValue(v1.byteValue() / v2.byteValue())),
110111
BYTE, BYTE, BYTE),
111112
impl(nullMissingHandling(
112113
(v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() :
@@ -140,7 +141,7 @@ private static DefaultFunctionResolver divideFunction() {
140141
}
141142

142143
/**
143-
* Definition of modulo(x, y) function.
144+
* Definition of modulus(x, y) function.
144145
* Returns the number x modulo by number y
145146
* The supported signature of modulo function is
146147
* (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE)
@@ -149,7 +150,8 @@ private static DefaultFunctionResolver divideFunction() {
149150
private static DefaultFunctionResolver modulusBase(FunctionName functionName) {
150151
return define(functionName,
151152
impl(nullMissingHandling(
152-
(v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())),
153+
(v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() :
154+
new ExprByteValue(v1.byteValue() % v2.byteValue())),
153155
BYTE, BYTE, BYTE),
154156
impl(nullMissingHandling(
155157
(v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() :

core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void mod(ExprValue op1, ExprValue op2) {
113113
assertEquals(String.format("mod(%s, %s)", op1.toString(), op2.toString()),
114114
expression.toString());
115115

116-
expression = DSL.mod(literal(op1), literal(new ExprShortValue(0)));
116+
expression = DSL.mod(literal(op1), literal(new ExprByteValue(0)));
117117
assertTrue(expression.valueOf(valueEnv()).isNull());
118118
assertEquals(String.format("mod(%s, 0)", op1.toString()), expression.toString());
119119
}
@@ -128,7 +128,7 @@ public void modulus(ExprValue op1, ExprValue op2) {
128128
assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()),
129129
expression.toString());
130130

131-
expression = DSL.modulus(literal(op1), literal(new ExprShortValue(0)));
131+
expression = DSL.modulus(literal(op1), literal(new ExprByteValue(0)));
132132
assertTrue(expression.valueOf(valueEnv()).isNull());
133133
assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString());
134134
}
@@ -144,7 +144,7 @@ public void modulusFunction(ExprValue op1, ExprValue op2) {
144144
assertEquals(String.format("modulus(%s, %s)", op1.toString(), op2.toString()),
145145
expression.toString());
146146

147-
expression = DSL.modulusFunction(literal(op1), literal(new ExprShortValue(0)));
147+
expression = DSL.modulusFunction(literal(op1), literal(new ExprByteValue(0)));
148148
assertTrue(expression.valueOf(valueEnv()).isNull());
149149
assertEquals(String.format("modulus(%s, 0)", op1.toString()), expression.toString());
150150
}
@@ -183,7 +183,7 @@ public void divide(ExprValue op1, ExprValue op2) {
183183
assertEquals(String.format("/(%s, %s)", op1.toString(), op2.toString()),
184184
expression.toString());
185185

186-
expression = DSL.divide(literal(op1), literal(new ExprShortValue(0)));
186+
expression = DSL.divide(literal(op1), literal(new ExprByteValue(0)));
187187
assertTrue(expression.valueOf(valueEnv()).isNull());
188188
assertEquals(String.format("/(%s, 0)", op1.toString()), expression.toString());
189189
}
@@ -199,7 +199,7 @@ public void divideFunction(ExprValue op1, ExprValue op2) {
199199
assertEquals(String.format("divide(%s, %s)", op1.toString(), op2.toString()),
200200
expression.toString());
201201

202-
expression = DSL.divideFunction(literal(op1), literal(new ExprShortValue(0)));
202+
expression = DSL.divideFunction(literal(op1), literal(new ExprByteValue(0)));
203203
assertTrue(expression.valueOf(valueEnv()).isNull());
204204
assertEquals(String.format("divide(%s, 0)", op1.toString()), expression.toString());
205205
}

0 commit comments

Comments
 (0)