diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 87bbfefe422..9aaf6502f23 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2284,6 +2284,7 @@ def add_bases( def cleanup_old_versions( self, older_than: Optional[timedelta] = None, + retain_versions: Optional[int] = None, *, delete_unverified: bool = False, error_if_tagged_old_versions: bool = True, @@ -2303,8 +2304,11 @@ def cleanup_old_versions( ---------- older_than: timedelta, optional - Only versions older than this will be removed. If not specified, this - will default to two weeks. + Only versions older than this will be removed. If ``older_than`` and + ``retain_versions`` are not specified, this will default to two weeks. + + retain_versions: int, optional + Retain the last N versions of the dataset. delete_unverified: bool, default False Files leftover from a failed transaction may appear to be part of an @@ -2324,10 +2328,14 @@ def cleanup_old_versions( be ignored without any error and only untagged versions will be cleaned up. """ - if older_than is None: + if older_than is None and retain_versions is None: older_than = timedelta(days=14) + return self._ds.cleanup_old_versions( - td_to_micros(older_than), delete_unverified, error_if_tagged_old_versions + td_to_micros(older_than) if older_than else None, + retain_versions, + delete_unverified, + error_if_tagged_old_versions, ) def create_scalar_index( diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 1c8578c1cd6..9ddb6db6881 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1181,6 +1181,44 @@ def test_cleanup_around_tagged_old_versions(tmp_path): assert stats.old_versions == 1 +def test_cleanup_with_retain_versions(tmp_path: Path): + base_dir = tmp_path / "cleanup_policy" + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + lance.write_dataset(table, base_dir, mode="create") + time.sleep(0.05) + lance.write_dataset(table, base_dir, mode="overwrite") + time.sleep(0.05) + lance.write_dataset(table, base_dir, mode="overwrite") + time.sleep(0.05) + ds = lance.write_dataset(table, base_dir, mode="append") + + assert len(ds.versions()) == 4 + stats = ds.cleanup_old_versions(retain_versions=3) + assert stats.old_versions == 1 + assert len(ds.versions()) == 3 + assert ds.count_rows() == len(ds.to_table()) + + +def test_cleanup_with_older_than_and_retain_versions(tmp_path: Path): + base_dir = tmp_path / "cleanup_policy" + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + lance.write_dataset(table, base_dir, mode="create") + time.sleep(0.05) + lance.write_dataset(table, base_dir, mode="overwrite") + time.sleep(0.05) + lance.write_dataset(table, base_dir, mode="overwrite") + moment = datetime.now() + time.sleep(0.05) + ds = lance.write_dataset(table, base_dir, mode="append") + + stats = ds.cleanup_old_versions( + older_than=datetime.now() - moment, retain_versions=2 + ) + assert stats.old_versions == 2 + assert len(ds.versions()) == 2 + assert ds.count_rows() == len(ds.to_table()) + + def test_auto_cleanup(tmp_path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" diff --git a/python/src/dataset.rs b/python/src/dataset.rs index ade2b4516ca..2afea042521 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -15,7 +15,7 @@ use arrow_data::ArrayData; use arrow_schema::{DataType, Schema as ArrowSchema}; use async_trait::async_trait; use blob::LanceBlobFile; -use chrono::{Duration, TimeDelta}; +use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; use lance_index::vector::bq::RQBuildParams; use log::error; @@ -33,6 +33,7 @@ use pyo3::{ use pyo3::{prelude::*, IntoPyObjectExt}; use snafu::location; +use lance::dataset::cleanup::CleanupPolicyBuilder; use lance::dataset::index::LanceIndexStoreExt; use lance::dataset::refs::{Ref, TagContents}; use lance::dataset::scanner::{ @@ -1496,23 +1497,33 @@ impl Dataset { } /// Cleanup old versions from the dataset - #[pyo3(signature = (older_than_micros, delete_unverified = None, error_if_tagged_old_versions = None))] + #[pyo3(signature = (older_than_micros = None, retain_versions = None, delete_unverified = None, error_if_tagged_old_versions = None))] fn cleanup_old_versions( &self, - older_than_micros: i64, + older_than_micros: Option, + retain_versions: Option, delete_unverified: Option, error_if_tagged_old_versions: Option, ) -> PyResult { - let older_than = Duration::microseconds(older_than_micros); let cleanup_stats = rt() - .block_on( - None, - self.ds.cleanup_old_versions( - older_than, - delete_unverified, - error_if_tagged_old_versions, - ), - )? + .block_on(None, async { + let mut builder = CleanupPolicyBuilder::default(); + if let Some(v) = older_than_micros { + let older_than = Duration::microseconds(v); + builder = builder.before_timestamp(Utc::now() - older_than); + } + if let Some(v) = retain_versions { + builder = builder.retain_n_versions(self.ds.as_ref(), v).await?; + } + if let Some(v) = delete_unverified { + builder = builder.delete_unverified(v); + } + if let Some(v) = error_if_tagged_old_versions { + builder = builder.error_if_tagged_old_versions(v); + } + + self.ds.cleanup_with_policy(builder.build()).await + })? .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; Ok(CleanupStats { bytes_removed: cleanup_stats.bytes_removed,