Skip to content

Commit 2dcf6ae

Browse files
authored
chore: Migrate pow_op operator to SQLGlot (#2291)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes b/447388852 🦕
1 parent d1ecc61 commit 2dcf6ae

File tree

4 files changed

+492
-0
lines changed

4 files changed

+492
-0
lines changed

bigframes/core/compile/sqlglot/expressions/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
16+
1517
import sqlglot.expressions as sge
1618

1719
_ZERO = sge.Cast(this=sge.convert(0), to="INT64")
@@ -23,3 +25,13 @@
2325
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
2426
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
2527
_FLOAT64_EXP_BOUND = sge.convert(709.78)
28+
29+
# The natural logarithm of the maximum value for a signed 64-bit integer.
30+
# This is used to check for potential overflows in power operations involving integers
31+
# by checking if `exponent * log(base)` exceeds this value.
32+
_INT64_LOG_BOUND = math.log(2**63 - 1)
33+
34+
# Represents the largest integer N where all integers from -N to N can be
35+
# represented exactly as a float64. Float64 types have a 53-bit significand precision,
36+
# so integers beyond this value may lose precision.
37+
_FLOAT64_MAX_INT_PRECISION = 2**53

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,141 @@ def _(expr: TypedExpr) -> sge.Expression:
210210
return expr.expr
211211

212212

213+
@register_binary_op(ops.pow_op)
214+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
215+
left_expr = _coerce_bool_to_int(left)
216+
right_expr = _coerce_bool_to_int(right)
217+
if left.dtype == dtypes.INT_DTYPE and right.dtype == dtypes.INT_DTYPE:
218+
return _int_pow_op(left_expr, right_expr)
219+
else:
220+
return _float_pow_op(left_expr, right_expr)
221+
222+
223+
def _int_pow_op(
224+
left_expr: sge.Expression, right_expr: sge.Expression
225+
) -> sge.Expression:
226+
overflow_cond = sge.and_(
227+
sge.NEQ(this=left_expr, expression=sge.convert(0)),
228+
sge.GT(
229+
this=sge.Mul(
230+
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
231+
),
232+
expression=sge.convert(constants._INT64_LOG_BOUND),
233+
),
234+
)
235+
236+
return sge.Case(
237+
ifs=[
238+
sge.If(
239+
this=overflow_cond,
240+
true=sge.Null(),
241+
)
242+
],
243+
default=sge.Cast(
244+
this=sge.Pow(
245+
this=sge.Cast(
246+
this=left_expr, to=sge.DataType(this=sge.DataType.Type.DECIMAL)
247+
),
248+
expression=right_expr,
249+
),
250+
to="INT64",
251+
),
252+
)
253+
254+
255+
def _float_pow_op(
256+
left_expr: sge.Expression, right_expr: sge.Expression
257+
) -> sge.Expression:
258+
# Most conditions here seek to prevent calling BQ POW with inputs that would generate errors.
259+
# See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow
260+
overflow_cond = sge.and_(
261+
sge.NEQ(this=left_expr, expression=constants._ZERO),
262+
sge.GT(
263+
this=sge.Mul(
264+
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
265+
),
266+
expression=constants._FLOAT64_EXP_BOUND,
267+
),
268+
)
269+
270+
# Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity
271+
exp_too_big = sge.GT(
272+
this=sge.Abs(this=right_expr),
273+
expression=sge.convert(constants._FLOAT64_MAX_INT_PRECISION),
274+
)
275+
# Treat very large exponents as +=INF
276+
norm_exp = sge.Case(
277+
ifs=[
278+
sge.If(
279+
this=exp_too_big,
280+
true=sge.Mul(this=constants._INF, expression=sge.Sign(this=right_expr)),
281+
)
282+
],
283+
default=right_expr,
284+
)
285+
286+
pow_result = sge.Pow(this=left_expr, expression=norm_exp)
287+
288+
# This cast is dangerous, need to only excuted where y_val has been bounds-checked
289+
# Ibis needs try_cast binding to bq safe_cast
290+
exponent_is_whole = sge.EQ(
291+
this=sge.Cast(this=right_expr, to="INT64"), expression=right_expr
292+
)
293+
odd_exponent = sge.and_(
294+
sge.LT(this=left_expr, expression=constants._ZERO),
295+
sge.EQ(
296+
this=sge.Mod(
297+
this=sge.Cast(this=right_expr, to="INT64"), expression=sge.convert(2)
298+
),
299+
expression=sge.convert(1),
300+
),
301+
)
302+
infinite_base = sge.EQ(this=sge.Abs(this=left_expr), expression=constants._INF)
303+
304+
return sge.Case(
305+
ifs=[
306+
# Might be able to do something more clever with x_val==0 case
307+
sge.If(
308+
this=sge.EQ(this=right_expr, expression=constants._ZERO),
309+
true=sge.convert(1),
310+
),
311+
sge.If(
312+
this=sge.EQ(this=left_expr, expression=sge.convert(1)),
313+
true=sge.convert(1),
314+
), # Need to ignore exponent, even if it is NA
315+
sge.If(
316+
this=sge.and_(
317+
sge.EQ(this=left_expr, expression=constants._ZERO),
318+
sge.LT(this=right_expr, expression=constants._ZERO),
319+
),
320+
true=constants._INF,
321+
), # This case would error POW function in BQ
322+
sge.If(this=infinite_base, true=pow_result),
323+
sge.If(
324+
this=exp_too_big, true=pow_result
325+
), # Bigquery can actually handle the +-inf cases gracefully
326+
sge.If(
327+
this=sge.and_(
328+
sge.LT(this=left_expr, expression=constants._ZERO),
329+
sge.Not(this=exponent_is_whole),
330+
),
331+
true=constants._NAN,
332+
),
333+
sge.If(
334+
this=overflow_cond,
335+
true=sge.Mul(
336+
this=constants._INF,
337+
expression=sge.Case(
338+
ifs=[sge.If(this=odd_exponent, true=sge.convert(-1))],
339+
default=sge.convert(1),
340+
),
341+
),
342+
), # finite overflows would cause bq to error
343+
],
344+
default=pow_result,
345+
)
346+
347+
213348
@register_unary_op(ops.sqrt_op)
214349
def _(expr: TypedExpr) -> sge.Expression:
215350
return sge.Case(

0 commit comments

Comments
 (0)