Skip to content

Commit 24ba4a1

Browse files
authored
chore: Migrate cosine_distance_op operator to SQLGlot (#2236)
Fixes b/447388852 🦕
1 parent 5bee573 commit 24ba4a1

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def _(expr: TypedExpr) -> sge.Expression:
125125
)
126126

127127

128+
@register_binary_op(ops.cosine_distance_op)
129+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
130+
return sge.Anonymous(
131+
this="ML.DISTANCE",
132+
expressions=[
133+
left.expr,
134+
right.expr,
135+
sge.Literal.string("COSINE"),
136+
],
137+
)
138+
139+
128140
@register_unary_op(ops.exp_op)
129141
def _(expr: TypedExpr) -> sge.Expression:
130142
return sge.Case(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int_list_col` AS `bfcol_0`,
4+
`float_list_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
ML.DISTANCE(`bfcol_0`, `bfcol_0`, 'COSINE') AS `bfcol_2`,
10+
ML.DISTANCE(`bfcol_1`, `bfcol_1`, 'COSINE') AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_2` AS `int_list_col`,
15+
`bfcol_3` AS `float_list_col`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ def test_cosh(scalar_types_df: bpd.DataFrame, snapshot):
117117
snapshot.assert_match(sql, "out.sql")
118118

119119

120+
def test_cosine_distance(repeated_types_df: bpd.DataFrame, snapshot):
121+
col_names = ["int_list_col", "float_list_col"]
122+
bf_df = repeated_types_df[col_names]
123+
124+
sql = utils._apply_ops_to_sql(
125+
bf_df,
126+
[
127+
ops.cosine_distance_op.as_expr("int_list_col", "int_list_col"),
128+
ops.cosine_distance_op.as_expr("float_list_col", "float_list_col"),
129+
],
130+
["int_list_col", "float_list_col"],
131+
)
132+
snapshot.assert_match(sql, "out.sql")
133+
134+
120135
def test_exp(scalar_types_df: bpd.DataFrame, snapshot):
121136
col_name = "float64_col"
122137
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)