Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ibis-server/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pre-commit-install:
port := "8000"

run:
poetry run fastapi run --port {{ port }}
poetry run python -m fastapi run --port {{ port }}

workers := "2"

Expand All @@ -45,7 +45,7 @@ run-trace-otlp:
--service_name wren-engine \
fastapi run --port {{ port }}
dev:
poetry run fastapi dev --port {{ port }}
poetry run python -m fastapi dev --port {{ port }}

# run the pytest tests for the given marker
test MARKER:
Expand Down
98 changes: 97 additions & 1 deletion ibis-server/tests/routers/v3/connector/snowflake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import pathlib

import pytest
import snowflake.connector

from app.config import get_config
from tests.conftest import file_path

pytestmark = pytest.mark.snowflake

base_url = "/v3/connector/snowflake"
function_list_path = file_path("../resources/function_list")


def pytest_collection_modifyitems(items):
Expand All @@ -15,8 +20,87 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytestmark)


@pytest.fixture(scope="module", autouse=True)
def init_snowflake():
user = os.getenv("SNOWFLAKE_USER")
account = os.getenv("SNOWFLAKE_ACCOUNT")
private_key = os.getenv("SNOWFLAKE_PRIVATE_KEY")
if not user or not account or not private_key:
pytest.skip("Snowflake credentials are not set", allow_module_level=True)

conn = snowflake.connector.connect(
user=user,
account=account,
private_key=private_key,
warehouse="COMPUTE_WH",
)
try:
cs = conn.cursor()
try:
cs.execute("USE WREN")
except Exception:
cs.execute("CREATE DATABASE IF NOT EXISTS WREN")
cs.execute("USE WREN")
try:
cs.execute("USE SCHEMA PUBLIC")
except Exception:
cs.execute("CREATE SCHEMA IF NOT EXISTS PUBLIC")
cs.execute("USE SCHEMA PUBLIC")
# prepare table with variant column
cs.execute(
"""
CREATE OR REPLACE TABLE car_sales
(
src variant
)
AS
SELECT PARSE_JSON(column1) AS src
FROM VALUES
('{
"date" : "2017-04-28",
"dealership" : "Valley View Auto Sales",
"salesperson" : {
"id": "55",
"name": "Frank Beasley"
},
"customer" : [
{"name": "Joyce Ridgely", "phone": "16504378889", "address": "San Francisco, CA"}
],
"vehicle" : [
{"make": "Honda", "model": "Civic", "year": "2017", "price": "20275", "extras":["ext warranty", "paint protection"]}
]
}'),
('{
"date" : "2017-04-28",
"dealership" : "Tindel Toyota",
"salesperson" : {
"id": "274",
"name": "Greg Northrup"
},
"customer" : [
{"name": "Bradley Greenbloom", "phone": "12127593751", "address": "New York, NY"}
],
"vehicle" : [
{"make": "Toyota", "model": "Camry", "year": "2017", "price": "23500", "extras":["ext warranty", "rust proofing", "fabric protection"]}
]
}') v;
"""
)
finally:
cs.close()
conn.close()


@pytest.fixture(scope="module", autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)


@pytest.fixture(scope="module")
def connection_info() -> dict[str, str]:
def tpch_connection_info() -> dict[str, str]:
return {
"user": os.getenv("SNOWFLAKE_USER"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
Expand All @@ -25,3 +109,15 @@ def connection_info() -> dict[str, str]:
"warehouse": "COMPUTE_WH",
"private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"),
}


@pytest.fixture(scope="module")
def snowflake_connection_info() -> dict[str, str]:
return {
"user": os.getenv("SNOWFLAKE_USER"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"database": "WREN",
"schema": "PUBLIC",
"warehouse": "COMPUTE_WH",
"private_key": os.getenv("SNOWFLAKE_PRIVATE_KEY"),
}
22 changes: 6 additions & 16 deletions ibis-server/tests/routers/v3/connector/snowflake/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

from app.config import get_config
from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path
from tests.routers.v3.connector.snowflake.conftest import base_url
from tests.conftest import DATAFUSION_FUNCTION_COUNT
from tests.routers.v3.connector.snowflake.conftest import base_url, function_list_path

manifest = {
"catalog": "my_catalog",
Expand All @@ -24,22 +24,12 @@
],
}

