Skip to content

Commit 7c062a6

Browse files
authored
feat: use end user credentials for bigframes.bigquery.ai functions when connection_id is not present (#2272)
Fixes #460856043 🦕
1 parent ca86686 commit 7c062a6

File tree

18 files changed

+181
-34
lines changed
  • bigframes
  • tests/unit/core/compile/sqlglot/expressions
    • snapshots/test_ai_ops
      • test_ai_generate_bool_with_connection_id
      • test_ai_generate_bool_with_model_param
      • test_ai_generate_bool
      • test_ai_generate_double_with_connection_id
      • test_ai_generate_double_with_model_param
      • test_ai_generate_double
      • test_ai_generate_int_with_connection_id
      • test_ai_generate_int_with_model_param
      • test_ai_generate_int
      • test_ai_generate_with_connection_id
      • test_ai_generate_with_model_param
      • test_ai_generate_with_output_schema
      • test_ai_generate
  • third_party/bigframes_vendored/ibis/expr/operations

18 files changed

+181
-34
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def generate(
8888
or pandas Series.
8989
connection_id (str, optional):
9090
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
91-
If not provided, the connection from the current session will be used.
91+
If not provided, the query uses your end-user credential.
9292
endpoint (str, optional):
9393
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
9494
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -131,7 +131,7 @@ def generate(
131131

132132
operator = ai_ops.AIGenerate(
133133
prompt_context=tuple(prompt_context),
134-
connection_id=_resolve_connection_id(series_list[0], connection_id),
134+
connection_id=connection_id,
135135
endpoint=endpoint,
136136
request_type=request_type,
137137
model_params=json.dumps(model_params) if model_params else None,
@@ -186,7 +186,7 @@ def generate_bool(
186186
or pandas Series.
187187
connection_id (str, optional):
188188
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
189-
If not provided, the connection from the current session will be used.
189+
If not provided, the query uses your end-user credential.
190190
endpoint (str, optional):
191191
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
192192
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -216,7 +216,7 @@ def generate_bool(
216216

217217
operator = ai_ops.AIGenerateBool(
218218
prompt_context=tuple(prompt_context),
219-
connection_id=_resolve_connection_id(series_list[0], connection_id),
219+
connection_id=connection_id,
220220
endpoint=endpoint,
221221
request_type=request_type,
222222
model_params=json.dumps(model_params) if model_params else None,
@@ -267,7 +267,7 @@ def generate_int(
267267
or pandas Series.
268268
connection_id (str, optional):
269269
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
270-
If not provided, the connection from the current session will be used.
270+
If not provided, the query uses your end-user credential.
271271
endpoint (str, optional):
272272
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
273273
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -297,7 +297,7 @@ def generate_int(
297297

298298
operator = ai_ops.AIGenerateInt(
299299
prompt_context=tuple(prompt_context),
300-
connection_id=_resolve_connection_id(series_list[0], connection_id),
300+
connection_id=connection_id,
301301
endpoint=endpoint,
302302
request_type=request_type,
303303
model_params=json.dumps(model_params) if model_params else None,
@@ -348,7 +348,7 @@ def generate_double(
348348
or pandas Series.
349349
connection_id (str, optional):
350350
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
351-
If not provided, the connection from the current session will be used.
351+
If not provided, the query uses your end-user credential.
352352
endpoint (str, optional):
353353
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
354354
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -378,7 +378,7 @@ def generate_double(
378378

379379
operator = ai_ops.AIGenerateDouble(
380380
prompt_context=tuple(prompt_context),
381-
connection_id=_resolve_connection_id(series_list[0], connection_id),
381+
connection_id=connection_id,
382382
endpoint=endpoint,
383383
request_type=request_type,
384384
model_params=json.dumps(model_params) if model_params else None,

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
104104

105105
op_args = asdict(op)
106106

107-
connection_id = op_args["connection_id"]
108-
args.append(
109-
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
110-
)
107+
connection_id = op_args.get("connection_id", None)
108+
if connection_id is not None:
109+
args.append(
110+
sge.Kwarg(
111+
this="connection_id", expression=sge.Literal.string(connection_id)
112+
)
113+
)
111114

112115
endpoit = op_args.get("endpoint", None)
113116
if endpoit is not None:

bigframes/operations/ai_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AIGenerate(base_ops.NaryOp):
2929
name: ClassVar[str] = "ai_generate"
3030

3131
prompt_context: Tuple[str | None, ...]
32-
connection_id: str
32+
connection_id: str | None
3333
endpoint: str | None
3434
request_type: Literal["dedicated", "shared", "unspecified"]
3535
model_params: str | None
@@ -57,7 +57,7 @@ class AIGenerateBool(base_ops.NaryOp):
5757
name: ClassVar[str] = "ai_generate_bool"
5858

5959
prompt_context: Tuple[str | None, ...]
60-
connection_id: str
60+
connection_id: str | None
6161
endpoint: str | None
6262
request_type: Literal["dedicated", "shared", "unspecified"]
6363
model_params: str | None
@@ -79,7 +79,7 @@ class AIGenerateInt(base_ops.NaryOp):
7979
name: ClassVar[str] = "ai_generate_int"
8080

8181
prompt_context: Tuple[str | None, ...]
82-
connection_id: str
82+
connection_id: str | None
8383
endpoint: str | None
8484
request_type: Literal["dedicated", "shared", "unspecified"]
8585
model_params: str | None
@@ -101,7 +101,7 @@ class AIGenerateDouble(base_ops.NaryOp):
101101
name: ClassVar[str] = "ai_generate_double"
102102

103103
prompt_context: Tuple[str | None, ...]
104-
connection_id: str
104+
connection_id: str | None
105105
endpoint: str | None
106106
request_type: Literal["dedicated", "shared", "unspecified"]
107107
model_params: str | None

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE(
99
prompt => (`string_col`, ' is the same as ', `string_col`),
10-
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1110
endpoint => 'gemini-2.5-flash',
1211
request_type => 'SHARED'
1312
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`string_col`, ' is the same as ', `string_col`),
10-
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1110
endpoint => 'gemini-2.5-flash',
1211
request_type => 'SHARED'
1312
) AS `bfcol_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_BOOL(
9+
prompt => (`string_col`, ' is the same as ', `string_col`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`string_col`, ' is the same as ', `string_col`),
10-
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1110
request_type => 'SHARED',
1211
model_params => JSON '{}'
1312
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_DOUBLE(
99
prompt => (`string_col`, ' is the same as ', `string_col`),
10-
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1110
endpoint => 'gemini-2.5-flash',
1211
request_type => 'SHARED'
1312
) AS `bfcol_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_DOUBLE(
9+
prompt => (`string_col`, ' is the same as ', `string_col`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_DOUBLE(
99
prompt => (`string_col`, ' is the same as ', `string_col`),
10-
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1110
request_type => 'SHARED',
1211
model_params => JSON '{}'
1312
) AS `bfcol_1`

0 commit comments

Comments
 (0)