Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ClickHouseConnectionInfo(BaseConnectionInfo):
password: SecretStr | None = Field(
description="the password of your database", examples=["password"], default=None
)
kwargs: dict[str, str] | None = Field(
description="Client specific keyword arguments", default=None
)


class MSSqlConnectionInfo(BaseConnectionInfo):
Expand Down Expand Up @@ -200,22 +203,37 @@ class PostgresConnectionInfo(BaseConnectionInfo):
password: SecretStr | None = Field(
examples=["password"], description="the password of your database", default=None
)
kwargs: dict[str, str] | None = Field(
description="Additional keyword arguments to pass to the backend client connection.",
default=None,
)


class OracleConnectionInfo(BaseConnectionInfo):
host: SecretStr = Field(
examples=["localhost"], description="the hostname of your database"
examples=["localhost"],
description="the hostname of your database",
default="localhost",
)
port: SecretStr = Field(
examples=[1521], description="the port of your database", default="1521"
)
port: SecretStr = Field(examples=[1521], description="the port of your database")
database: SecretStr = Field(
examples=["orcl"], description="the database name of your database"
examples=["orcl"],
description="the database name of your database",
default="orcl",
)
user: SecretStr = Field(
examples=["admin"], description="the username of your database"
)
password: SecretStr | None = Field(
examples=["password"], description="the password of your database", default=None
)
dsn: SecretStr | None = Field(
default=None,
description="An Oracle Data Source Name. If provided, overrides all other connection arguments except username and password.",
examples=["localhost:1521/orcl"],
)


class SnowflakeConnectionInfo(BaseConnectionInfo):
Expand All @@ -236,6 +254,10 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
description="the schema name of your database",
examples=["myschema"],
) # Use `sf_schema` to avoid `schema` shadowing in BaseModel
kwargs: dict[str, str] | None = Field(
description="Additional arguments passed to the DBAPI connection call.",
default=None,
)


class TrinoConnectionInfo(BaseConnectionInfo):
Expand All @@ -257,6 +279,10 @@ class TrinoConnectionInfo(BaseConnectionInfo):
password: SecretStr | None = Field(
description="the password of your database", examples=["password"], default=None
)
kwargs: dict[str, str] | None = Field(
description="Additional keyword arguments passed directly to the trino.dbapi.connect API.",
default=None,
)


class LocalFileConnectionInfo(BaseConnectionInfo):
Expand Down
12 changes: 12 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_clickhouse_connection(info: ClickHouseConnectionInfo) -> BaseBackend:
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
**info.kwargs if info.kwargs else dict(),
)

@classmethod
Expand Down Expand Up @@ -177,10 +178,19 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend:
# if dsn is provided, use it to connect
# otherwise, use host, port, database, user, password, and sid
if hasattr(info, "dsn") and info.dsn:
return ibis.oracle.connect(
dsn=info.dsn.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
)
return ibis.oracle.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
Expand All @@ -197,6 +207,7 @@ def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
account=info.account.get_secret_value(),
database=info.database.get_secret_value(),
schema=info.sf_schema.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
Expand All @@ -208,6 +219,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
schema=info.trino_schema.get_secret_value(),
user=(info.user and info.user.get_secret_value()),
password=(info.password and info.password.get_secret_value()),
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
Expand Down
26 changes: 26 additions & 0 deletions ibis-server/tests/routers/v3/connector/oracle/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,29 @@ async def test_query_with_connection_url(client, manifest_str, connection_url):
assert len(result["data"]) == 1
assert result["data"][0][0] == 1
assert result["dtypes"] is not None


async def test_query_with_dsn(client, manifest_str, connection_info):
dsn = f"{connection_info['host']}:{connection_info['port']}/{connection_info['database']}"
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": {
"dsn": dsn,
"user": connection_info["user"],
"password": connection_info["password"],
},
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
# include one hidden column
assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1
assert len(result["data"]) == 1
assert result["data"][0][0] == 1
assert result["dtypes"] is not None