function_list_path = file_path("../resources/function_list")


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)


async def test_function_list(client):
config = get_config()

Expand Down Expand Up @@ -71,11 +61,11 @@ async def test_function_list(client):
assert len(result) == DATAFUSION_FUNCTION_COUNT


async def test_scalar_function(client, manifest_str: str, connection_info):
async def test_scalar_function(client, manifest_str: str, tpch_connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"connectionInfo": tpch_connection_info,
"manifestStr": manifest_str,
"sql": "SELECT ABS(-1) AS col",
},
Expand All @@ -89,11 +79,11 @@ async def test_scalar_function(client, manifest_str: str, connection_info):
}


async def test_aggregate_function(client, manifest_str: str, connection_info):
async def test_aggregate_function(client, manifest_str: str, tpch_connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"connectionInfo": tpch_connection_info,
"manifestStr": manifest_str,
"sql": "SELECT COUNT(*) AS col FROM (SELECT 1) AS temp_table",
},
Expand Down
56 changes: 56 additions & 0 deletions ibis-server/tests/routers/v3/connector/snowflake/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import base64

import orjson
import pytest

from app.dependencies import X_WREN_FALLBACK_DISABLE
from tests.routers.v3.connector.snowflake.conftest import base_url

manifest = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "car_sales",
"tableReference": {
"catalog": "wren",
"schema": "PUBLIC",
"table": "car_sales",
},
"columns": [
{"name": "src", "type": "variant"},
],
},
],
"dataSource": "snowflake",
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_qeury(client, manifest_str, snowflake_connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": snowflake_connection_info,
"manifestStr": manifest_str,
"sql": "select t.a from car_sales c, UNNEST(to_array(get_path(c.src, 'customer'))) t(a)",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 2
assert result["data"] == [
[
'{"address":"San Francisco, CA","name":"Joyce Ridgely","phone":"16504378889"}'
],
[
'{"address":"New York, NY","name":"Bradley Greenbloom","phone":"12127593751"}'
],
]
Comment on lines +34 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in function name.

The test function is named test_qeury but should be test_query.

Apply this diff:

-async def test_qeury(client, manifest_str, snowflake_connection_info):
+async def test_query(client, manifest_str, snowflake_connection_info):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def test_qeury(client, manifest_str, snowflake_connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": snowflake_connection_info,
"manifestStr": manifest_str,
"sql": "select t.a from car_sales c, UNNEST(to_array(get_path(c.src, 'customer'))) t(a)",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 2
assert result["data"] == [
[
'{"address":"San Francisco, CA","name":"Joyce Ridgely","phone":"16504378889"}'
],
[
'{"address":"New York, NY","name":"Bradley Greenbloom","phone":"12127593751"}'
],
]
async def test_query(client, manifest_str, snowflake_connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": snowflake_connection_info,
"manifestStr": manifest_str,
"sql": "select t.a from car_sales c, UNNEST(to_array(get_path(c.src, 'customer'))) t(a)",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 2
assert result["data"] == [
[
'{"address":"San Francisco, CA","name":"Joyce Ridgely","phone":"16504378889"}'
],
[
'{"address":"New York, NY","name":"Bradley Greenbloom","phone":"12127593751"}'
],
]
🤖 Prompt for AI Agents
In ibis-server/tests/routers/v3/connector/snowflake/test_query.py around lines
34 to 56, the test function is misspelled as `test_qeury`; rename the function
to `test_query` (update the `def` line only) so the test runner recognizes it,
and ensure there are no other references to the old name in this file or imports
that need updating.

Loading