Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
user: SecretStr = Field(
description="the username of your database", examples=["admin"]
)
password: SecretStr | None = Field(
description="the password of your database", examples=["password"], default=None
)
account: SecretStr = Field(
description="the account name of your database", examples=["myaccount"]
)
Expand All @@ -329,8 +332,10 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
description="the schema name of your database",
examples=["myschema"],
) # Use `sf_schema` to avoid `schema` shadowing in BaseModel
warehouse: SecretStr = Field(
description="the warehouse name of your database", examples=["COMPUTE_WH"]
warehouse: SecretStr | None = Field(
description="the warehouse name of your database",
examples=["COMPUTE_WH"],
default=None,
)
private_key: SecretStr | None = Field(
description="the private key for key pair authentication",
Expand Down
28 changes: 19 additions & 9 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,25 @@ def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend:

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
warehouse=info.warehouse.get_secret_value(),
private_key=info.private_key.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)
if hasattr(info, "private_key") and info.private_key:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
warehouse=info.warehouse.get_secret_value(),
private_key=info.private_key.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)
else:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)
Comment on lines +221 to +239
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Verify authentication credential validation and improve code style.

The authentication branching logic is sound, but consider these improvements:

  1. Add explicit validation for missing credentials to provide clearer error messages
  2. Simplify the else clause and improve dictionary usage per static analysis hints

Apply this diff to improve the implementation:

 @staticmethod
 def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
-    if hasattr(info, "private_key") and info.private_key:
+    if hasattr(info, "private_key") and info.private_key:
         return ibis.snowflake.connect(
             user=info.user.get_secret_value(),
             account=info.account.get_secret_value(),
             database=info.database.get_secret_value(),
             schema=info.sf_schema.get_secret_value(),
             warehouse=info.warehouse.get_secret_value(),
             private_key=info.private_key.get_secret_value(),
-            **info.kwargs if info.kwargs else dict(),
+            **info.kwargs if info.kwargs else {},
         )
-    else:
+    
+    if not info.password:
+        raise ValueError("Either private_key or password must be provided for Snowflake authentication")
+    
+    return ibis.snowflake.connect(
+        user=info.user.get_secret_value(),
+        password=info.password.get_secret_value(),
+        account=info.account.get_secret_value(),
+        database=info.database.get_secret_value(),
+        schema=info.sf_schema.get_secret_value(),
+        **info.kwargs if info.kwargs else {},
+    )
-        return ibis.snowflake.connect(
-            user=info.user.get_secret_value(),
-            password=info.password.get_secret_value(),
-            account=info.account.get_secret_value(),
-            database=info.database.get_secret_value(),
-            schema=info.sf_schema.get_secret_value(),
-            **info.kwargs if info.kwargs else dict(),
-        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(info, "private_key") and info.private_key:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
warehouse=info.warehouse.get_secret_value(),
private_key=info.private_key.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)
else:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)
@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
if hasattr(info, "private_key") and info.private_key:
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
warehouse=info.warehouse.get_secret_value(),
private_key=info.private_key.get_secret_value(),
**info.kwargs if info.kwargs else {},
)
if not info.password:
raise ValueError(
"Either private_key or password must be provided for Snowflake authentication"
)
return ibis.snowflake.connect(
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
**info.kwargs if info.kwargs else {},
)
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 221-239: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it

(R1705)


[refactor] 229-229: Consider using '{}' instead of a call to 'dict'.

(R1735)


[refactor] 238-238: Consider using '{}' instead of a call to 'dict'.

(R1735)

🤖 Prompt for AI Agents
In ibis-server/app/model/data_source.py around lines 221 to 239, add explicit
checks to validate that required authentication credentials (like user,
password, or private_key) are present before attempting connection, raising
clear errors if missing. Simplify the else clause by using a single dictionary
unpacking with a default empty dict for info.kwargs, avoiding redundant
conditional expressions to improve code clarity and comply with static analysis
recommendations.


@staticmethod
def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
Expand Down
45 changes: 45 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
"private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"),
}

password_connection_info = {
"user": os.getenv("SNOWFLAKE_USER"),
"password": os.getenv("SNOWFLAKE_PASSWORD"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"database": "SNOWFLAKE_SAMPLE_DATA",
"schema": "TPCH_SF1",
}

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
Expand Down Expand Up @@ -110,6 +118,43 @@ async def test_query(client, manifest_str):
}


async def test_query_with_password_connection_info(client, manifest_str):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1',
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1
assert result["data"][0] == [
1,
36901,
"O",
"173665.47",
"1996-01-02",
"1_36901",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
None,
]
assert result["dtypes"] == {
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(12, 2)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[ns]",
"timestamptz": "timestamp[ns, tz=UTC]",
"test_null_time": "timestamp[ns]",
}


async def test_query_without_manifest(client):
response = await client.post(
url=f"{base_url}/query",
Expand Down
Loading