Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ibis-server/resources/function_list/oracle.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
function_type,name,return_type,param_names,param_types,description
scalar,to_timestamp_tz,timestamp with time zone,,,"Convert string to timestamp with time zone"
3 changes: 2 additions & 1 deletion ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"dataSource": "oracle",
"models": [
{
"name": "Orders",
Expand Down Expand Up @@ -139,7 +140,7 @@ def oracle(request) -> OracleDbContainer:
# Add table and column comments
conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'"))
conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'"))

request.addfinalizer(oracle.stop)
return oracle


Expand Down
Empty file.
82 changes: 82 additions & 0 deletions ibis-server/tests/routers/v3/connector/oracle/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pathlib

import pandas as pd
import pytest
import sqlalchemy
from sqlalchemy import text
from testcontainers.oracle import OracleDbContainer

from app.config import get_config
from tests.conftest import file_path

pytestmark = pytest.mark.oracle

base_url = "/v3/connector/oracle"
oracle_password = "Oracle123"
oracle_user = "SYSTEM"
oracle_database = "FREEPDB1"


def pytest_collection_modifyitems(items):
current_file_dir = pathlib.Path(__file__).resolve().parent
for item in items:
if pathlib.Path(item.fspath).is_relative_to(current_file_dir):
item.add_marker(pytestmark)


@pytest.fixture(scope="module")
def oracle(request) -> OracleDbContainer:
oracle = OracleDbContainer(
"gvenzl/oracle-free:23.6-slim-faststart", oracle_password=f"{oracle_password}"
).start()
engine = sqlalchemy.create_engine(oracle.get_connection_url())
with engine.begin() as conn:
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql(
"customer", engine, index=False
)

# Create a table with a large CLOB column
large_text = "x" * (1024 * 1024 * 2) # 2MB
conn.execute(text("CREATE TABLE test_lob (id NUMBER, content CLOB)"))
conn.execute(
text("INSERT INTO test_lob VALUES (1, :content)"), {"content": large_text}
)

# Add table and column comments
conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'"))
conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'"))
request.addfinalizer(oracle.stop)
return oracle


@pytest.fixture(scope="module")
def connection_info(oracle: OracleDbContainer):
# We can't use oracle.user, oracle.password, oracle.dbname here
# since these values are None at this point
return {
"host": oracle.get_container_host_ip(),
"port": oracle.get_exposed_port(oracle.port),
"user": f"{oracle_user}",
"password": f"{oracle_password}",
"database": f"{oracle_database}",
}


@pytest.fixture(scope="module")
def connection_url(connection_info: dict[str, str]):
info = connection_info
return f"oracle://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}"


function_list_path = file_path("../resources/function_list")


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)
105 changes: 105 additions & 0 deletions ibis-server/tests/routers/v3/connector/oracle/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import base64

import orjson
import pytest

from app.config import get_config
from app.dependencies import X_WREN_FALLBACK_DISABLE
from tests.conftest import DATAFUSION_FUNCTION_COUNT
from tests.routers.v3.connector.oracle.conftest import base_url, function_list_path

manifest = {
"dataSource": "oracle",
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "orders",
"tableReference": {
"schema": "SYSTEM",
"table": "ORDERS",
},
"columns": [
{"name": "orderkey", "expression": '"O_ORDERKEY"', "type": "number"},
],
},
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_function_list(client):
config = get_config()

config.set_remote_function_list_path(None)
response = await client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT

config.set_remote_function_list_path(function_list_path)
response = await client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT + 1
the_func = next(filter(lambda x: x["name"] == "to_timestamp_tz", result))
assert the_func == {
"name": "to_timestamp_tz",
"description": "Convert string to timestamp with time zone",
"function_type": "scalar",
"param_names": None,
"param_types": None,
"return_type": None,
}

config.set_remote_function_list_path(None)
response = await client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT


async def test_scalar_function(client, manifest_str: str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT ABS(-1) AS col",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert result == {
"columns": ["col"],
"data": [[1]],
"dtypes": {"col": "int64"},
}


async def test_aggregate_function(client, manifest_str: str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT COUNT(*) AS col FROM (SELECT 1) AS temp_table",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert result == {
"columns": ["col"],
"data": [[1]],
"dtypes": {"col": "int64"},
}
134 changes: 134 additions & 0 deletions ibis-server/tests/routers/v3/connector/oracle/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import base64

import orjson
import pytest

from app.dependencies import X_WREN_FALLBACK_DISABLE
from tests.routers.v3.connector.oracle.conftest import base_url

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"dataSource": "oracle",
"models": [
{
"name": "Orders",
"tableReference": {
"schema": "SYSTEM",
"table": "ORDERS",
},
"columns": [
{"name": "orderkey", "expression": '"O_ORDERKEY"', "type": "number"},
{"name": "custkey", "expression": '"O_CUSTKEY"', "type": "number"},
{
"name": "orderstatus",
"expression": '"O_ORDERSTATUS"',
"type": "varchar",
},
{
"name": "totalprice",
"expression": '"O_TOTALPRICE"',
"type": "number",
},
{
"name": "O_ORDERDATE",
"type": "float64",
"isHidden": True,
},
{
"name": "orderdate",
"expression": 'TRUNC("O_ORDERDATE")',
"type": "date",
},
{
"name": "order_cust_key",
"expression": '"O_ORDERKEY" || \'_\' || "O_CUSTKEY"',
"type": "varchar",
},
{
"name": "timestamp",
"expression": "TO_TIMESTAMP('2024-01-01 23:59:59', 'YYYY-MM-DD HH24:MI:SS')",
"type": "timestamp",
},
{
"name": "timestamptz",
"expression": "TO_TIMESTAMP_TZ( '2024-01-01 23:59:59.000000 +00:00', 'YYYY-MM-DD HH24:MI:SS.FF6 TZH:TZM')",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "CAST(NULL AS TIMESTAMP)",
"type": "timestamp",
},
],
"primaryKey": "orderkey",
}
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_query(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
# include one hidden column
assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1
assert len(result["data"]) == 1
assert result["data"][0] == [
1,
370,
"O",
"172799.49",
"1996-01-02 00:00:00.000000",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
None,
]
assert result["dtypes"] == {
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "object",
"totalprice": "object",
"orderdate": "object",
"order_cust_key": "object",
"timestamp": "object",
"timestamptz": "object",
"test_null_time": "datetime64[ns]",
}


async def test_query_with_connection_url(client, manifest_str, connection_url):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": {"connectionUrl": connection_url},
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
# include one hidden column
assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1
assert len(result["data"]) == 1
assert result["data"][0][0] == 1
assert result["dtypes"] is not None
2 changes: 2 additions & 0 deletions wren-core-base/manifest-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ pub fn data_source(python_binding: proc_macro::TokenStream) -> proc_macro::Token
GcsFile,
#[serde(alias = "minio_file")]
MinioFile,
#[serde(alias = "oracle")]
Oracle,
}
};
proc_macro::TokenStream::from(expanded)
Expand Down
1 change: 1 addition & 0 deletions wren-core-base/src/mdl/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl Display for DataSource {
DataSource::S3File => write!(f, "S3_FILE"),
DataSource::GcsFile => write!(f, "GCS_FILE"),
DataSource::MinioFile => write!(f, "MINIO_FILE"),
DataSource::Oracle => write!(f, "ORACLE"),
}
}
}
Expand Down
Loading