diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index d43900e3d..c4b06edd6 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -318,6 +318,9 @@ class SnowflakeConnectionInfo(BaseConnectionInfo): user: SecretStr = Field( description="the username of your database", examples=["admin"] ) + password: SecretStr | None = Field( + description="the password of your database", examples=["password"], default=None + ) account: SecretStr = Field( description="the account name of your database", examples=["myaccount"] ) @@ -329,8 +332,10 @@ class SnowflakeConnectionInfo(BaseConnectionInfo): description="the schema name of your database", examples=["myschema"], ) # Use `sf_schema` to avoid `schema` shadowing in BaseModel - warehouse: SecretStr = Field( - description="the warehouse name of your database", examples=["COMPUTE_WH"] + warehouse: SecretStr | None = Field( + description="the warehouse name of your database", + examples=["COMPUTE_WH"], + default=None, ) private_key: SecretStr | None = Field( description="the private key for key pair authentication", diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 187acc532..6f371c617 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -218,15 +218,25 @@ def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend: @staticmethod def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend: - return ibis.snowflake.connect( - user=info.user.get_secret_value(), - account=info.account.get_secret_value(), - database=info.database.get_secret_value(), - schema=info.sf_schema.get_secret_value(), - warehouse=info.warehouse.get_secret_value(), - private_key=info.private_key.get_secret_value(), - **info.kwargs if info.kwargs else dict(), - ) + if hasattr(info, "private_key") and info.private_key: + return ibis.snowflake.connect( + user=info.user.get_secret_value(), + account=info.account.get_secret_value(), + database=info.database.get_secret_value(), + schema=info.sf_schema.get_secret_value(), + warehouse=info.warehouse.get_secret_value(), + private_key=info.private_key.get_secret_value(), + **info.kwargs if info.kwargs else dict(), + ) + else: + return ibis.snowflake.connect( + user=info.user.get_secret_value(), + password=info.password.get_secret_value(), + account=info.account.get_secret_value(), + database=info.database.get_secret_value(), + schema=info.sf_schema.get_secret_value(), + **info.kwargs if info.kwargs else dict(), + ) @staticmethod def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: diff --git a/ibis-server/tests/routers/v2/connector/test_snowflake.py b/ibis-server/tests/routers/v2/connector/test_snowflake.py index f126acebc..b8f166683 100644 --- a/ibis-server/tests/routers/v2/connector/test_snowflake.py +++ b/ibis-server/tests/routers/v2/connector/test_snowflake.py @@ -19,6 +19,14 @@ "private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"), } +password_connection_info = { + "user": os.getenv("SNOWFLAKE_USER"), + "password": os.getenv("SNOWFLAKE_PASSWORD"), + "account": os.getenv("SNOWFLAKE_ACCOUNT"), + "database": "SNOWFLAKE_SAMPLE_DATA", + "schema": "TPCH_SF1", +} + manifest = { "catalog": "my_catalog", "schema": "my_schema", @@ -110,6 +118,43 @@ async def test_query(client, manifest_str): } +async def test_query_with_password_connection_info(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", + None, + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "string", + "totalprice": "decimal128(12, 2)", + "orderdate": "date32[day]", + "order_cust_key": "string", + "timestamp": "timestamp[ns]", + "timestamptz": "timestamp[ns, tz=UTC]", + "test_null_time": "timestamp[ns]", + } + + async def test_query_without_manifest(client): response = await client.post( url=f"{base_url}/query",