diff --git a/ibis-server/resources/function_list/oracle.csv b/ibis-server/resources/function_list/oracle.csv new file mode 100644 index 000000000..ca96109a1 --- /dev/null +++ b/ibis-server/resources/function_list/oracle.csv @@ -0,0 +1,2 @@ +function_type,name,return_type,param_names,param_types,description +scalar,to_timestamp_tz,timestamp with time zone,,,"Convert string to timestamp with time zone" \ No newline at end of file diff --git a/ibis-server/tests/routers/v2/connector/test_oracle.py b/ibis-server/tests/routers/v2/connector/test_oracle.py index 87021ef0b..6aa20be2b 100644 --- a/ibis-server/tests/routers/v2/connector/test_oracle.py +++ b/ibis-server/tests/routers/v2/connector/test_oracle.py @@ -20,6 +20,7 @@ manifest = { "catalog": "my_catalog", "schema": "my_schema", + "dataSource": "oracle", "models": [ { "name": "Orders", @@ -139,7 +140,7 @@ def oracle(request) -> OracleDbContainer: # Add table and column comments conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'")) conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'")) - + request.addfinalizer(oracle.stop) return oracle diff --git a/ibis-server/tests/routers/v3/connector/oracle/__init__.py b/ibis-server/tests/routers/v3/connector/oracle/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ibis-server/tests/routers/v3/connector/oracle/conftest.py b/ibis-server/tests/routers/v3/connector/oracle/conftest.py new file mode 100644 index 000000000..414807900 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/oracle/conftest.py @@ -0,0 +1,82 @@ +import pathlib + +import pandas as pd +import pytest +import sqlalchemy +from sqlalchemy import text +from testcontainers.oracle import OracleDbContainer + +from app.config import get_config +from tests.conftest import file_path + +pytestmark = pytest.mark.oracle + +base_url = "/v3/connector/oracle" +oracle_password = "Oracle123" +oracle_user = "SYSTEM" +oracle_database = "FREEPDB1" + + +def pytest_collection_modifyitems(items): + current_file_dir = pathlib.Path(__file__).resolve().parent + for item in items: + if pathlib.Path(item.fspath).is_relative_to(current_file_dir): + item.add_marker(pytestmark) + + +@pytest.fixture(scope="module") +def oracle(request) -> OracleDbContainer: + oracle = OracleDbContainer( + "gvenzl/oracle-free:23.6-slim-faststart", oracle_password=f"{oracle_password}" + ).start() + engine = sqlalchemy.create_engine(oracle.get_connection_url()) + with engine.begin() as conn: + pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql( + "orders", engine, index=False + ) + pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql( + "customer", engine, index=False + ) + + # Create a table with a large CLOB column + large_text = "x" * (1024 * 1024 * 2) # 2MB + conn.execute(text("CREATE TABLE test_lob (id NUMBER, content CLOB)")) + conn.execute( + text("INSERT INTO test_lob VALUES (1, :content)"), {"content": large_text} + ) + + # Add table and column comments + conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'")) + conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'")) + request.addfinalizer(oracle.stop) + return oracle + + +@pytest.fixture(scope="module") +def connection_info(oracle: OracleDbContainer): + # We can't use oracle.user, oracle.password, oracle.dbname here + # since these values are None at this point + return { + "host": oracle.get_container_host_ip(), + "port": oracle.get_exposed_port(oracle.port), + "user": f"{oracle_user}", + "password": f"{oracle_password}", + "database": f"{oracle_database}", + } + + +@pytest.fixture(scope="module") +def connection_url(connection_info: dict[str, str]): + info = connection_info + return f"oracle://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}" + + +function_list_path = file_path("../resources/function_list") + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/oracle/test_function.py b/ibis-server/tests/routers/v3/connector/oracle/test_function.py new file mode 100644 index 000000000..26cff20cc --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/oracle/test_function.py @@ -0,0 +1,105 @@ +import base64 + +import orjson +import pytest + +from app.config import get_config +from app.dependencies import X_WREN_FALLBACK_DISABLE +from tests.conftest import DATAFUSION_FUNCTION_COUNT +from tests.routers.v3.connector.oracle.conftest import base_url, function_list_path + +manifest = { + "dataSource": "oracle", + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "SYSTEM", + "table": "ORDERS", + }, + "columns": [ + {"name": "orderkey", "expression": '"O_ORDERKEY"', "type": "number"}, + ], + }, + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_function_list(client): + config = get_config() + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + config.set_remote_function_list_path(function_list_path) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + 1 + the_func = next(filter(lambda x: x["name"] == "to_timestamp_tz", result)) + assert the_func == { + "name": "to_timestamp_tz", + "description": "Convert string to timestamp with time zone", + "function_type": "scalar", + "param_names": None, + "param_types": None, + "return_type": None, + } + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + +async def test_scalar_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT ABS(-1) AS col", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int64"}, + } + + +async def test_aggregate_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT COUNT(*) AS col FROM (SELECT 1) AS temp_table", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int64"}, + } diff --git a/ibis-server/tests/routers/v3/connector/oracle/test_query.py b/ibis-server/tests/routers/v3/connector/oracle/test_query.py new file mode 100644 index 000000000..22a8fe983 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/oracle/test_query.py @@ -0,0 +1,134 @@ +import base64 + +import orjson +import pytest + +from app.dependencies import X_WREN_FALLBACK_DISABLE +from tests.routers.v3.connector.oracle.conftest import base_url + +manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "dataSource": "oracle", + "models": [ + { + "name": "Orders", + "tableReference": { + "schema": "SYSTEM", + "table": "ORDERS", + }, + "columns": [ + {"name": "orderkey", "expression": '"O_ORDERKEY"', "type": "number"}, + {"name": "custkey", "expression": '"O_CUSTKEY"', "type": "number"}, + { + "name": "orderstatus", + "expression": '"O_ORDERSTATUS"', + "type": "varchar", + }, + { + "name": "totalprice", + "expression": '"O_TOTALPRICE"', + "type": "number", + }, + { + "name": "O_ORDERDATE", + "type": "float64", + "isHidden": True, + }, + { + "name": "orderdate", + "expression": 'TRUNC("O_ORDERDATE")', + "type": "date", + }, + { + "name": "order_cust_key", + "expression": '"O_ORDERKEY" || \'_\' || "O_CUSTKEY"', + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "TO_TIMESTAMP('2024-01-01 23:59:59', 'YYYY-MM-DD HH24:MI:SS')", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "TO_TIMESTAMP_TZ( '2024-01-01 23:59:59.000000 +00:00', 'YYYY-MM-DD HH24:MI:SS.FF6 TZH:TZM')", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "CAST(NULL AS TIMESTAMP)", + "type": "timestamp", + }, + ], + "primaryKey": "orderkey", + } + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_query(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 200 + result = response.json() + # include one hidden column + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1 + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + "172799.49", + "1996-01-02 00:00:00.000000", + "1_370", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", + None, + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "object", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", + "test_null_time": "datetime64[ns]", + } + + +async def test_query_with_connection_url(client, manifest_str, connection_url): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 200 + result = response.json() + # include one hidden column + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) - 1 + assert len(result["data"]) == 1 + assert result["data"][0][0] == 1 + assert result["dtypes"] is not None diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 65f145879..bfac5e625 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -102,6 +102,8 @@ pub fn data_source(python_binding: proc_macro::TokenStream) -> proc_macro::Token GcsFile, #[serde(alias = "minio_file")] MinioFile, + #[serde(alias = "oracle")] + Oracle, } }; proc_macro::TokenStream::from(expanded) diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 2d4896ca4..e47bb5ed8 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -106,6 +106,7 @@ impl Display for DataSource { DataSource::S3File => write!(f, "S3_FILE"), DataSource::GcsFile => write!(f, "GCS_FILE"), DataSource::MinioFile => write!(f, "MINIO_FILE"), + DataSource::Oracle => write!(f, "ORACLE"), } } } diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index 34c53b416..a6a281f07 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -20,9 +20,12 @@ use crate::mdl::dialect::utils::scalar_function_to_sql_internal; use crate::mdl::manifest::DataSource; use datafusion::common::Result; +use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::Expr; + use datafusion::sql::sqlparser::ast; use datafusion::sql::unparser::Unparser; +use regex::Regex; /// [InnerDialect] is a trait that defines the methods that for dialect-specific SQL generation. pub trait InnerDialect: Send + Sync { @@ -41,6 +44,10 @@ pub trait InnerDialect: Send + Sync { fn unnest_as_table_factor(&self) -> bool { false } + + fn identifier_quote_style(&self, _identifier: &str) -> Option { + None + } } /// [get_inner_dialect] returns the suitable InnerDialect for the given data source. @@ -48,6 +55,7 @@ pub fn get_inner_dialect(data_source: &DataSource) -> Box { match data_source { DataSource::MySQL => Box::new(MySQLDialect {}), DataSource::BigQuery => Box::new(BigQueryDialect {}), + DataSource::Oracle => Box::new(OracleDialect {}), _ => Box::new(GenericDialect {}), } } @@ -82,3 +90,25 @@ impl InnerDialect for BigQueryDialect { true } } + +pub struct OracleDialect {} + +impl InnerDialect for OracleDialect { + fn identifier_quote_style(&self, identifier: &str) -> Option { + // Oracle defaults to upper case for identifiers + let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap(); + if ALL_KEYWORDS.contains(&identifier.to_uppercase().as_str()) + || !identifier_regex.is_match(identifier) + || non_uppercase(identifier) + { + Some('"') + } else { + None + } + } +} + +fn non_uppercase(sql: &str) -> bool { + let uppsercase = sql.to_uppercase(); + uppsercase != sql +} diff --git a/wren-core/core/src/mdl/dialect/wren_dialect.rs b/wren-core/core/src/mdl/dialect/wren_dialect.rs index c6df175d4..56dfa55dd 100644 --- a/wren-core/core/src/mdl/dialect/wren_dialect.rs +++ b/wren-core/core/src/mdl/dialect/wren_dialect.rs @@ -38,6 +38,10 @@ pub struct WrenDialect { impl Dialect for WrenDialect { fn identifier_quote_style(&self, identifier: &str) -> Option { + if let Some(quote) = self.inner_dialect.identifier_quote_style(identifier) { + return Some(quote); + } + let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap(); if ALL_KEYWORDS.contains(&identifier.to_uppercase().as_str()) || !identifier_regex.is_match(identifier)