diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 49a54f9dc..a7fa6482c 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -348,25 +348,37 @@ def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend: @staticmethod def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend: + # private key authentication if hasattr(info, "private_key") and info.private_key: - return ibis.snowflake.connect( - user=info.user.get_secret_value(), - account=info.account.get_secret_value(), - database=info.database.get_secret_value(), - schema=info.sf_schema.get_secret_value(), - warehouse=info.warehouse.get_secret_value(), - private_key=info.private_key.get_secret_value(), - **info.kwargs if info.kwargs else dict(), - ) + connection_params = { + "user": info.user.get_secret_value(), + "private_key": info.private_key.get_secret_value(), + "account": info.account.get_secret_value(), + "database": info.database.get_secret_value(), + "schema": info.sf_schema.get_secret_value(), + } + # warehouse if it exists and is not None/empty + if hasattr(info, "warehouse") and info.warehouse: + connection_params["warehouse"] = info.warehouse.get_secret_value() + if info.kwargs: + connection_params.update(info.kwargs) + return ibis.snowflake.connect(**connection_params) else: - return ibis.snowflake.connect( - user=info.user.get_secret_value(), - password=info.password.get_secret_value(), - account=info.account.get_secret_value(), - database=info.database.get_secret_value(), - schema=info.sf_schema.get_secret_value(), - **info.kwargs if info.kwargs else dict(), - ) + # password authentication + connection_params = { + "user": info.user.get_secret_value(), + "password": info.password.get_secret_value(), + "account": info.account.get_secret_value(), + "database": info.database.get_secret_value(), + "schema": info.sf_schema.get_secret_value(), + } + + # warehouse if it exists and is not None/empty + if hasattr(info, "warehouse") and info.warehouse: + connection_params["warehouse"] = info.warehouse.get_secret_value() + if info.kwargs: + connection_params.update(info.kwargs) + return ibis.snowflake.connect(**connection_params) @staticmethod def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: diff --git a/ibis-server/tests/routers/v2/connector/test_snowflake.py b/ibis-server/tests/routers/v2/connector/test_snowflake.py index 419ed247b..6c3789da5 100644 --- a/ibis-server/tests/routers/v2/connector/test_snowflake.py +++ b/ibis-server/tests/routers/v2/connector/test_snowflake.py @@ -19,6 +19,14 @@ "private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"), } +connection_info_without_warehouse = { + "user": os.getenv("SNOWFLAKE_USER"), + "account": os.getenv("SNOWFLAKE_ACCOUNT"), + "database": "SNOWFLAKE_SAMPLE_DATA", + "schema": "TPCH_SF1", + "private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"), +} + password_connection_info = { "user": os.getenv("SNOWFLAKE_USER"), "password": os.getenv("SNOWFLAKE_PASSWORD"), @@ -109,7 +117,44 @@ async def test_query(client, manifest_str): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(12, 2)", + "totalprice": "decimal128(38, 9)", + "orderdate": "date32[day]", + "order_cust_key": "string", + "timestamp": "timestamp[ns]", + "timestamptz": "timestamp[ns, tz=UTC]", + "test_null_time": "timestamp[ns]", + } + + +async def test_query_without_warehouse(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info_without_warehouse, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 +00:00", + None, + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "string", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[ns]", @@ -139,14 +184,14 @@ async def test_query_with_password_connection_info(client, manifest_str): "1996-01-02", "1_36901", "2024-01-01 23:59:59.000000", - "2024-01-01 23:59:59.000000 UTC", + "2024-01-01 23:59:59.000000 +00:00", None, ] assert result["dtypes"] == { "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(12, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[ns]",