Skip to content

Commit aae1b04

Browse files
authored
refactor: fix some string ops in the sqlglot compiler. (#2299)
This change aims to fix some string-related tests failing in #2248. Fixes internal issue 417774347🦕
1 parent 33a211e commit aae1b04

File tree

7 files changed

+42
-14
lines changed

7 files changed

+42
-14
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
107107
sge.If(this=sge.convert(key), true=sge.convert(value))
108108
for key, value in op.mappings
109109
],
110+
default=expr.expr,
110111
)
111112

112113

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sqlglot.expressions as sge
2020

21+
from bigframes import dtypes
2122
from bigframes import operations as ops
2223
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2324
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
@@ -195,6 +196,9 @@ def _(expr: TypedExpr) -> sge.Expression:
195196

196197
@register_unary_op(ops.len_op)
197198
def _(expr: TypedExpr) -> sge.Expression:
199+
if dtypes.is_array_like(expr.dtype):
200+
return sge.func("ARRAY_LENGTH", expr.expr)
201+
198202
return sge.Length(this=expr.expr)
199203

200204

@@ -239,7 +243,7 @@ def to_startswith(pat: str) -> sge.Expression:
239243

240244
@register_unary_op(ops.StrStripOp, pass_op=True)
241245
def _(expr: TypedExpr, op: ops.StrStripOp) -> sge.Expression:
242-
return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr)
246+
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip))
243247

244248

245249
@register_unary_op(ops.StringSplitOp, pass_op=True)
@@ -284,27 +288,29 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
284288

285289
@register_unary_op(ops.ZfillOp, pass_op=True)
286290
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
291+
length_expr = sge.Greatest(
292+
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
293+
)
287294
return sge.Case(
288295
ifs=[
289296
sge.If(
290-
this=sge.EQ(
291-
this=sge.Substring(
292-
this=expr.expr, start=sge.convert(1), length=sge.convert(1)
293-
),
294-
expression=sge.convert("-"),
297+
this=sge.func(
298+
"STARTS_WITH",
299+
expr.expr,
300+
sge.convert("-"),
295301
),
296302
true=sge.Concat(
297303
expressions=[
298304
sge.convert("-"),
299305
sge.func(
300306
"LPAD",
301-
sge.Substring(this=expr.expr, start=sge.convert(1)),
302-
sge.convert(op.width - 1),
307+
sge.Substring(this=expr.expr, start=sge.convert(2)),
308+
length_expr - 1,
303309
sge.convert("0"),
304310
),
305311
]
306312
),
307313
)
308314
],
309-
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
315+
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
310316
)

tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE `string_col` WHEN 'value1' THEN 'mapped1' END AS `bfcol_1`
8+
CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int_list_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ARRAY_LENGTH(`int_list_col`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `int_list_col`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_strip/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
TRIM(' ', `string_col`) AS `bfcol_1`
8+
TRIM(`string_col`, ' ') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_zfill/out.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ WITH `bfcte_0` AS (
66
SELECT
77
*,
88
CASE
9-
WHEN SUBSTRING(`string_col`, 1, 1) = '-'
10-
THEN CONCAT('-', LPAD(SUBSTRING(`string_col`, 1), 9, '0'))
11-
ELSE LPAD(`string_col`, 10, '0')
9+
WHEN STARTS_WITH(`string_col`, '-')
10+
THEN CONCAT('-', LPAD(SUBSTRING(`string_col`, 2), GREATEST(LENGTH(`string_col`), 10) - 1, '0'))
11+
ELSE LPAD(`string_col`, GREATEST(LENGTH(`string_col`), 10), '0')
1212
END AS `bfcol_1`
1313
FROM `bfcte_0`
1414
)

tests/unit/core/compile/sqlglot/expressions/test_string_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def test_len(scalar_types_df: bpd.DataFrame, snapshot):
120120
snapshot.assert_match(sql, "out.sql")
121121

122122

123+
def test_len_w_array(repeated_types_df: bpd.DataFrame, snapshot):
124+
col_name = "int_list_col"
125+
bf_df = repeated_types_df[[col_name]]
126+
sql = utils._apply_ops_to_sql(bf_df, [ops.len_op.as_expr(col_name)], [col_name])
127+
128+
snapshot.assert_match(sql, "out.sql")
129+
130+
123131
def test_lower(scalar_types_df: bpd.DataFrame, snapshot):
124132
col_name = "string_col"
125133
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)