Skip to content

Commit

Permalink
Fix sql scans on empty tables
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx committed Sep 23, 2024
1 parent 7666669 commit ed957b7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
3 changes: 3 additions & 0 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def multiline_display(self) -> list[str]:

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
total_rows, total_size, num_scan_tasks = self._get_size_estimates()
if num_scan_tasks == 0:
return iter(())
if num_scan_tasks == 1 or self._partition_col is None:
return self._single_scan_task(pushdowns, total_rows, total_size)

Expand Down Expand Up @@ -136,6 +138,7 @@ def _get_size_estimates(self) -> tuple[int, float, int]:
if self._num_partitions is None
else self._num_partitions
)
num_scan_tasks = min(num_scan_tasks, total_rows)
return total_rows, total_size, num_scan_tasks

def _get_num_rows(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def assert_df_equals(
sort_key_list: list[str] = [sort_key] if isinstance(sort_key, str) else sort_key
for key in sort_key_list:
assert key in daft_pd_df.columns, (
f"DaFt Dataframe missing key: {key}\nNOTE: This doesn't necessarily mean your code is "
f"Daft Dataframe missing key: {key}\nNOTE: This doesn't necessarily mean your code is "
"breaking, but our testing utilities require sorting on this key in order to compare your "
"Dataframe against the expected Pandas Dataframe."
)
Expand Down
23 changes: 23 additions & 0 deletions tests/integration/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"mysql+pymysql://username:password@localhost:3306/mysql",
]
TEST_TABLE_NAME = "example"
EMPTY_TEST_TABLE_NAME = "empty_table"


@pytest.fixture(scope="session", params=[{"num_rows": 200}])
Expand All @@ -54,6 +55,28 @@ def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Gen
yield db_url


@pytest.fixture(scope="session", params=URLS)
def empty_test_db(request: pytest.FixtureRequest) -> Generator[str, None, None]:
data = pd.DataFrame(
{
"id": pd.Series(dtype="int"),
"string_col": pd.Series(dtype="str"),
}
)
db_url = request.param
engine = create_engine(db_url)
metadata = MetaData()
table = Table(
EMPTY_TEST_TABLE_NAME,
metadata,
Column("id", Integer),
Column("string_col", String(50)),
)
metadata.create_all(engine)
data.to_sql(table.name, con=engine, if_exists="replace", index=False)
yield db_url


@tenacity.retry(stop=tenacity.stop_after_delay(10), wait=tenacity.wait_fixed(5), reraise=True)
def setup_database(db_url: str, data: pd.DataFrame) -> None:
engine = create_engine(db_url)
Expand Down
24 changes: 23 additions & 1 deletion tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import daft
from tests.conftest import assert_df_equals
from tests.integration.sql.conftest import TEST_TABLE_NAME
from tests.integration.sql.conftest import EMPTY_TEST_TABLE_NAME, TEST_TABLE_NAME


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -65,6 +65,28 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [0, 1, 2])
@pytest.mark.parametrize("partition_col", ["id", "string_col"])
def test_sql_partitioned_read_on_empty_table(empty_test_db, num_partitions, partition_col) -> None:
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(
f"SELECT * FROM {EMPTY_TEST_TABLE_NAME}",
empty_test_db,
partition_col=partition_col,
num_partitions=num_partitions,
schema={"id": daft.DataType.int64(), "string_col": daft.DataType.string()},
)
assert df.num_partitions() == 1
empty_pdf = pd.read_sql_query(
f"SELECT * FROM {EMPTY_TEST_TABLE_NAME}", empty_test_db, dtype={"id": "int64", "string_col": "str"}
)
assert_df_equals(df.to_pandas(), empty_pdf, sort_key="id")


@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4])
def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions, pdf) -> None:
Expand Down

0 comments on commit ed957b7

Please sign in to comment.