diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index af4b46964..67e84f2c6 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -131,6 +131,9 @@ class ClickHouseConnectionInfo(BaseConnectionInfo): password: SecretStr | None = Field( description="the password of your database", examples=["password"], default=None ) + kwargs: dict[str, str] | None = Field( + description="Client specific keyword arguments", default=None + ) class MSSqlConnectionInfo(BaseConnectionInfo): @@ -200,15 +203,25 @@ class PostgresConnectionInfo(BaseConnectionInfo): password: SecretStr | None = Field( examples=["password"], description="the password of your database", default=None ) + kwargs: dict[str, str] | None = Field( + description="Additional keyword arguments to pass to the backend client connection.", + default=None, + ) class OracleConnectionInfo(BaseConnectionInfo): host: SecretStr = Field( - examples=["localhost"], description="the hostname of your database" + examples=["localhost"], + description="the hostname of your database", + default="localhost", + ) + port: SecretStr = Field( + examples=[1521], description="the port of your database", default="1521" ) - port: SecretStr = Field(examples=[1521], description="the port of your database") database: SecretStr = Field( - examples=["orcl"], description="the database name of your database" + examples=["orcl"], + description="the database name of your database", + default="orcl", ) user: SecretStr = Field( examples=["admin"], description="the username of your database" @@ -216,6 +229,11 @@ class OracleConnectionInfo(BaseConnectionInfo): password: SecretStr | None = Field( examples=["password"], description="the password of your database", default=None ) + dsn: SecretStr | None = Field( + default=None, + description="An Oracle Data Source Name. If provided, overrides all other connection arguments except username and password.", + examples=["localhost:1521/orcl"], + ) class SnowflakeConnectionInfo(BaseConnectionInfo): @@ -236,6 +254,10 @@ class SnowflakeConnectionInfo(BaseConnectionInfo): description="the schema name of your database", examples=["myschema"], ) # Use `sf_schema` to avoid `schema` shadowing in BaseModel + kwargs: dict[str, str] | None = Field( + description="Additional arguments passed to the DBAPI connection call.", + default=None, + ) class TrinoConnectionInfo(BaseConnectionInfo): @@ -257,6 +279,10 @@ class TrinoConnectionInfo(BaseConnectionInfo): password: SecretStr | None = Field( description="the password of your database", examples=["password"], default=None ) + kwargs: dict[str, str] | None = Field( + description="Additional keyword arguments passed directly to the trino.dbapi.connect API.", + default=None, + ) class LocalFileConnectionInfo(BaseConnectionInfo): diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 4328fec6b..2d766973c 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -135,6 +135,7 @@ def get_clickhouse_connection(info: ClickHouseConnectionInfo) -> BaseBackend: database=info.database.get_secret_value(), user=info.user.get_secret_value(), password=(info.password and info.password.get_secret_value()), + **info.kwargs if info.kwargs else dict(), ) @classmethod @@ -177,10 +178,19 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend: database=info.database.get_secret_value(), user=info.user.get_secret_value(), password=(info.password and info.password.get_secret_value()), + **info.kwargs if info.kwargs else dict(), ) @staticmethod def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend: + # if dsn is provided, use it to connect + # otherwise, use host, port, database, user, password, and sid + if hasattr(info, "dsn") and info.dsn: + return ibis.oracle.connect( + dsn=info.dsn.get_secret_value(), + user=info.user.get_secret_value(), + password=(info.password and info.password.get_secret_value()), + ) return ibis.oracle.connect( host=info.host.get_secret_value(), port=int(info.port.get_secret_value()), @@ -197,6 +207,7 @@ def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend: account=info.account.get_secret_value(), database=info.database.get_secret_value(), schema=info.sf_schema.get_secret_value(), + **info.kwargs if info.kwargs else dict(), ) @staticmethod @@ -208,6 +219,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: schema=info.trino_schema.get_secret_value(), user=(info.user and info.user.get_secret_value()), password=(info.password and info.password.get_secret_value()), + **info.kwargs if info.kwargs else dict(), ) @staticmethod diff --git a/ibis-server/tests/routers/v3/connector/oracle/test_query.py b/ibis-server/tests/routers/v3/connector/oracle/test_query.py index 22a8fe983..ce4046aee 100644 --- a/ibis-server/tests/routers/v3/connector/oracle/test_query.py +++ b/ibis-server/tests/routers/v3/connector/oracle/test_query.py @@ -132,3 +132,29 @@ async def test_query_with_connection_url(client, manifest_str, connection_url): assert len(result["data"]) == 1 assert result["data"][0][0] == 1 assert result["dtypes"] is not None + + +async def test_query_with_dsn(client, manifest_str, connection_info): + dsn = f"{connection_info['host']}:{connection_info['port']}/{connection_info['database']}" + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": { + "dsn": dsn, + "user": connection_info["user"], + "password": connection_info["password"], + }, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 200 + result = response.json() + # include one hidden column + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1 + assert len(result["data"]) == 1 + assert result["data"][0][0] == 1 + assert result["dtypes"] is not None