diff --git a/docs/guide/checkpoints.md b/docs/guide/checkpoints.md index c88e71e92..e137fe0c7 100644 --- a/docs/guide/checkpoints.md +++ b/docs/guide/checkpoints.md @@ -11,10 +11,10 @@ Checkpoints are available for both local script runs and Studio executions. When you run a Python script locally (e.g., `python my_script.py`), DataChain automatically: 1. **Creates a job** for the script execution, using the script's absolute path as the job name -2. **Tracks parent jobs** by finding the last job with the same script name +2. **Tracks previous runs** by finding the last job with the same script name 3. **Calculates hashes** for each dataset save operation based on the DataChain operations chain 4. **Creates checkpoints** after each successful `.save()` call, storing the hash -5. **Checks for existing checkpoints** on subsequent runs - if a matching checkpoint exists in the parent job, DataChain skips the save and reuses the existing dataset +5. **Checks for existing checkpoints** on subsequent runs - if a matching checkpoint exists from the previous run, DataChain skips the save and reuses the existing dataset This means that if your script creates multiple datasets and fails partway through, the next run will skip recreating the datasets that were already successfully saved. @@ -37,7 +37,7 @@ When triggering jobs through the Studio interface: 2. **Checkpoint control** is explicit - you choose between: - **Run from scratch**: Ignores any existing checkpoints and recreates all datasets - **Continue from last checkpoint**: Resumes from the last successful checkpoint, skipping already-completed stages -3. **Parent-child job linking** is handled automatically by the system +3. **Job linking between runs** is handled automatically by the system 4. **Checkpoint behavior** during execution is the same as local runs: datasets are saved at each `.save()` call and can be reused on retry @@ -75,7 +75,7 @@ result = ( **First run:** The script executes all three stages and creates three datasets: `filtered_data`, `transformed_data`, and `final_results`. If the script fails during Stage 3, only `filtered_data` and `transformed_data` are saved. -**Second run:** DataChain detects that `filtered_data` and `transformed_data` were already created in the parent job with matching hashes. It skips recreating them and proceeds directly to Stage 3, creating only `final_results`. +**Second run:** DataChain detects that `filtered_data` and `transformed_data` were already created in the previous run with matching hashes. It skips recreating them and proceeds directly to Stage 3, creating only `final_results`. ## When Checkpoints Are Used @@ -84,27 +84,20 @@ Checkpoints are automatically used when: - Running a Python script locally (e.g., `python my_script.py`) - The script has been run before - A dataset with the same name is being saved -- The chain hash matches a checkpoint from the parent job +- The chain hash matches a checkpoint from the previous run Checkpoints are **not** used when: - Running code interactively (Python REPL, Jupyter notebooks) - Running code as a module (e.g., `python -m mymodule`) -- The `DATACHAIN_CHECKPOINTS_RESET` environment variable is set (see below) +- The `DATACHAIN_IGNORE_CHECKPOINTS` environment variable is set (see below) ## Resetting Checkpoints -To ignore existing checkpoints and run your script from scratch, set the `DATACHAIN_CHECKPOINTS_RESET` environment variable: +To ignore existing checkpoints and run your script from scratch, set the `DATACHAIN_IGNORE_CHECKPOINTS` environment variable: ```bash -export DATACHAIN_CHECKPOINTS_RESET=1 -python my_script.py -``` - -Or set it inline: - -```bash -DATACHAIN_CHECKPOINTS_RESET=1 python my_script.py +DATACHAIN_IGNORE_CHECKPOINTS=1 python my_script.py ``` This forces DataChain to recreate all datasets, regardless of existing checkpoints. @@ -121,7 +114,7 @@ When running `python my_script.py`, DataChain uses the **absolute path** to the /home/user/projects/my_script.py ``` -This allows DataChain to link runs of the same script together as parent-child jobs, enabling checkpoint lookup. +This allows DataChain to link runs of the same script together, enabling checkpoint lookup across runs. ### Interactive or Module Execution (Checkpoints Disabled) @@ -140,7 +133,7 @@ For each `.save()` operation, DataChain calculates a hash based on: 1. The hash of the previous checkpoint in the current job (if any) 2. The hash of the current DataChain operations chain -This creates a chain of hashes that uniquely identifies each stage of data processing. On subsequent runs, DataChain matches these hashes against the parent job's checkpoints and skips recreating datasets where the hashes match. +This creates a chain of hashes that uniquely identifies each stage of data processing. On subsequent runs, DataChain matches these hashes against checkpoints from the previous run and skips recreating datasets where the hashes match. ### Hash Invalidation @@ -198,29 +191,153 @@ for ds in dc.datasets(): print(ds.name) ``` -## Limitations +## UDF-Level Checkpoints -- **Script-based:** Code must be run as a script (not interactively or as a module). -- **Hash-based matching:** Any change to the chain will create a different hash, preventing checkpoint reuse. -- **Same script path:** The script must be run from the same absolute path for parent job linking to work. +In addition to dataset-level checkpointing via `.save()`, DataChain automatically creates checkpoints for individual UDFs (`.map()`, `.gen()`, `.agg()`) during execution. -## Future Plans +**Two levels of checkpointing:** +- **Dataset checkpoints** (via `.save()`): When you explicitly save a dataset, it's persisted and can be used in other scripts. If you re-run the same chain with unchanged code, DataChain skips recreation and reuses the saved dataset. +- **UDF checkpoints** (automatic): Each UDF execution is automatically checkpointed. If a UDF completes successfully, it's skipped entirely on re-run (if code unchanged). If a UDF fails mid-execution, only the unprocessed rows are recomputed on re-run. -### UDF-Level Checkpoints +**Key differences:** +- `.save()` creates a named dataset that persists even if your script fails later, and can be used in other scripts +- UDF checkpoints are automatic and internal - they optimize execution within a single script by skipping or resuming UDFs -Currently, checkpoints are created only when datasets are saved using `.save()`. This means that if a script fails during a long-running UDF operation (like `.map()`, `.gen()`, or `.agg()`), the entire UDF computation must be rerun on the next execution. +For `.map()` and `.gen()`, **DataChain saves processed rows continuously during UDF execution**. This means: +- If a UDF **completes successfully**, a checkpoint is created and the entire UDF is skipped on re-run (unless code changes) +- If a UDF **fails mid-execution**, the next run continues from where it left off, skipping already-processed rows - even if you've modified the UDF code to fix a bug -Future versions will support **UDF-level checkpoints**, creating checkpoints after each UDF step in the chain. This will provide much more granular recovery: +**Note:** For `.agg()`, checkpoints are created when the aggregation completes successfully, but partial results are not tracked. If an aggregation fails partway through, it will restart from scratch on the next run. + +### How It Works + +When executing `.map()` or `.gen()`, DataChain: + +1. **Saves processed rows incrementally** as the UDF processes your dataset +2. **Creates a checkpoint** when the UDF completes successfully +3. **Allows you to fix bugs and continue** - if the UDF fails, you can modify the code and re-run, skipping already-processed rows +4. **Invalidates the checkpoint if you change the UDF after successful completion** - completed UDFs are recomputed from scratch if the code changes + +For `.agg()`, checkpoints are only created upon successful completion, without incremental progress tracking. + +### Example: Fixing a Bug Mid-Execution ```python -# Future behavior with UDF-level checkpoints -result = ( - dc.read_csv("data.csv") - .map(heavy_computation_1) # Checkpoint created after this UDF - .map(heavy_computation_2) # Checkpoint created after this UDF - .map(heavy_computation_3) # Checkpoint created after this UDF - .save("result") + +def process_image(file: File) -> int: + # Bug: this will fail on some images + img = Image.open(file.get_local_path()) + return img.size[0] + +( + dc.read_dataset("images") + .map(width=process_image) + .save("image_dimensions") ) ``` -If the script fails during `heavy_computation_3`, the next run will skip re-executing `heavy_computation_1` and `heavy_computation_2`, resuming only the work that wasn't completed. +**First run:** Script processes 50% of images successfully, then fails on a corrupted image. + +**After fixing the bug:** + +```python +from datachain import File + +def process_image(file: File) -> int: + # Fixed: handle corrupted images gracefully + try: + img = Image.open(file.get_local_path()) + return img.size[0] + except Exception: + return 0 +``` + +**Second run:** DataChain automatically skips the 50% of images that were already processed successfully, and continues processing the remaining images using the fixed code. You don't lose any progress from the first run. + +### When UDF Checkpoints Are Invalidated + +DataChain distinguishes between two types of UDF changes: + +#### 1. Code-Only Changes (Bug Fixes) - Continues from Partial Results + +When you fix a bug in your UDF code **without changing the output type**, DataChain allows you to continue from where the UDF failed. This is the key benefit of UDF-level checkpoints - you don't lose progress when fixing bugs. + +**Example: Bug fix without output change** +```python +# First run - fails partway through +def process(num: int) -> int: + if num > 100: + raise Exception("Bug!") # Oops, a bug! + return num * 10 + +# Second run - continues from where it failed +def process(num: int) -> int: + return num * 10 # Bug fixed! ✓ Continues from partial results +``` + +In this case, DataChain will skip already-processed rows and continue processing the remaining rows with your fixed code. + +#### 2. Output Schema Changes - Forces Re-run from Scratch + +When you change the **output type** of your UDF, DataChain automatically detects this and reruns the entire UDF from scratch. This prevents schema mismatches that would cause errors or corrupt data. + +**Example: Output change** +```python +# First run - fails partway through +def process(num: int) -> int: + if num > 100: + raise Exception("Bug!") + return num * 10 + +# Second run - output type changed +def process(num: int) -> str: + return f"value_{num * 10}" # Output type changed! ✗ Reruns from scratch +``` + +In this case, DataChain detects that the output type changed from `int` to `str` and discards partial results to avoid schema incompatibility. All rows will be reprocessed with the new output. + +#### Changes That Invalidate In-Progress UDF Checkpoints + +Partial results are automatically discarded when you change: + +- **Output type** - Changes to the `output` parameter or return type annotations +- **Operations before the UDF** - Any changes to the data processing chain before the UDF + +#### Changes That Invalidate Completed UDF Checkpoints + +Once a UDF completes successfully, its checkpoint is tied to the UDF function code. If you modify the function and re-run the script, DataChain will detect the change and recompute the entire UDF from scratch. + +Changes that invalidate completed UDF checkpoints: + +- **Modifying the UDF function logic** - Any code changes inside the function +- **Changing function parameters or output types** - Changes to input/output specifications +- **Altering any operations before the UDF in the chain** - Changes to upstream data processing + +**Key takeaway:** For in-progress (partial) UDFs, you can fix bugs freely as long as the output stays the same. For completed UDFs, any code change triggers a full recomputation. + +## Limitations + +When running locally: + +- **Script-based:** Code must be run as a script (not interactively or as a module). +- **Same script path:** The script must be run from the same absolute path for linking to previous runs to work. +- **Threading/Multiprocessing:** Checkpoints are automatically disabled when Python threading or multiprocessing is detected to prevent race conditions. Any checkpoints created before threading starts remain valid for future runs. DataChain's built-in `parallel` setting for UDF execution is not affected by this limitation. + +These limitations don't apply when running on Studio, where job linking between runs is handled automatically by the platform. + +### UDF Hashing Limitations + +DataChain computes checkpoint hashes by inspecting UDF code and metadata. Certain types of callables cannot be reliably hashed: + +- **Built-in functions** (`len`, `str`, `int`, etc.): Cannot access bytecode, so a random hash is generated on each run. Checkpoints using these functions will not be reused. +- **C extensions**: Same limitation as built-ins - no accessible bytecode means a new hash each run. +- **Mock objects**: `Mock(side_effect=...)` cannot be reliably hashed because the side effect is not discoverable via inspection. Use regular functions instead. +- **Dynamically generated callables**: If a callable is created via `exec`/`eval` or its behavior depends on runtime state, the hash reflects only the method's code, not captured state. + +To ensure checkpoints work correctly, use regular Python functions defined with `def` or lambda expressions for your UDFs. + +## Future Plans + +### Partial Result Tracking for Aggregations + +Currently, `.agg()` creates checkpoints only upon successful completion, without tracking partial progress. Future versions will extend the same incremental progress tracking that `.map()` and `.gen()` have to aggregations, allowing them to resume from where they failed rather than restarting from scratch. diff --git a/docs/guide/env.md b/docs/guide/env.md index 616768f78..5a33b739a 100644 --- a/docs/guide/env.md +++ b/docs/guide/env.md @@ -19,4 +19,7 @@ List of environment variables used to configure DataChain behavior. - `DATACHAIN_NAMESPACE` – Namespace name to use as default. - `DATACHAIN_PROJECT` – Project name or combination of namespace name and project name separated by `.` to use as default, example: `DATACHAIN_PROJECT=dev.analytics` +### Checkpoints +.- `DATACHAIN_IGNORE_CHECKPOINTS` – When set to `1` or `true`, ignores all existing checkpoints and runs the script from scratch, forcing DataChain to recreate all datasets. + Note: Some environment variables are used internally and may not be documented here. For the most up-to-date list, refer to the source code. diff --git a/src/datachain/checkpoint_event.py b/src/datachain/checkpoint_event.py new file mode 100644 index 000000000..baf3ab445 --- /dev/null +++ b/src/datachain/checkpoint_event.py @@ -0,0 +1,101 @@ +import uuid +from dataclasses import dataclass +from datetime import datetime +from enum import Enum + + +class CheckpointEventType(str, Enum): + """Types of checkpoint events.""" + + # UDF events + UDF_SKIPPED = "UDF_SKIPPED" + UDF_CONTINUED = "UDF_CONTINUED" + UDF_FROM_SCRATCH = "UDF_FROM_SCRATCH" + + # Dataset save events + DATASET_SAVE_SKIPPED = "DATASET_SAVE_SKIPPED" + DATASET_SAVE_COMPLETED = "DATASET_SAVE_COMPLETED" + + +class CheckpointStepType(str, Enum): + """Types of checkpoint steps.""" + + UDF_MAP = "UDF_MAP" + UDF_GEN = "UDF_GEN" + DATASET_SAVE = "DATASET_SAVE" + + +@dataclass +class CheckpointEvent: + """ + Represents a checkpoint event for debugging and visibility. + + Checkpoint events are logged during job execution to track checkpoint + decisions (skip, continue, run from scratch) and provide visibility + into what happened during script execution. + """ + + id: str + job_id: str + run_group_id: str | None + timestamp: datetime + event_type: CheckpointEventType + step_type: CheckpointStepType + udf_name: str | None = None + dataset_name: str | None = None + checkpoint_hash: str | None = None + hash_partial: str | None = None + hash_input: str | None = None + hash_output: str | None = None + rows_input: int | None = None + rows_processed: int | None = None + rows_output: int | None = None + rows_input_reused: int | None = None + rows_output_reused: int | None = None + rerun_from_job_id: str | None = None + details: dict | None = None + + @classmethod + def parse( # noqa: PLR0913 + cls, + id: str | uuid.UUID, + job_id: str, + run_group_id: str | None, + timestamp: datetime, + event_type: str, + step_type: str, + udf_name: str | None, + dataset_name: str | None, + checkpoint_hash: str | None, + hash_partial: str | None, + hash_input: str | None, + hash_output: str | None, + rows_input: int | None, + rows_processed: int | None, + rows_output: int | None, + rows_input_reused: int | None, + rows_output_reused: int | None, + rerun_from_job_id: str | None, + details: dict | None, + ) -> "CheckpointEvent": + return cls( + id=str(id), + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=CheckpointEventType(event_type), + step_type=CheckpointStepType(step_type), + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_output=rows_output, + rows_input_reused=rows_input_reused, + rows_output_reused=rows_output_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) diff --git a/src/datachain/data_storage/db_engine.py b/src/datachain/data_storage/db_engine.py index 50a673484..22fb3d1f0 100644 --- a/src/datachain/data_storage/db_engine.py +++ b/src/datachain/data_storage/db_engine.py @@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import DDLRole from datachain.data_storage.serializer import Serializable +from datachain.error import TableMissingError if TYPE_CHECKING: from sqlalchemy import MetaData, Table @@ -86,12 +87,16 @@ def execute( ) -> Iterator[tuple[Any, ...]]: ... def get_table(self, name: str) -> "Table": + """Get a table by name, raising TableMissingError if not found.""" table = self.metadata.tables.get(name) if table is None: - sa.Table(name, self.metadata, autoload_with=self.engine) - # ^^^ This table may not be correctly initialised on some dialects - # Grab it from metadata instead. - table = self.metadata.tables[name] + try: + sa.Table(name, self.metadata, autoload_with=self.engine) + table = self.metadata.tables.get(name) + if table is None: + raise TableMissingError(f"Table '{name}' not found") + except sa.exc.NoSuchTableError as e: + raise TableMissingError(f"Table '{name}' not found") from e return table @abstractmethod @@ -117,6 +122,18 @@ def has_table(self, name: str) -> bool: """ return sa.inspect(self.engine).has_table(name) + @abstractmethod + def list_tables(self, prefix: str = "") -> list[str]: + """ + List all table names, optionally filtered by prefix. + + Args: + prefix: Optional prefix to filter table names + + Returns: + List of table names matching the prefix + """ + @abstractmethod def create_table( self, @@ -124,7 +141,10 @@ def create_table( if_not_exists: bool = True, *, kind: str | None = None, - ) -> None: ... + ) -> None: + """ + Create table. Does nothing if table already exists when if_not_exists=True. + """ @abstractmethod def drop_table(self, table: "Table", if_exists: bool = False) -> None: ... diff --git a/src/datachain/data_storage/metastore.py b/src/datachain/data_storage/metastore.py index 6b89b608a..e52311377 100644 --- a/src/datachain/data_storage/metastore.py +++ b/src/datachain/data_storage/metastore.py @@ -4,7 +4,7 @@ import sys from abc import ABC, abstractmethod from collections.abc import Iterator -from contextlib import contextmanager, suppress +from contextlib import contextmanager, nullcontext, suppress from datetime import datetime, timezone from functools import cached_property, reduce from itertools import groupby @@ -39,6 +39,11 @@ from datachain import json from datachain.catalog.dependency import DatasetDependencyNode from datachain.checkpoint import Checkpoint +from datachain.checkpoint_event import ( + CheckpointEvent, + CheckpointEventType, + CheckpointStepType, +) from datachain.data_storage import JobQueryType, JobStatus from datachain.data_storage.serializer import Serializable from datachain.dataset import ( @@ -98,6 +103,7 @@ class AbstractMetastore(ABC, Serializable): dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode job_class: type[Job] = Job checkpoint_class: type[Checkpoint] = Checkpoint + checkpoint_event_class: type[CheckpointEvent] = CheckpointEvent def __init__( self, @@ -486,6 +492,10 @@ def create_job( def get_job(self, job_id: str) -> Job | None: """Returns the job with the given ID.""" + @abstractmethod + def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: + """Returns list of ancestor job IDs in order from parent to root.""" + @abstractmethod def update_job( self, @@ -542,14 +552,77 @@ def find_checkpoint( """ @abstractmethod - def create_checkpoint( + def get_or_create_checkpoint( self, job_id: str, _hash: str, partial: bool = False, conn: Any | None = None, ) -> Checkpoint: - """Creates new checkpoint""" + """Get or create checkpoint. Must be atomic and idempotent.""" + + @abstractmethod + def remove_checkpoint(self, checkpoint_id: str, conn: Any | None = None) -> None: + """Removes a checkpoint by ID""" + + # + # Checkpoint Events + # + + @abstractmethod + def log_checkpoint_event( # noqa: PLR0913 + self, + job_id: str, + event_type: "CheckpointEventType", + step_type: "CheckpointStepType", + run_group_id: str | None = None, + udf_name: str | None = None, + dataset_name: str | None = None, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_output: int | None = None, + rows_input_reused: int | None = None, + rows_output_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + conn: Any | None = None, + ) -> "CheckpointEvent": + """Log a checkpoint event.""" + + @abstractmethod + def get_checkpoint_events( + self, + job_id: str | None = None, + run_group_id: str | None = None, + conn: Any | None = None, + ) -> Iterator["CheckpointEvent"]: + """Get checkpoint events, optionally filtered by job_id or run_group_id.""" + + # + # UDF Registry (SaaS only, no-op for local metastores) + # + + def add_udf( + self, + udf_id: str, + name: str, + status: str, + rows_total: int, + job_id: str, + tasks_created: int, + skipped: bool = False, + continued: bool = False, + rows_reused: int = 0, + output_rows_reused: int = 0, + ) -> None: + """ + Register a UDF in the registry. + No-op for local metastores, implemented in SaaS APIMetastore. + """ # # Dataset Version Jobs (many-to-many) @@ -563,17 +636,7 @@ def link_dataset_version_to_job( is_creator: bool = False, conn=None, ) -> None: - """ - Link dataset version to job. - - This atomically: - 1. Creates a link in the dataset_version_jobs junction table - 2. Updates dataset_version.job_id to point to this job - """ - - @abstractmethod - def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: - """Get all ancestor job IDs for a given job.""" + """Link dataset version to job. Must be atomic.""" @abstractmethod def get_dataset_version_for_job_ancestry( @@ -606,6 +669,7 @@ class AbstractDBMetastore(AbstractMetastore): DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs" JOBS_TABLE = "jobs" CHECKPOINTS_TABLE = "checkpoints" + CHECKPOINT_EVENTS_TABLE = "checkpoint_events" db: "DatabaseEngine" @@ -2116,6 +2180,74 @@ def _checkpoints(self) -> "Table": @abstractmethod def _checkpoints_insert(self) -> "Insert": ... + # + # Checkpoint Events + # + + @staticmethod + def _checkpoint_events_columns() -> "list[SchemaItem]": + return [ + Column( + "id", + Text, + default=uuid4, + primary_key=True, + nullable=False, + ), + Column("job_id", Text, nullable=False), + Column("run_group_id", Text, nullable=True), + Column("timestamp", DateTime(timezone=True), nullable=False), + Column("event_type", Text, nullable=False), + Column("step_type", Text, nullable=False), + Column("udf_name", Text, nullable=True), + Column("dataset_name", Text, nullable=True), + Column("checkpoint_hash", Text, nullable=True), + Column("hash_partial", Text, nullable=True), + Column("hash_input", Text, nullable=True), + Column("hash_output", Text, nullable=True), + Column("rows_input", BigInteger, nullable=True), + Column("rows_processed", BigInteger, nullable=True), + Column("rows_output", BigInteger, nullable=True), + Column("rows_input_reused", BigInteger, nullable=True), + Column("rows_output_reused", BigInteger, nullable=True), + Column("rerun_from_job_id", Text, nullable=True), + Column("details", JSON, nullable=True), + Index("dc_idx_ce_job_id", "job_id"), + Index("dc_idx_ce_run_group_id", "run_group_id"), + ] + + @cached_property + def _checkpoint_events_fields(self) -> list[str]: + return [ + c.name # type: ignore[attr-defined] + for c in self._checkpoint_events_columns() + if isinstance(c, Column) + ] + + @cached_property + def _checkpoint_events(self) -> "Table": + return Table( + self.CHECKPOINT_EVENTS_TABLE, + self.db.metadata, + *self._checkpoint_events_columns(), + ) + + @abstractmethod + def _checkpoint_events_insert(self) -> "Insert": ... + + def _checkpoint_events_select(self, *columns) -> "Select": + if not columns: + return self._checkpoint_events.select() + return select(*columns) + + def _checkpoint_events_query(self): + return self._checkpoint_events_select( + *[ + getattr(self._checkpoint_events.c, f) + for f in self._checkpoint_events_fields + ] + ) + @classmethod def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]": """Junction table for dataset versions and jobs many-to-many relationship.""" @@ -2170,28 +2302,39 @@ def _checkpoints_query(self): *[getattr(self._checkpoints.c, f) for f in self._checkpoints_fields] ) - def create_checkpoint( + def get_or_create_checkpoint( self, job_id: str, _hash: str, partial: bool = False, conn: Any | None = None, ) -> Checkpoint: - """ - Creates a new job query step. - """ - checkpoint_id = str(uuid4()) - self.db.execute( - self._checkpoints_insert().values( - id=checkpoint_id, + tx = self.db.transaction() if conn is None else nullcontext(conn) + with tx as active_conn: + query = self._checkpoints_insert().values( + id=str(uuid4()), job_id=job_id, hash=_hash, partial=partial, created_at=datetime.now(timezone.utc), - ), - conn=conn, - ) - return self.get_checkpoint_by_id(checkpoint_id) + ) + + # Use on_conflict_do_nothing to handle race conditions + if not hasattr(query, "on_conflict_do_nothing"): + raise RuntimeError("Database must support on_conflict_do_nothing") + query = query.on_conflict_do_nothing(index_elements=["job_id", "hash"]) + + self.db.execute(query, conn=active_conn) + + checkpoint = self.find_checkpoint( + job_id, _hash, partial=partial, conn=active_conn + ) + if checkpoint is None: + raise RuntimeError( + f"Checkpoint should exist after get_or_create for job_id={job_id}, " + f"hash={_hash}, partial={partial}" + ) + return checkpoint def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]: """List checkpoints by job id.""" @@ -2236,6 +2379,95 @@ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None: return None return self.checkpoint_class.parse(*rows[0]) + def log_checkpoint_event( # noqa: PLR0913 + self, + job_id: str, + event_type: CheckpointEventType, + step_type: CheckpointStepType, + run_group_id: str | None = None, + udf_name: str | None = None, + dataset_name: str | None = None, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_output: int | None = None, + rows_input_reused: int | None = None, + rows_output_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + conn: Any | None = None, + ) -> CheckpointEvent: + """Log a checkpoint event.""" + event_id = str(uuid4()) + timestamp = datetime.now(timezone.utc) + + query = self._checkpoint_events_insert().values( + id=event_id, + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=event_type.value, + step_type=step_type.value, + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_output=rows_output, + rows_input_reused=rows_input_reused, + rows_output_reused=rows_output_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + self.db.execute(query, conn=conn) + + return CheckpointEvent( + id=event_id, + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=event_type, + step_type=step_type, + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_output=rows_output, + rows_input_reused=rows_input_reused, + rows_output_reused=rows_output_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + + def get_checkpoint_events( + self, + job_id: str | None = None, + run_group_id: str | None = None, + conn: Any | None = None, + ) -> Iterator[CheckpointEvent]: + """Get checkpoint events, optionally filtered by job_id or run_group_id.""" + query = self._checkpoint_events_query() + + if job_id is not None: + query = query.where(self._checkpoint_events.c.job_id == job_id) + if run_group_id is not None: + query = query.where(self._checkpoint_events.c.run_group_id == run_group_id) + + query = query.order_by(self._checkpoint_events.c.timestamp) + rows = list(self.db.execute(query, conn=conn)) + + yield from [self.checkpoint_event_class.parse(*r) for r in rows] + def link_dataset_version_to_job( self, dataset_version_id: int, @@ -2271,11 +2503,7 @@ def link_dataset_version_to_job( self.db.execute(update_query, conn=conn) def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: - # Use recursive CTE to walk up the rerun chain - # Format: WITH RECURSIVE ancestors(id, rerun_from_job_id, run_group_id, - # depth) AS (...) - # Include depth tracking to prevent infinite recursion in case of - # circular dependencies + """Get all ancestor job IDs using recursive CTE.""" ancestors_cte = ( self._jobs_select( self._jobs.c.id.label("id"), @@ -2287,8 +2515,6 @@ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: .cte(name="ancestors", recursive=True) ) - # Recursive part: join with parent jobs, incrementing depth and checking limit - # Also ensure we only traverse jobs within the same run_group_id for safety ancestors_recursive = ancestors_cte.union_all( self._jobs_select( self._jobs.c.id.label("id"), @@ -2409,3 +2635,9 @@ def get_dataset_version_for_job_ancestry( ) return self.dataset_version_class.parse(*results[0]) + + def remove_checkpoint(self, checkpoint_id: str, conn: Any | None = None) -> None: + self.db.execute( + self._checkpoints_delete().where(self._checkpoints.c.id == checkpoint_id), + conn=conn, + ) diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index 81e20a69e..e8da58e02 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -31,7 +31,11 @@ from datachain.data_storage.db_engine import DatabaseEngine from datachain.data_storage.schema import DefaultSchema from datachain.dataset import DatasetRecord, StorageURI -from datachain.error import DataChainError, OutdatedDatabaseSchemaError +from datachain.error import ( + DataChainError, + OutdatedDatabaseSchemaError, + TableMissingError, +) from datachain.namespace import Namespace from datachain.project import Project from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_dialect @@ -217,7 +221,6 @@ def _reconnect(self) -> None: def get_table(self, name: str) -> Table: if self.is_closed: - # Reconnect in case of being closed previously. self._reconnect() return super().get_table(name) @@ -255,6 +258,21 @@ def execute_str(self, sql: str, parameters=None) -> sqlite3.Cursor: return self.db.execute(sql) return self.db.execute(sql, parameters) + def list_tables(self, prefix: str = "") -> list[str]: + """List all table names, optionally filtered by prefix.""" + sqlite_master = sqlalchemy.table( + "sqlite_master", + sqlalchemy.column("type"), + sqlalchemy.column("name"), + ) + query = sqlalchemy.select(sqlite_master.c.name).where( + sqlite_master.c.type == "table" + ) + if prefix: + query = query.where(sqlite_master.c.name.like(f"{prefix}%")) + result = self.execute(query) + return [row[0] for row in result.fetchall()] + def add_column(self, table_name: str, column: Column) -> None: """ Add a column to an existing table. @@ -364,15 +382,31 @@ def create_table( *, kind: str | None = None, ) -> None: + """ + Create table. Does nothing if table already exists when if_not_exists=True. + """ self.execute(CreateTable(table, if_not_exists=if_not_exists)) def drop_table(self, table: "Table", if_exists: bool = False) -> None: self.execute(DropTable(table, if_exists=if_exists)) + # Remove the table from metadata to avoid stale references + if table.name in self.metadata.tables: + self.metadata.remove(table) def rename_table(self, old_name: str, new_name: str): + from datachain.error import TableRenameError + comp_old_name = quote_schema(old_name) comp_new_name = quote_schema(new_name) - self.execute_str(f"ALTER TABLE {comp_old_name} RENAME TO {comp_new_name}") + try: + self.execute_str(f"ALTER TABLE {comp_old_name} RENAME TO {comp_new_name}") + except Exception as e: + raise TableRenameError( + f"Failed to rename table from '{old_name}' to '{new_name}': {e}" + ) from e + # Remove old table from metadata to avoid stale references + if old_name in self.metadata.tables: + self.metadata.remove(self.metadata.tables[old_name]) class SQLiteMetastore(AbstractDBMetastore): @@ -497,6 +531,7 @@ def _metastore_tables(self) -> list[Table]: self._datasets_dependencies, self._jobs, self._checkpoints, + self._checkpoint_events, self._dataset_version_jobs, ] @@ -644,6 +679,12 @@ def _jobs_insert(self) -> "Insert": def _checkpoints_insert(self) -> "Insert": return sqlite.insert(self._checkpoints) + # + # Checkpoint Events + # + def _checkpoint_events_insert(self) -> "Insert": + return sqlite.insert(self._checkpoint_events) + def _dataset_version_jobs_insert(self) -> "Insert": return sqlite.insert(self._dataset_version_jobs) @@ -743,6 +784,10 @@ def create_dataset_rows_table( columns: Sequence["sqlalchemy.Column"] = (), if_not_exists: bool = True, ) -> Table: + # Return existing table if it exists (and caller allows it) + if if_not_exists and self.db.has_table(name): + return self.db.get_table(name) + table = self.schema.dataset_row_cls.new_table( name, columns=columns, @@ -812,7 +857,10 @@ def instr(self, source, target) -> "ColumnElement": def get_table(self, name: str) -> sqlalchemy.Table: # load table with latest schema to metadata self._reflect_tables(filter_tables=lambda t, _: t == name) - return self.db.metadata.tables[name] + try: + return self.db.metadata.tables[name] + except KeyError: + raise TableMissingError(f"Table '{name}' not found") from None def python_type(self, col_type: Union["TypeEngine", "SQLType"]) -> Any: if isinstance(col_type, SQLType): @@ -838,7 +886,7 @@ def export_dataset_table( ) -> None: raise NotImplementedError("Exporting dataset table not implemented for SQLite") - def copy_table( + def insert_into( self, table: Table, query: Select, @@ -935,14 +983,20 @@ def _system_row_number_expr(self): def _system_random_expr(self): return self._system_row_number_expr() * 1103515245 + 12345 - def create_pre_udf_table(self, query: "Select") -> "Table": + def create_pre_udf_table(self, query: "Select", name: str) -> "Table": """ Create a temporary table from a query for use in a UDF. + Populates the table from the query, using a staging pattern for atomicity. + + This ensures that if the process crashes during population, the next run + won't find a partially-populated table and incorrectly reuse it. """ columns = [sqlalchemy.Column(c.name, c.type) for c in query.selected_columns] - table = self.create_udf_table(columns) with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar: - self.copy_table(table, query, progress_cb=pbar.update) - - return table + return self.create_table_from_query( + name, + query, + create_fn=lambda n: self.create_udf_table(columns, name=n), + progress_cb=pbar.update, + ) diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 720600150..6a6bf12d0 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -24,6 +24,7 @@ from datachain.data_storage.schema import convert_rows_custom_column_types from datachain.data_storage.serializer import Serializable from datachain.dataset import DatasetRecord, StorageURI +from datachain.error import TableRenameError from datachain.lib.file import File from datachain.lib.model_store import ModelStore from datachain.lib.signal_schema import SignalSchema @@ -555,6 +556,16 @@ def get_table(self, name: str) -> sa.Table: create it """ + def rename_table(self, old_table: sa.Table, new_name: str) -> sa.Table: + """Rename table and return new Table object with same schema.""" + self.db.rename_table(old_table.name, new_name) + + return sa.Table( + new_name, + self.db.metadata, + *[sa.Column(c.name, c.type) for c in old_table.columns], + ) + @abstractmethod def export_dataset_table( self, @@ -995,6 +1006,9 @@ def create_udf_table( Create a temporary table for storing custom signals generated by a UDF. SQLite TEMPORARY tables cannot be directly used as they are process-specific, and UDFs are run in other processes when run in parallel. + + Returns: + table: The created SQLAlchemy Table object """ columns = [ c @@ -1011,16 +1025,55 @@ def create_udf_table( return tbl @abstractmethod - def copy_table( + def insert_into( self, table: sa.Table, query: sa.Select, progress_cb: Callable[[int], None] | None = None, ) -> None: """ - Copy the results of a query into a table. + Insert the results of a query into an existing table. """ + def create_table_from_query( + self, + name: str, + query: sa.Select, + create_fn: Callable[[str], sa.Table], + progress_cb: Callable[[int], None] | None = None, + ) -> sa.Table: + """ + Atomically create and populate a table from a query. + + Uses a staging pattern to ensure the final table only exists when fully + populated. This prevents race conditions and ensures data consistency + if the operation fails mid-way. + + Leftover staging tables (tmp_*) are cleaned by system maintenance. + + Args: + name: Final table name + query: Query to populate the table from + create_fn: Function that creates an empty table given a name + progress_cb: Optional callback for progress updates + + Returns: + The created and populated table + """ + staging_name = self.temp_table_name() + staging_table = create_fn(staging_name) + + self.insert_into(staging_table, query, progress_cb=progress_cb) + + try: + return self.rename_table(staging_table, name) + except TableRenameError: + # Another process won the race - clean up our staging table + self.db.drop_table(staging_table, if_exists=True) + if self.db.has_table(name): + return self.get_table(name) + raise + @abstractmethod def join( self, @@ -1079,7 +1132,7 @@ def subtract_query( ) @abstractmethod - def create_pre_udf_table(self, query: sa.Select) -> sa.Table: + def create_pre_udf_table(self, query: sa.Select, name: str) -> sa.Table: """ Create a temporary table from a query for use in a UDF. """ @@ -1087,7 +1140,7 @@ def create_pre_udf_table(self, query: sa.Select) -> sa.Table: def is_temp_table_name(self, name: str) -> bool: """Returns if the given table name refers to a temporary or no longer needed table.""" - return name.startswith((self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX)) + return name.startswith(self.TMP_TABLE_NAME_PREFIX) def get_temp_table_names(self) -> list[str]: return [ diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 6d6f483ee..e9c92601b 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -74,6 +74,18 @@ def create_dataset_uri( return uri +def create_dataset_full_name( + namespace: str, project: str, name: str, version: str +) -> str: + """ + Creates a full dataset name including namespace, project and version. + Example: + Input: dev, clothes, zalando, 3.0.1 + Output: dev.clothes.zalando@3.0.1 + """ + return f"{namespace}.{project}.{name}@{version}" + + def parse_dataset_name(name: str) -> tuple[str | None, str | None, str]: """Parses dataset name and returns namespace, project and name""" if not name: diff --git a/src/datachain/error.py b/src/datachain/error.py index 5915997d5..1ac4c059a 100644 --- a/src/datachain/error.py +++ b/src/datachain/error.py @@ -99,6 +99,10 @@ class TableMissingError(DataChainError): pass +class TableRenameError(DataChainError): + pass + + class OutdatedDatabaseSchemaError(DataChainError): pass diff --git a/src/datachain/hash_utils.py b/src/datachain/hash_utils.py index 3a0748397..5ba7c8c64 100644 --- a/src/datachain/hash_utils.py +++ b/src/datachain/hash_utils.py @@ -1,13 +1,17 @@ import hashlib import inspect +import logging import textwrap from collections.abc import Sequence from typing import TypeAlias, TypeVar +from uuid import uuid4 from sqlalchemy.sql.elements import ClauseElement, ColumnElement from datachain import json +logger = logging.getLogger("datachain") + T = TypeVar("T", bound=ColumnElement) ColumnLike: TypeAlias = str | T @@ -81,16 +85,71 @@ def hash_column_elements(columns: ColumnLike | Sequence[ColumnLike]) -> str: def hash_callable(func): """ - Calculate a hash from a callable. - Rules: - - Named functions (def) → use source code for stable, cross-version hashing - - Lambdas → use bytecode (deterministic in same Python runtime) + Calculate a deterministic hash from a callable. + + Hashing Strategy: + - **Named functions** (def): Uses source code via inspect.getsourcelines() + → Produces stable hashes across Python versions and sessions + - **Lambdas**: Uses bytecode (func.__code__.co_code) + → Stable within same Python runtime, may differ across Python versions + - **Callable objects** (with __call__): Extracts and hashes the __call__ method + + Supported Callables: + - Regular Python functions defined with 'def' + - Lambda functions + - Classes/instances with __call__ method (uses __call__ method's code) + - Methods (both bound and unbound) + + Limitations and Edge Cases: + - **Mock objects**: Cannot reliably hash Mock(side_effect=...) because the + side_effect is not discoverable via inspection. Use regular functions instead. + - **Built-in functions** (len, str, etc.): Cannot access __code__ attribute. + Returns a random hash that changes on each call. + - **C extensions**: Cannot access source or bytecode. Returns a random hash + that changes on each call. + - **Dynamically generated callables**: If __call__ is created via exec/eval + or the behavior depends on runtime state, the hash won't reflect changes + in behavior. Only the method's code is hashed, not captured state. + + Args: + func: A callable object (function, lambda, method, or object with __call__) + + Returns: + str: SHA256 hexdigest of the callable's code and metadata. For unhashable + callables (C extensions, built-ins), returns a hash of a random UUID that + changes on each invocation. + + Raises: + TypeError: If func is not callable + + Examples: + >>> def my_func(x): return x * 2 + >>> hash_callable(my_func) # Uses source code + 'abc123...' + + >>> hash_callable(lambda x: x * 2) # Uses bytecode + 'def456...' + + >>> class MyCallable: + ... def __call__(self, x): return x * 2 + >>> hash_callable(MyCallable()) # Hashes __call__ method + 'ghi789...' """ if not callable(func): raise TypeError("Expected a callable") + # Handle callable objects (instances with __call__) + # If it's not a function or method, it must be a callable object + if not inspect.isfunction(func) and not inspect.ismethod(func): + # For callable objects, hash the __call__ method instead + func = func.__call__ + # Determine if it is a lambda - is_lambda = func.__name__ == "" + try: + is_lambda = func.__name__ == "" + except AttributeError: + # Some callables (like Mock objects) may not have __name__ + is_lambda = False if not is_lambda: # Try to get exact source of named function @@ -99,20 +158,37 @@ def hash_callable(func): payload = textwrap.dedent("".join(lines)).strip() except (OSError, TypeError): # Fallback: bytecode if source not available - payload = func.__code__.co_code + try: + payload = func.__code__.co_code + except AttributeError: + # C extensions, built-ins - use random UUID + # Returns different hash on each call to avoid caching unhashable + # functions + logger.warning( + "Cannot hash callable %r (likely C extension or built-in). " + "Returning random hash.", + func, + ) + payload = f"unhashable-{uuid4()}" else: # For lambdas, fall back directly to bytecode - payload = func.__code__.co_code + try: + payload = func.__code__.co_code + except AttributeError: + # Unlikely for lambdas, but handle it just in case + logger.warning("Cannot hash lambda %r. Returning random hash.", func) + payload = f"unhashable-{uuid4()}" - # Normalize annotations + # Normalize annotations (may not exist for built-ins/C extensions) + raw_annotations = getattr(func, "__annotations__", {}) annotations = { - k: getattr(v, "__name__", str(v)) for k, v in func.__annotations__.items() + k: getattr(v, "__name__", str(v)) for k, v in raw_annotations.items() } # Extras to distinguish functions with same code but different metadata extras = { - "name": func.__name__, - "defaults": func.__defaults__, + "name": getattr(func, "__name__", ""), + "defaults": getattr(func, "__defaults__", None), "annotations": annotations, } diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index 4b8e8ff2e..9613fc931 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -1,5 +1,4 @@ import copy -import hashlib import logging import os import os.path @@ -23,7 +22,8 @@ from tqdm import tqdm from datachain import json, semver -from datachain.dataset import DatasetRecord +from datachain.checkpoint_event import CheckpointEventType, CheckpointStepType +from datachain.dataset import DatasetRecord, create_dataset_full_name from datachain.delta import delta_disabled from datachain.error import ( JobAncestryDepthExceededError, @@ -63,7 +63,13 @@ ) from datachain.query.schema import DEFAULT_DELIMITER, Column from datachain.sql.functions import path as pathfunc -from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict +from datachain.utils import ( + batched_it, + checkpoints_enabled, + env2bool, + inside_notebook, + row_to_nested_dict, +) from .database import DEFAULT_DATABASE_BATCH_SIZE from .utils import ( @@ -289,6 +295,13 @@ def session(self) -> Session: """Session of the chain.""" return self._query.session + @property + def job(self) -> Job: + """ + Get existing job if running in SaaS, or creating new one if running locally + """ + return self.session.get_or_create_job() + @property def name(self) -> str | None: """Name of the underlying dataset, if there is one.""" @@ -586,19 +599,6 @@ def persist(self) -> "Self": signal_schema=self.signals_schema | SignalSchema({"sys": Sys}), ) - def _calculate_job_hash(self, job_id: str) -> str: - """ - Calculates hash of the job at the place of this chain's save method. - Hash is calculated using previous job checkpoint hash (if exists) and - adding hash of this chain to produce new hash. - """ - last_checkpoint = self.session.catalog.metastore.get_last_checkpoint(job_id) - - return hashlib.sha256( - (bytes.fromhex(last_checkpoint.hash) if last_checkpoint else b"") - + bytes.fromhex(self.hash()) - ).hexdigest() - def save( # type: ignore[override] self, name: str, @@ -632,9 +632,6 @@ def save( # type: ignore[override] self._validate_version(version) self._validate_update_version(update_version) - # get existing job if running in SaaS, or creating new one if running locally - job = self.session.get_or_create_job() - namespace_name, project_name, name = catalog.get_full_dataset_name( name, namespace_name=self._settings.namespace, @@ -642,8 +639,16 @@ def save( # type: ignore[override] ) project = self._get_or_create_project(namespace_name, project_name) + # Calculate hash including dataset name and job context to avoid conflicts + import hashlib + + base_hash = self._query.hash(job_aware=True) + _hash = hashlib.sha256( + (base_hash + f"{namespace_name}/{project_name}/{name}").encode("utf-8") + ).hexdigest() + # Checkpoint handling - _hash, result = self._resolve_checkpoint(name, project, job, kwargs) + result = self._resolve_checkpoint(name, project, _hash, kwargs) if bool(result): # Checkpoint was found and reused print(f"Checkpoint found for dataset '{name}', skipping creation") @@ -670,7 +675,22 @@ def save( # type: ignore[override] ) ) - catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type] + # Log checkpoint event for new dataset save + assert result.version is not None + full_dataset_name = create_dataset_full_name( + namespace_name, project_name, name, result.version + ) + catalog.metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=CheckpointEventType.DATASET_SAVE_COMPLETED, + step_type=CheckpointStepType.DATASET_SAVE, + run_group_id=self.job.run_group_id, + dataset_name=full_dataset_name, + checkpoint_hash=_hash, + ) + + if checkpoints_enabled(): + catalog.metastore.get_or_create_checkpoint(self.job.id, _hash) return result def _validate_version(self, version: str | None) -> None: @@ -699,21 +719,20 @@ def _resolve_checkpoint( self, name: str, project: Project, - job: Job, + job_hash: str, kwargs: dict, - ) -> tuple[str, "DataChain | None"]: + ) -> "DataChain | None": """Check if checkpoint exists and return cached dataset if possible.""" from .datasets import read_dataset metastore = self.session.catalog.metastore - checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET", undefined=True) - - _hash = self._calculate_job_hash(job.id) + ignore_checkpoints = env2bool("DATACHAIN_IGNORE_CHECKPOINTS", undefined=False) if ( - job.rerun_from_job_id - and not checkpoints_reset - and metastore.find_checkpoint(job.rerun_from_job_id, _hash) + checkpoints_enabled() + and self.job.rerun_from_job_id + and not ignore_checkpoints + and metastore.find_checkpoint(self.job.rerun_from_job_id, job_hash) ): # checkpoint found → find which dataset version to reuse @@ -723,7 +742,7 @@ def _resolve_checkpoint( name, project.namespace.name, project.name, - job.id, + self.job.id, ) except JobAncestryDepthExceededError: raise JobAncestryDepthExceededError( @@ -737,18 +756,18 @@ def _resolve_checkpoint( "Checkpoint found but no dataset version for '%s' " "in job ancestry (job_id=%s). Creating new version.", name, - job.id, + self.job.id, ) # Dataset version not found (e.g deleted by user) - skip # checkpoint and recreate - return _hash, None + return None logger.debug( "Reusing dataset version '%s' v%s from job ancestry " "(job_id=%s, dataset_version_id=%s)", name, dataset_version.version, - job.id, + self.job.id, dataset_version.id, ) @@ -765,13 +784,27 @@ def _resolve_checkpoint( # This also updates dataset_version.job_id. metastore.link_dataset_version_to_job( dataset_version.id, - job.id, + self.job.id, is_creator=False, ) - return _hash, chain + # Log checkpoint event + full_dataset_name = create_dataset_full_name( + project.namespace.name, project.name, name, dataset_version.version + ) + metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=CheckpointEventType.DATASET_SAVE_SKIPPED, + step_type=CheckpointStepType.DATASET_SAVE, + run_group_id=self.job.run_group_id, + dataset_name=full_dataset_name, + checkpoint_hash=job_hash, + rerun_from_job_id=self.job.rerun_from_job_id, + ) - return _hash, None + return chain + + return None def _handle_delta( self, @@ -1838,16 +1871,18 @@ def subtract( # type: ignore[override] if on is None and right_on is None: other_columns = set(other._effective_signals_schema.db_signals()) - signals = [ + common_signals = [ c for c in self._effective_signals_schema.db_signals() if c in other_columns ] - if not signals: + if not common_signals: raise DataChainParamsError("subtract(): no common columns") + signals = list(zip(common_signals, common_signals, strict=False)) elif on is not None and right_on is None: right_on = on - signals = list(self.signals_schema.resolve(*on).db_signals()) + resolved_signals = list(self.signals_schema.resolve(*on).db_signals()) + signals = list(zip(resolved_signals, resolved_signals, strict=False)) # type: ignore[arg-type] elif on is None and right_on is not None: raise DataChainParamsError( "'on' must be specified when 'right_on' is provided" diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index c22f0fd61..03b02bafa 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -22,7 +22,7 @@ Partition, RowsOutputBatch, ) -from datachain.utils import safe_closing +from datachain.utils import safe_closing, with_last_flag if TYPE_CHECKING: from collections import abc @@ -109,6 +109,10 @@ class UDFAdapter: def hash(self) -> str: return self.inner.hash() + def output_schema_hash(self) -> str: + """Hash of just the output schema (not including code or inputs).""" + return self.inner.output_schema_hash() + def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy: if use_partitioning: return Partition() @@ -220,6 +224,14 @@ def hash(self) -> str: b"".join([bytes.fromhex(part) for part in parts]) ).hexdigest() + def output_schema_hash(self) -> str: + """Hash of just the output schema (not including code or inputs). + + Used for partial checkpoint hash to detect schema changes while + allowing code-only bug fixes to continue from partial results. + """ + return self.output.hash() + def process(self, *args, **kwargs): """Processing function that needs to be defined by user""" if not self._func: @@ -346,14 +358,14 @@ def _set_stream_recursive( if isinstance(field_value, DataModel): self._set_stream_recursive(field_value, catalog, cache, download_cb) - def _prepare_row(self, row, udf_fields, catalog, cache, download_cb): - row_dict = RowDict(zip(udf_fields, row, strict=False)) - return self._parse_row(row_dict, catalog, cache, download_cb) - - def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb): + def _prepare_row( + self, row, udf_fields, catalog, cache, download_cb, include_id=False + ): row_dict = RowDict(zip(udf_fields, row, strict=False)) udf_input = self._parse_row(row_dict, catalog, cache, download_cb) - return row_dict["sys__id"], *udf_input + if include_id: + return row_dict["sys__id"], *udf_input + return udf_input def noop(*args, **kwargs): @@ -439,8 +451,8 @@ def run( def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": with safe_closing(udf_inputs): for row in udf_inputs: - yield self._prepare_row_and_id( - row, udf_fields, catalog, cache, download_cb + yield self._prepare_row( + row, udf_fields, catalog, cache, download_cb, include_id=True ) prepared_inputs = _prepare_rows(udf_inputs) @@ -502,8 +514,8 @@ def run( n_rows = len(batch) row_ids, *udf_args = zip( *[ - self._prepare_row_and_id( - row, udf_fields, catalog, cache, download_cb + self._prepare_row( + row, udf_fields, catalog, cache, download_cb, include_id=True ) for row in batch ], @@ -547,14 +559,21 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": with safe_closing(udf_inputs): for row in udf_inputs: yield self._prepare_row( - row, udf_fields, catalog, cache, download_cb + row, udf_fields, catalog, cache, download_cb, include_id=True ) def _process_row(row): + row_id, *row = row with safe_closing(self.process(*row)) as result_objs: - for result_obj in result_objs: + for result_obj, is_last in with_last_flag(result_objs): udf_output = self._flatten_row(result_obj) - yield dict(zip(self.signal_names, udf_output, strict=False)) + udf_output = dict(zip(self.signal_names, udf_output, strict=False)) + # Include sys__input_id to track which input generated this + # output. + udf_output["sys__input_id"] = row_id # input id + # Mark as partial=True unless it's the last output + udf_output["sys__partial"] = not is_last + yield udf_output prepared_inputs = _prepare_rows(udf_inputs) prepared_inputs = _prefetch_inputs( @@ -563,7 +582,14 @@ def _process_row(row): download_cb=download_cb, remove_prefetched=bool(self.prefetch) and not cache, ) + with closing(prepared_inputs): + # TODO: Fix limitation where inputs yielding nothing are not tracked in + # processed table. Currently, if process() yields nothing for an input, + # that input's sys__id is never added to the processed table, causing it + # to be re-processed on checkpoint recovery. Solution: yield a marker + # row with sys__input_id when process() yields nothing, then filter + # these marker rows before inserting to output table. for row in prepared_inputs: yield _process_row(row) processed_cb.relative_update(1) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 5cde59b82..7053d4092 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -13,6 +13,7 @@ from functools import wraps from types import GeneratorType from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from uuid import uuid4 import attrs import sqlalchemy @@ -29,15 +30,25 @@ from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper from datachain.catalog.catalog import clone_catalog_with_cache +from datachain.checkpoint import Checkpoint +from datachain.checkpoint_event import ( + CheckpointEventType, + CheckpointStepType, +) from datachain.data_storage.schema import ( PARTITION_COLUMN_ID, partition_col_names, partition_columns, ) from datachain.dataset import DatasetDependency, DatasetStatus, RowDict -from datachain.error import DatasetNotFoundError, QueryScriptCancelError +from datachain.error import ( + DatasetNotFoundError, + QueryScriptCancelError, + TableMissingError, +) from datachain.func.base import Function from datachain.hash_utils import hash_column_elements +from datachain.job import Job from datachain.lib.listing import is_listing_dataset, listing_dataset_expired from datachain.lib.signal_schema import SignalSchema, generate_merge_root_mapping from datachain.lib.udf import JsonSerializationError, UdfError, _get_cache @@ -50,9 +61,11 @@ from datachain.sql.functions.random import rand from datachain.sql.types import SQLType from datachain.utils import ( + checkpoints_enabled, determine_processes, determine_workers, ensure_sequence, + env2bool, filtered_cloudpickle_dumps, get_datachain_executable, safe_closing, @@ -155,7 +168,11 @@ class Step(ABC): @abstractmethod def apply( - self, query_generator: "QueryGenerator", temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + *args, + **kwargs, ) -> "StepResult": """Apply the processing step.""" @@ -233,7 +250,13 @@ def query( Should return select query that calculates desired diff between dataset queries """ - def apply(self, query_generator, temp_tables: list[str]) -> "StepResult": + def apply( + self, + query_generator, + temp_tables: list[str], + *args, + **kwargs, + ) -> "StepResult": source_query = query_generator.select() right_before = len(self.dq.temp_table_names) @@ -261,7 +284,8 @@ def apply(self, query_generator, temp_tables: list[str]) -> "StepResult": diff_q = self.query(source_query, target_query) insert_q = temp_table.insert().from_select( - source_query.selected_columns, diff_q + source_query.selected_columns, # type: ignore[arg-type] + diff_q, ) self.catalog.warehouse.db.execute(insert_q) @@ -429,8 +453,16 @@ def _insert_rows(): udf_kind=udf_kind, ) - warehouse.insert_rows(udf_table, _insert_rows(), batch_size=batch_size) - warehouse.insert_rows_done(udf_table) + try: + warehouse.insert_rows( + udf_table, + _insert_rows(), + batch_size=batch_size, + ) + finally: + # Always flush the buffer even if an exception occurs + # This ensures partial results are visible for checkpoint continuation + warehouse.insert_rows_done(udf_table) def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback: @@ -464,7 +496,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: @frozen class UDFStep(Step, ABC): udf: "UDFAdapter" - catalog: "Catalog" + session: "Session" partition_by: PartitionByType | None = None is_generator = False # Parameters from Settings @@ -485,14 +517,71 @@ def hash_inputs(self) -> str: return hashlib.sha256(b"".join(parts)).hexdigest() @abstractmethod - def create_udf_table(self, query: Select) -> "Table": + def create_output_table(self, name: str) -> "Table": """Method that creates a table where temp udf results will be saved""" - def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]: - """Materialize inputs, ensure sys columns are available, needed for checkpoints, - needed for map to work (merge results)""" - table = self.catalog.warehouse.create_pre_udf_table(query) - return sqlalchemy.select(*table.c), [table] + def _checkpoint_tracking_columns(self) -> list["sqlalchemy.Column"]: + """ + Columns needed for checkpoint tracking in UDF output tables. + + Returns list of columns: + - sys__input_id: Tracks which input produced each output. Allows atomic + writes and reconstruction of processed inputs from output table during + checkpoint recovery. Nullable because mappers use sys__id (1:1 mapping) + while generators populate this field explicitly (1:N mapping). + - sys__partial: Tracks incomplete inputs during checkpoint recovery. + For generators, all rows except the last one for each input are marked + as partial=True. If an input has no row with partial=False, it means the + input was not fully processed and needs to be re-run. Nullable because + mappers (1:1) don't use this field. + """ + return [ + sa.Column("sys__input_id", sa.Integer, nullable=True), + sa.Column("sys__partial", sa.Boolean, nullable=True), + ] + + def get_input_query(self, input_table_name: str, original_query: Select) -> Select: + """Get a select query for UDF input.""" + # Table was created from original_query by create_pre_udf_table, + # so they should have the same columns. However, get_table() reflects + # the table with database-specific types (e.g ClickHouse types) instead of + # SQLTypes. + # To preserve SQLTypes for proper type conversion while keeping columns bound + # to the table (to avoid ambiguous column names), we use type_coerce. + table = self.warehouse.db.get_table(input_table_name) + + # Create a mapping of column names to SQLTypes from original query + orig_col_types = {col.name: col.type for col in original_query.selected_columns} + + # Sys columns are added by create_udf_table and may not be in original query + sys_col_types = { + col.name: col.type for col in self.warehouse.dataset_row_cls.sys_columns() + } + + # Build select using bound columns from table, with type coercion for SQLTypes + select_columns = [] + for table_col in table.c: + if table_col.name in orig_col_types: + # Use type_coerce to preserve SQLType while keeping column bound + # to table. Use label() to preserve the column name + select_columns.append( + sqlalchemy.type_coerce( + table_col, orig_col_types[table_col.name] + ).label(table_col.name) + ) + elif table_col.name in sys_col_types: + # Sys column added by create_udf_table - use known type + select_columns.append( + sqlalchemy.type_coerce( + table_col, sys_col_types[table_col.name] + ).label(table_col.name) + ) + else: + raise RuntimeError( + f"Unexpected column '{table_col.name}' in input table" + ) + + return sqlalchemy.select(*select_columns).select_from(table) @abstractmethod def create_result_query( @@ -503,8 +592,24 @@ def create_result_query( to select """ - def populate_udf_table(self, udf_table: "Table", query: Select) -> None: - if (rows_total := self.catalog.warehouse.query_count(query)) == 0: + def populate_udf_output_table( + self, + udf_table: "Table", + query: Select, + continued: bool = False, + rows_reused: int = 0, + output_rows_reused: int = 0, + rows_total: int | None = None, + ) -> None: + catalog = self.session.catalog + rows_to_process = catalog.warehouse.query_count(query) + if rows_to_process == 0: + logger.debug( + "UDF(%s) [job=%s run_group=%s]: No rows to process, skipping", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + ) return from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE @@ -513,8 +618,19 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: get_udf_distributor_class, ) - workers = determine_workers(self.workers, rows_total=rows_total) - processes = determine_processes(self.parallel, rows_total=rows_total) + workers = determine_workers(self.workers, rows_total=rows_to_process) + processes = determine_processes(self.parallel, rows_total=rows_to_process) + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Processing %d rows " + "(workers=%s, processes=%s, batch_size=%s)", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + rows_to_process, + workers, + processes, + self.batch_size, + ) use_partitioning = self.partition_by is not None batching = self.udf.get_batching(use_partitioning) @@ -522,8 +638,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: udf_distributor_class = get_udf_distributor_class() prefetch = self.udf.prefetch - with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache: - catalog = clone_catalog_with_cache(self.catalog, _cache) + with _get_cache(catalog.cache, prefetch, use_cache=self.cache) as _cache: + catalog = clone_catalog_with_cache(catalog, _cache) try: if udf_distributor_class and not catalog.in_memory: @@ -537,11 +653,15 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: workers=workers, processes=processes, udf_fields=udf_fields, + rows_to_process=rows_to_process, rows_total=rows_total, use_cache=self.cache, is_generator=self.is_generator, min_task_size=self.min_task_size, batch_size=self.batch_size, + continued=continued, + rows_reused=rows_reused, + output_rows_reused=output_rows_reused, ) udf_distributor() return @@ -577,7 +697,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: processes=processes, is_generator=self.is_generator, cache=self.cache, - rows_total=rows_total, + rows_total=rows_to_process, batch_size=self.batch_size, ) @@ -586,7 +706,14 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: exec_cmd = get_datachain_executable() cmd = [*exec_cmd, "internal-run-udf"] envs = dict(os.environ) - envs.update({"PYTHONPATH": os.getcwd()}) + envs.update( + { + "PYTHONPATH": os.getcwd(), + # Mark as DataChain-controlled subprocess to enable + # checkpoints + "DATACHAIN_SUBPROCESS": "1", + } + ) process_data = filtered_cloudpickle_dumps(udf_info) with subprocess.Popen( # noqa: S603 @@ -635,17 +762,19 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: generated_cb.close() except QueryScriptCancelError: - self.catalog.warehouse.close() + catalog.warehouse.close() sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) except (Exception, KeyboardInterrupt): # Close any open database connections if an error is encountered - self.catalog.warehouse.close() + catalog.warehouse.close() raise def create_partitions_table(self, query: Select) -> "Table": """ Create temporary table with group by partitions. """ + catalog = self.session.catalog + if self.partition_by is None: raise RuntimeError("Query must have partition_by set to use partitioning") if (id_col := query.selected_columns.get("sys__id")) is None: @@ -661,14 +790,14 @@ def create_partitions_table(self, query: Select) -> "Table": ] # create table with partitions - tbl = self.catalog.warehouse.create_udf_table(partition_columns()) + tbl = catalog.warehouse.create_udf_table(partition_columns()) # fill table with partitions cols = [ id_col, f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID), ] - self.catalog.warehouse.db.execute( + catalog.warehouse.db.execute( tbl.insert().from_select( cols, query.offset(None).limit(None).with_only_columns(*cols), @@ -681,43 +810,576 @@ def clone(self, partition_by: PartitionByType | None = None) -> "Self": if partition_by is not None: return self.__class__( self.udf, - self.catalog, + self.session, partition_by=partition_by, parallel=self.parallel, workers=self.workers, min_task_size=self.min_task_size, batch_size=self.batch_size, ) - return self.__class__(self.udf, self.catalog) + return self.__class__(self.udf, self.session) + + @property + def _udf_name(self) -> str: + """Get UDF name for logging.""" + return self.udf.inner.verbose_name + + @property + def _job_id_short(self) -> str: + """Get short job_id for logging.""" + return self.job.id[:8] if self.job.id else "none" + + @property + def _run_group_id_short(self) -> str: + """Get short run_group_id for logging.""" + return self.job.run_group_id[:8] if self.job.run_group_id else "none" + + @property + @abstractmethod + def _step_type(self) -> CheckpointStepType: + """Get the step type for checkpoint events.""" + + def _log_event( + self, + event_type: CheckpointEventType, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_output: int | None = None, + rows_input_reused: int | None = None, + rows_output_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + ) -> None: + """Log a checkpoint event and emit a log message.""" + self.metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=event_type, + step_type=self._step_type, + run_group_id=self.job.run_group_id, + udf_name=self._udf_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_output=rows_output, + rows_input_reused=rows_input_reused, + rows_output_reused=rows_output_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + logger.info( + "UDF(%s) [job=%s run_group=%s]: %s - " + "input=%s, processed=%s, output=%s, input_reused=%s, output_reused=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + event_type.value, + rows_input, + rows_processed, + rows_output, + rows_input_reused, + rows_output_reused, + ) + + def _find_udf_checkpoint( + self, _hash: str, partial: bool = False + ) -> Checkpoint | None: + """ + Find a reusable UDF checkpoint for the given hash. + Returns the Checkpoint object if found and checkpoints are enabled, + None otherwise. + """ + ignore_checkpoints = env2bool("DATACHAIN_IGNORE_CHECKPOINTS", undefined=False) + + if ( + checkpoints_enabled() + and self.job.rerun_from_job_id + and not ignore_checkpoints + and ( + checkpoint := self.metastore.find_checkpoint( + self.job.rerun_from_job_id, _hash, partial=partial + ) + ) + ): + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Found %scheckpoint " + "hash=%s from job_id=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + "partial " if partial else "", + _hash[:8], + checkpoint.job_id, + ) + return checkpoint + + return None + + @property + def job(self) -> Job: + return self.session.get_or_create_job() + + @property + def metastore(self): + return self.session.catalog.metastore + + @property + def warehouse(self): + return self.session.catalog.warehouse + + @staticmethod + def input_table_name(run_group_id: str, _hash: str) -> str: + """Run-group-specific input table name. + + Uses run_group_id instead of job_id so all jobs in the same run group + share the same input table, eliminating the need for ancestor traversal. + """ + return f"udf_{run_group_id}_{_hash}_input" + + @staticmethod + def output_table_name(job_id: str, _hash: str) -> str: + """Job-specific final output table name.""" + return f"udf_{job_id}_{_hash}_output" + + @staticmethod + def partial_output_table_name(job_id: str, _hash: str) -> str: + """Job-specific partial output table name.""" + return f"udf_{job_id}_{_hash}_output_partial" + + def get_or_create_input_table(self, query: Select, _hash: str) -> "Table": + """ + Get or create input table for the given hash. + + Uses run_group_id for table naming so all jobs in the same run group + share the same input table. + + Returns the input table. + """ + assert self.job.run_group_id + input_table_name = UDFStep.input_table_name(self.job.run_group_id, _hash) + + # Check if input table already exists (created by ancestor job) + if self.warehouse.db.has_table(input_table_name): + return self.warehouse.get_table(input_table_name) + + # Create input table from original query + return self.warehouse.create_pre_udf_table(query, input_table_name) def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + hash_input: str, + hash_output: str, ) -> "StepResult": - query, tables = self.process_input_query(query_generator.select()) - _query = query + query = query_generator.select() + + # Calculate partial hash that includes output schema + # This allows continuing from partial when only code changes (bug fix), + # but forces re-run when output schema changes (incompatible) + partial_hash = hashlib.sha256( + (hash_input + self.udf.output_schema_hash()).encode() + ).hexdigest() - # Apply partitioning if needed. + # If partition_by is set, we need to create input table first to ensure + # consistent sys__id if self.partition_by is not None: + # Create input table first so partition table can reference the + # same sys__id values + input_table = self.get_or_create_input_table(query, hash_input) + + # Now query from the input table for partition creation + # Use get_input_query to preserve SQLTypes from original query + query = self.get_input_query(input_table.name, query) + partition_tbl = self.create_partitions_table(query) + temp_tables.append(partition_tbl.name) query = query.outerjoin( partition_tbl, partition_tbl.c.sys__id == query.selected_columns.sys__id, ).add_columns(*partition_columns()) - tables = [*tables, partition_tbl] - temp_tables.extend(t.name for t in tables) - udf_table = self.create_udf_table(_query) - temp_tables.append(udf_table.name) - self.populate_udf_table(udf_table, query) - q, cols = self.create_result_query(udf_table, query) + # Aggregator checkpoints are not implemented yet - skip partial continuation + can_continue_from_partial = self.partition_by is None + if ch := self._find_udf_checkpoint(hash_output): + try: + output_table, input_table = self._skip_udf(ch, hash_input, query) + except TableMissingError: + logger.warning( + "UDF(%s) [job=%s run_group=%s]: Output table not found for " + "checkpoint %s. Running UDF from scratch.", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + ch, + ) + output_table, input_table = self._run_from_scratch( + partial_hash, ch.hash, hash_input, query + ) + elif can_continue_from_partial and ( + ch_partial := self._find_udf_checkpoint(partial_hash, partial=True) + ): + output_table, input_table = self._continue_udf( + ch_partial, hash_output, hash_input, query + ) + else: + output_table, input_table = self._run_from_scratch( + partial_hash, hash_output, hash_input, query + ) + + # Create result query from output table + input_query = self.get_input_query(input_table.name, query) + q, cols = self.create_result_query(output_table, input_query) return step_result(q, cols) + def _skip_udf( + self, checkpoint: Checkpoint, hash_input: str, query + ) -> tuple["Table", "Table"]: + """ + Skip UDF by copying existing output table. Returns (output_table, input_table) + """ + print(f"UDF '{self._udf_name}': Skipped, reusing output from checkpoint") + logger.info( + "UDF(%s) [job=%s run_group=%s]: Skipping execution, " + "reusing output from job_id=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + checkpoint.job_id, + ) + existing_output_table = self.warehouse.get_table( + UDFStep.output_table_name(checkpoint.job_id, checkpoint.hash) + ) + output_table = self.warehouse.create_table_from_query( + UDFStep.output_table_name(self.job.id, checkpoint.hash), + sa.select(existing_output_table), + create_fn=self.create_output_table, + ) + + input_table = self.get_or_create_input_table(query, hash_input) + + self.metastore.get_or_create_checkpoint(self.job.id, checkpoint.hash) + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Created checkpoint hash=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + checkpoint.hash[:8], + ) + + # Log checkpoint event with row counts + rows_input = self.warehouse.table_rows_count(input_table) + output_rows_reused = self.warehouse.table_rows_count(output_table) + self._log_event( + CheckpointEventType.UDF_SKIPPED, + checkpoint_hash=checkpoint.hash, + hash_input=hash_input, + hash_output=checkpoint.hash, + rerun_from_job_id=checkpoint.job_id, + rows_input=rows_input, + rows_processed=0, + rows_output=0, + rows_input_reused=rows_input, + rows_output_reused=output_rows_reused, + ) + + # Register skipped UDF in the registry (no-op for local metastores) + self.metastore.add_udf( + udf_id=str(uuid4()), + name=self._udf_name, + status="DONE", + rows_total=rows_input, + job_id=self.job.id, + tasks_created=0, + skipped=True, + rows_reused=rows_input, + output_rows_reused=output_rows_reused, + ) + + return output_table, input_table + + def _run_from_scratch( + self, partial_hash: str, hash_output: str, hash_input: str, query + ) -> tuple["Table", "Table"]: + """Execute UDF from scratch. Returns (output_table, input_table).""" + logger.info( + "UDF(%s) [job=%s run_group=%s]: Running from scratch", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + ) + + partial_checkpoint = None + if checkpoints_enabled(): + partial_checkpoint = self.metastore.get_or_create_checkpoint( + self.job.id, partial_hash, partial=True + ) + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Created partial checkpoint hash=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + partial_hash[:8], + ) + + input_table = self.get_or_create_input_table(query, hash_input) + + partial_output_table = self.create_output_table( + UDFStep.partial_output_table_name(self.job.id, partial_hash), + ) + + if self.partition_by is not None: + input_query = query + else: + input_query = self.get_input_query(input_table.name, query) + + self.populate_udf_output_table(partial_output_table, input_query) + + output_table = self.warehouse.rename_table( + partial_output_table, UDFStep.output_table_name(self.job.id, hash_output) + ) + + if partial_checkpoint: + self.metastore.remove_checkpoint(partial_checkpoint.id) + self.metastore.get_or_create_checkpoint(self.job.id, hash_output) + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Promoted partial to final, hash=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + hash_output[:8], + ) + + # Log checkpoint event with row counts + rows_input = self.warehouse.table_rows_count(input_table) + rows_generated = self.warehouse.table_rows_count(output_table) + self._log_event( + CheckpointEventType.UDF_FROM_SCRATCH, + checkpoint_hash=hash_output, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_input, + rows_output=rows_generated, + rows_input_reused=0, + rows_output_reused=0, + ) + + return output_table, input_table + + def _continue_udf( + self, checkpoint: Checkpoint, hash_output: str, hash_input: str, query + ) -> tuple["Table", "Table"]: + """ + Continue UDF from parent's partial output. Returns (output_table, input_table) + """ + if self.job.rerun_from_job_id is None: + raise RuntimeError( + f"UDF '{self._udf_name}': Cannot continue from checkpoint " + f"without a rerun_from_job_id" + ) + if checkpoint.job_id != self.job.rerun_from_job_id: + raise RuntimeError( + f"UDF '{self._udf_name}': Checkpoint job_id mismatch — " + f"expected {self.job.rerun_from_job_id}, " + f"got {checkpoint.job_id}" + ) + + print(f"UDF '{self._udf_name}': Continuing from checkpoint") + logger.info( + "UDF(%s) [job=%s run_group=%s]: Continuing from partial checkpoint, " + "parent_job_id=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + self.job.rerun_from_job_id, + ) + + partial_checkpoint = self.metastore.get_or_create_checkpoint( + self.job.id, checkpoint.hash, partial=True + ) + + input_table = self.get_or_create_input_table(query, hash_input) + + try: + parent_partial_table = self.warehouse.get_table( + UDFStep.partial_output_table_name( + self.job.rerun_from_job_id, checkpoint.hash + ) + ) + except TableMissingError: + logger.warning( + "UDF(%s) [job=%s run_group=%s]: Parent partial table not found for " + "checkpoint %s, falling back to run from scratch", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + checkpoint, + ) + return self._run_from_scratch( + checkpoint.hash, hash_output, hash_input, query + ) + + incomplete_input_ids = self.find_incomplete_inputs(parent_partial_table) + if incomplete_input_ids: + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Found %d incomplete inputs " + "to re-process", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + len(incomplete_input_ids), + ) + + partial_table_name = UDFStep.partial_output_table_name( + self.job.id, checkpoint.hash + ) + if incomplete_input_ids: + # Filter out incomplete inputs - they will be re-processed + filtered_query = sa.select(parent_partial_table).where( + parent_partial_table.c.sys__input_id.not_in(incomplete_input_ids) + ) + partial_table = self.warehouse.create_table_from_query( + partial_table_name, + filtered_query, + create_fn=self.create_output_table, + ) + else: + partial_table = self.warehouse.create_table_from_query( + partial_table_name, + sa.select(parent_partial_table), + create_fn=self.create_output_table, + ) + + input_query = self.get_input_query(input_table.name, query) + + unprocessed_query = self.calculate_unprocessed_rows( + input_query, + partial_table, + incomplete_input_ids, + ) + + # Count rows before populating with new rows + output_rows_reused = self.warehouse.table_rows_count(partial_table) + rows_input = self.warehouse.table_rows_count(input_table) + rows_to_process = self.warehouse.query_count(unprocessed_query) + rows_reused = rows_input - rows_to_process # input rows reused + + self.populate_udf_output_table( + partial_table, + unprocessed_query, + continued=True, + rows_reused=rows_reused, + output_rows_reused=output_rows_reused, + rows_total=rows_input, + ) + + output_table = self.warehouse.rename_table( + partial_table, UDFStep.output_table_name(self.job.id, hash_output) + ) + + self.metastore.remove_checkpoint(partial_checkpoint.id) + self.metastore.get_or_create_checkpoint(self.job.id, hash_output) + logger.debug( + "UDF(%s) [job=%s run_group=%s]: Promoted partial to final, hash=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + hash_output[:8], + ) + + # Log checkpoint event with row counts + total_output = self.warehouse.table_rows_count(output_table) + rows_generated = total_output - output_rows_reused + self._log_event( + CheckpointEventType.UDF_CONTINUED, + checkpoint_hash=hash_output, + hash_partial=checkpoint.hash, + hash_input=hash_input, + hash_output=hash_output, + rerun_from_job_id=checkpoint.job_id, + rows_input=rows_input, + rows_processed=rows_to_process, + rows_output=rows_generated, + rows_input_reused=rows_reused, + rows_output_reused=output_rows_reused, + ) + + return output_table, input_table + + @abstractmethod + def processed_input_ids_query(self, partial_table: "Table"): + """ + Create a subquery that returns processed input sys__ids from partial table. + + Args: + partial_table: The UDF partial table + + Returns: + A subquery with a single column labeled 'sys__processed_id' containing + processed input IDs + """ + + @abstractmethod + def find_incomplete_inputs(self, partial_table: "Table") -> list[int]: + """ + Find input IDs that were only partially processed before a crash. + For generators (1:N), an input is incomplete if it has output rows but none + with sys__partial=False. For mappers (1:1), this never happens. + + Returns: + List of incomplete input IDs that need to be re-processed + """ + + def calculate_unprocessed_rows( + self, + input_query: Select, + partial_table: "Table", + incomplete_input_ids: None | list[int] = None, + ): + """ + Calculate which input rows haven't been processed yet. + + Args: + input_query: Select query for the UDF input table (with proper types) + partial_table: The UDF partial table + incomplete_input_ids: List of input IDs that were partially processed + and need to be re-run (for generators only) + + Returns: + A filtered query containing only unprocessed rows + """ + incomplete_input_ids = incomplete_input_ids or [] + # Get processed input IDs using subclass-specific logic + processed_input_ids_subquery = self.processed_input_ids_query(partial_table) + + sys_id_col = input_query.selected_columns.sys__id + + # Build filter: rows that haven't been processed OR were incompletely processed + unprocessed_filter: sa.ColumnElement[bool] = sys_id_col.notin_( + sa.select(processed_input_ids_subquery.c.sys__processed_id) + ) + + # Add incomplete inputs to the filter (they need to be re-processed) + if incomplete_input_ids: + unprocessed_filter = sa.or_( + unprocessed_filter, sys_id_col.in_(incomplete_input_ids) + ) + + return input_query.where(unprocessed_filter) + @frozen class UDFSignal(UDFStep): udf: "UDFAdapter" - catalog: "Catalog" + session: "Session" partition_by: PartitionByType | None = None is_generator = False # Parameters from Settings @@ -727,13 +1389,36 @@ class UDFSignal(UDFStep): min_task_size: int | None = None batch_size: int | None = None - def create_udf_table(self, query: Select) -> "Table": - udf_output_columns: list[sqlalchemy.Column[Any]] = [ + @property + def _step_type(self) -> CheckpointStepType: + return CheckpointStepType.UDF_MAP + + def processed_input_ids_query(self, partial_table: "Table"): + """ + For mappers (1:1 mapping): returns sys__id from partial table. + + Since mappers have a 1:1 relationship between input and output, + the sys__id in the partial table directly corresponds to input sys__ids. + """ + # labeling it with sys__processed_id to have common name since for udf signal + # we use sys__id and in generator we use sys__input_id + return sa.select(partial_table.c.sys__id.label("sys__processed_id")).subquery() + + def find_incomplete_inputs(self, partial_table: "Table") -> list[int]: + """ + For mappers (1:1 mapping): always returns empty list. + Mappers cannot have incomplete inputs because each input produces exactly + one output atomically. Either the output exists or it doesn't. + """ + return [] + + def create_output_table(self, name: str) -> "Table": + columns: list[sqlalchemy.Column[Any]] = [ sqlalchemy.Column(col_name, col_type) for (col_name, col_type) in self.udf.output.items() ] - - return self.catalog.warehouse.create_udf_table(udf_output_columns) + columns.extend(self._checkpoint_tracking_columns()) + return self.warehouse.create_udf_table(columns, name=name) def create_result_query( self, udf_table, query @@ -804,7 +1489,7 @@ class RowGenerator(UDFStep): """Extend dataset with new rows.""" udf: "UDFAdapter" - catalog: "Catalog" + session: "Session" partition_by: PartitionByType | None = None is_generator = True # Parameters from Settings @@ -814,25 +1499,62 @@ class RowGenerator(UDFStep): min_task_size: int | None = None batch_size: int | None = None - def create_udf_table(self, query: Select) -> "Table": - warehouse = self.catalog.warehouse + @property + def _step_type(self) -> CheckpointStepType: + return CheckpointStepType.UDF_GEN - table_name = self.catalog.warehouse.udf_table_name() - columns: tuple[Column, ...] = tuple( - Column(name, typ) for name, typ in self.udf.output.items() + def processed_input_ids_query(self, partial_table: "Table"): + """ + For generators (1:N mapping): returns distinct sys__input_id from partial table. + + Since generators can produce multiple outputs per input (1:N relationship), + we use sys__input_id which tracks which input created each output row. + """ + # labeling it with sys__processed_id to have common name since for udf signal + # we use sys__id and in generator we use sys__input_id + return sa.select( + sa.distinct(partial_table.c.sys__input_id).label("sys__processed_id") + ).subquery() + + def find_incomplete_inputs(self, partial_table: "Table") -> list[int]: + """ + For generators (1:N mapping): find inputs missing sys__partial=False row. + + An input is incomplete if it has output rows but none with sys__partial=False, + indicating the process crashed before finishing all outputs for that input. + These inputs need to be re-processed and their partial results filtered out. + """ + # Find inputs that don't have any row with sys__partial=False + incomplete_query = sa.select(sa.distinct(partial_table.c.sys__input_id)).where( + partial_table.c.sys__input_id.not_in( + sa.select(partial_table.c.sys__input_id).where( + partial_table.c.sys__partial == False # noqa: E712 + ) + ) ) - return warehouse.create_dataset_rows_table( - table_name, - columns=columns, - if_not_exists=False, + return [row[0] for row in self.warehouse.db.execute(incomplete_query)] + + def create_output_table(self, name: str) -> "Table": + columns: list[Column] = [ + Column(name, typ) for name, typ in self.udf.output.items() + ] + columns.extend(self._checkpoint_tracking_columns()) + return self.warehouse.create_dataset_rows_table( + name, + columns=tuple(columns), + if_not_exists=True, ) def create_result_query( self, udf_table, query: Select ) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]: udf_table_query = udf_table.select().subquery() + # Exclude sys__input_id and sys__partial - they're only needed for tracking + # during UDF execution and checkpoint recovery udf_table_cols: list[sqlalchemy.Label[Any]] = [ - label(c.name, c) for c in udf_table_query.columns + label(c.name, c) + for c in udf_table_query.columns + if c.name not in ("sys__input_id", "sys__partial") ] def q(*columns): @@ -841,13 +1563,21 @@ def q(*columns): cols = [c for c in udf_table_cols if c.name in names] return sqlalchemy.select(*cols).select_from(udf_table_query) - return q, udf_table_query.columns + return q, [ + c + for c in udf_table_query.columns + if c.name not in ("sys__input_id", "sys__partial") + ] @frozen class SQLClause(Step, ABC): def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + *args, + **kwargs, ) -> StepResult: query = query_generator.select() new_query = self.apply_sql_clause(query) @@ -876,7 +1606,11 @@ def hash_inputs(self) -> str: return hashlib.sha256(b"regenerate_system_columns").hexdigest() def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + *args, + **kwargs, ) -> StepResult: query = query_generator.select() new_query = self.catalog.warehouse._regenerate_system_columns( @@ -1039,7 +1773,11 @@ def hash_inputs(self) -> str: ).hexdigest() def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + *args, + **kwargs, ) -> StepResult: left_before = len(self.query1.temp_table_names) q1 = self.query1.apply_steps().select().subquery() @@ -1126,7 +1864,7 @@ def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery: ) temp_tables.append(temp_table.name) - warehouse.copy_table(temp_table, query) + warehouse.insert_into(temp_table, query) return temp_table.select().subquery(dq.table.name) @@ -1160,7 +1898,11 @@ def validate_expression(self, exp: "ClauseElement", q1, q2): self.validate_expression(c, q1, q2) def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] + self, + query_generator: QueryGenerator, + temp_tables: list[str], + *args, + **kwargs, ) -> StepResult: q1 = self.get_query(self.query1, temp_tables) q2 = self.get_query(self.query2, temp_tables) @@ -1409,23 +2151,46 @@ def _set_starting_step(self, ds: "DatasetRecord") -> None: self.column_types.pop("sys__id") self.project = ds.project + @property + def _starting_step_hash(self) -> str: + if self.starting_step: + return self.starting_step.hash() + assert self.list_ds_name + return self.list_ds_name + + @property + def job(self) -> Job: + """ + Get existing job if running in SaaS, or creating new one if running locally + """ + return self.session.get_or_create_job() + + @property + def _last_checkpoint_hash(self) -> str | None: + last_checkpoint = self.catalog.metastore.get_last_checkpoint(self.job.id) + return last_checkpoint.hash if last_checkpoint else None + def __iter__(self): return iter(self.db_results()) def __or__(self, other): return self.union(other) - def hash(self) -> str: + def hash(self, job_aware: bool = False) -> str: """ Calculates hash of this class taking into account hash of starting step and hashes of each following steps. Ordering is important. + + Args: + job_aware: If True, includes the last checkpoint hash from the job context. """ hasher = hashlib.sha256() - if self.starting_step: - hasher.update(self.starting_step.hash().encode("utf-8")) - else: - assert self.list_ds_name - hasher.update(self.list_ds_name.encode("utf-8")) + + start_hash = self._last_checkpoint_hash if job_aware else None + if start_hash: + hasher.update(start_hash.encode("utf-8")) + + hasher.update(self._starting_step_hash.encode("utf-8")) for step in self.steps: hasher.update(step.hash().encode("utf-8")) @@ -1490,6 +2255,13 @@ def apply_steps(self) -> QueryGenerator: Apply the steps in the query and return the resulting sqlalchemy.SelectBase. """ + hasher = hashlib.sha256() + start_hash = self._last_checkpoint_hash + if start_hash: + hasher.update(start_hash.encode("utf-8")) + + hasher.update(self._starting_step_hash.encode("utf-8")) + self.apply_listing_pre_step() query = self.clone() @@ -1514,9 +2286,18 @@ def apply_steps(self) -> QueryGenerator: result = query.starting_step.apply() self.dependencies.update(result.dependencies) + _hash = hasher.hexdigest() for step in query.steps: + hash_input = _hash + hasher.update(step.hash().encode("utf-8")) + _hash = hasher.hexdigest() + hash_output = _hash + result = step.apply( - result.query_generator, self.temp_table_names + result.query_generator, + self.temp_table_names, + hash_input=hash_input, + hash_output=hash_output, ) # a chain of steps linked by results self.dependencies.update(result.dependencies) @@ -1893,7 +2674,7 @@ def add_signals( query.steps.append( UDFSignal( udf, - self.catalog, + self.session, partition_by=partition_by, parallel=parallel, workers=workers, @@ -1931,7 +2712,7 @@ def generate( steps.append( RowGenerator( udf, - self.catalog, + self.session, partition_by=partition_by, parallel=parallel, workers=workers, @@ -2055,7 +2836,7 @@ def save( dr = self.catalog.warehouse.dataset_rows(dataset) - self.catalog.warehouse.copy_table(dr.get_table(), query.select()) + self.catalog.warehouse.insert_into(dr.get_table(), query.select()) self.catalog.update_dataset_version_with_warehouse_info(dataset, version) diff --git a/src/datachain/query/udf.py b/src/datachain/query/udf.py index 725973abc..9baf8b9d2 100644 --- a/src/datachain/query/udf.py +++ b/src/datachain/query/udf.py @@ -27,7 +27,7 @@ class UdfInfo(TypedDict): class AbstractUDFDistributor(ABC): @abstractmethod - def __init__( + def __init__( # noqa: PLR0913 self, catalog: "Catalog", table: "Table", @@ -37,11 +37,15 @@ def __init__( workers: bool | int, processes: bool | int, udf_fields: list[str], - rows_total: int, - use_cache: bool, + rows_to_process: int, + rows_total: int | None = None, + use_cache: bool = False, is_generator: bool = False, min_task_size: str | int | None = None, batch_size: int | None = None, + continued: bool = False, + rows_reused: int = 0, + output_rows_reused: int = 0, ) -> None: ... @abstractmethod diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 4244246e3..795d14488 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -6,6 +6,7 @@ import random import re import sys +import threading import time from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager @@ -30,6 +31,16 @@ logger = logging.getLogger("datachain") + +class _CheckpointState: + """Internal state for checkpoint management.""" + + _lock = threading.Lock() + disabled = False + warning_shown = False + owner_thread: int | None = None # Thread ident of the checkpoint owner + + NUL = b"\0" TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc) @@ -519,6 +530,75 @@ def uses_glob(path: str) -> bool: return glob.has_magic(os.path.basename(os.path.normpath(path))) +def checkpoints_enabled() -> bool: + """ + Check if checkpoints are enabled for the current execution context. + + Checkpoints are automatically disabled when code runs in: + 1. A user-created subprocess (detected via DATACHAIN_MAIN_PROCESS_PID mismatch) + 2. A thread that is not the original checkpoint owner thread + + DataChain-controlled subprocesses can enable checkpoints by setting + DATACHAIN_SUBPROCESS=1. + + This is because each checkpoint hash depends on the hash of the previous + checkpoint, making the computation order-sensitive. Concurrent execution can + cause non-deterministic hash calculations due to unpredictable ordering. + + Returns: + bool: True if checkpoints are enabled, False if disabled. + """ + # DataChain-controlled subprocess - explicitly allowed + if os.environ.get("DATACHAIN_SUBPROCESS"): + return True + + # Track the original main process PID via environment variable + # This env var is inherited by all child processes (fork and spawn) + current_pid = str(os.getpid()) + main_pid = os.environ.get("DATACHAIN_MAIN_PROCESS_PID") + + if main_pid is None: + # First call ever - we're the main process, set the marker + os.environ["DATACHAIN_MAIN_PROCESS_PID"] = current_pid + main_pid = current_pid + + if current_pid != main_pid: + # We're in a subprocess without DATACHAIN_SUBPROCESS flag + # This is a user-created subprocess - disable checkpoints + with _CheckpointState._lock: + if not _CheckpointState.warning_shown: + logger.warning( + "User subprocess detected. " + "Checkpoints will not be created in this subprocess. " + "Previously created checkpoints remain valid and can be reused." + ) + _CheckpointState.warning_shown = True + return False + + # Thread ownership tracking - first thread to call becomes the owner + # Threads share memory, so all threads see the same _CheckpointState + current_thread = threading.current_thread().ident + with _CheckpointState._lock: + if _CheckpointState.owner_thread is None: + _CheckpointState.owner_thread = current_thread + + is_owner = current_thread == _CheckpointState.owner_thread + + if not is_owner and not _CheckpointState.disabled: + _CheckpointState.disabled = True + if not _CheckpointState.warning_shown: + logger.warning( + "Concurrent thread detected. " + "New checkpoints will not be created from this point forward. " + "Previously created checkpoints remain valid and can be reused. " + "To enable checkpoints, ensure your script runs sequentially " + "without user-created threading." + ) + _CheckpointState.warning_shown = True + + return is_owner and not _CheckpointState.disabled + + def env2bool(var, undefined=False): """ undefined: return value if env var is unset @@ -573,3 +653,26 @@ def ensure_sequence(x) -> Sequence: if isinstance(x, Sequence) and not isinstance(x, (str, bytes)): return x return [x] + + +def with_last_flag(iterable): + """ + Returns flag saying is this element the last in the iterator or not, together + with the element. + + Example: + for item, is_last in with_last_flag(my_gen()): + ... + """ + it = iter(iterable) + try: + prev = next(it) + except StopIteration: + return + + for item in it: + yield prev, False + prev = item + + # last item + yield prev, True diff --git a/tests/conftest.py b/tests/conftest.py index 652e12553..f6d5a56f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -200,26 +200,41 @@ def metastore(monkeypatch): def check_temp_tables_cleaned_up(original_warehouse): - """Ensure that temporary tables are cleaned up.""" + """Ensure that temporary tables are cleaned up. + + UDF tables are now expected to persist (they're shared across jobs), + so we only check for temp tables here. + """ with original_warehouse.clone() as warehouse: - assert [ + temp_tables = [ t for t in sqlalchemy.inspect(warehouse.db.engine).get_table_names() - if t.startswith( - (warehouse.UDF_TABLE_NAME_PREFIX, warehouse.TMP_TABLE_NAME_PREFIX) - ) - ] == [] + if t.startswith(warehouse.TMP_TABLE_NAME_PREFIX) + ] + assert temp_tables == [], f"Temporary tables not cleaned up: {temp_tables}" + + +def cleanup_udf_tables(warehouse): + """Clean up all UDF tables after each test. + + UDF tables are shared across jobs and persist after chain finishes, + so we need to clean them up after each test to prevent interference. + """ + for table_name in warehouse.db.list_tables(prefix=warehouse.UDF_TABLE_NAME_PREFIX): + table = warehouse.db.get_table(table_name) + warehouse.db.drop_table(table, if_exists=True) @pytest.fixture def warehouse(metastore): if os.environ.get("DATACHAIN_WAREHOUSE"): _warehouse = get_warehouse() - yield _warehouse try: check_temp_tables_cleaned_up(_warehouse) finally: + cleanup_udf_tables(_warehouse) _warehouse.cleanup_for_tests() + yield _warehouse else: _warehouse = SQLiteWarehouse(db_file=":memory:") yield _warehouse @@ -280,11 +295,12 @@ def metastore_tmpfile(monkeypatch, tmp_path): def warehouse_tmpfile(tmp_path, metastore_tmpfile): if os.environ.get("DATACHAIN_WAREHOUSE"): _warehouse = get_warehouse() - yield _warehouse try: check_temp_tables_cleaned_up(_warehouse) finally: + cleanup_udf_tables(_warehouse) _warehouse.cleanup_for_tests() + yield _warehouse else: _warehouse = SQLiteWarehouse(db_file=str(tmp_path / "test.db")) yield _warehouse diff --git a/tests/func/checkpoints/__init__.py b/tests/func/checkpoints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/func/checkpoints/test_checkpoint_concurrency.py b/tests/func/checkpoints/test_checkpoint_concurrency.py new file mode 100644 index 000000000..fc327ffaf --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_concurrency.py @@ -0,0 +1,217 @@ +import os +import threading +from concurrent.futures import ThreadPoolExecutor + +import pytest + +import datachain as dc +from datachain.catalog import Catalog +from datachain.query.session import Session +from tests.utils import reset_session_job_state + + +def clone_session(session: Session) -> Session: + """ + Create a new session with cloned metastore and warehouse for thread-safe access. + + This is needed for tests that run DataChain operations in threads, as SQLite + connections cannot be shared across threads. For other databases (PostgreSQL, + Clickhouse), cloning ensures each thread has its own connection. + + Args: + session: The session to clone catalog from. + + Returns: + Session: A new session with cloned catalog components. + """ + catalog = session.catalog + thread_metastore = catalog.metastore.clone() + thread_warehouse = catalog.warehouse.clone() + thread_catalog = Catalog(metastore=thread_metastore, warehouse=thread_warehouse) + return Session("TestSession", catalog=thread_catalog) + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + """Mock is_script_run to return True for stable job names in tests.""" + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +def test_threading_disables_checkpoints(test_session_tmpfile, caplog): + test_session = test_session_tmpfile + metastore = test_session.catalog.metastore + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (main thread) ------------------- + reset_session_job_state() + + dc.read_dataset("nums", session=test_session).save("result1") + + job1 = test_session.get_or_create_job() + checkpoints_main = list(metastore.list_checkpoints(job1.id)) + assert len(checkpoints_main) > 0, "Checkpoint should be created in main thread" + + # -------------- SECOND RUN (in thread) ------------------- + reset_session_job_state() + + thread_ran = {"value": False} + + def run_datachain_in_thread(): + """Run DataChain operation in a thread - checkpoint should NOT be created.""" + thread_session = clone_session(test_session) + try: + thread_ran["value"] = True + dc.read_dataset("nums", session=thread_session).save("result2") + finally: + thread_session.catalog.close() + + thread = threading.Thread(target=run_datachain_in_thread) + thread.start() + thread.join() + + assert thread_ran["value"] is True + + assert any( + "Concurrent thread detected" in record.message for record in caplog.records + ), "Warning about concurrent thread should be logged" + + # Verify no checkpoint was created in thread + job2 = test_session.get_or_create_job() + checkpoints_thread = list(metastore.list_checkpoints(job2.id)) + assert len(checkpoints_thread) == 0, "No checkpoints should be created in thread" + + +def test_threading_with_executor(test_session_tmpfile, caplog): + test_session = test_session_tmpfile + metastore = test_session.catalog.metastore + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (main thread) ------------------- + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("before_threading") + + job1 = test_session.get_or_create_job() + checkpoints_before = len(list(metastore.list_checkpoints(job1.id))) + assert checkpoints_before > 0, "Checkpoint should be created before threading" + + # -------------- SECOND RUN (in thread pool) ------------------- + reset_session_job_state() + + def worker(i): + """Worker function that runs DataChain operations in thread pool.""" + thread_session = clone_session(test_session) + try: + dc.read_dataset("nums", session=thread_session).save(f"result_{i}") + finally: + thread_session.catalog.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + list(executor.map(worker, range(3))) + + assert any( + "Concurrent thread detected" in record.message for record in caplog.records + ), "Warning should be logged when using thread pool" + + job2 = test_session.get_or_create_job() + checkpoints_after = len(list(metastore.list_checkpoints(job2.id))) + assert checkpoints_after == 0, "No checkpoints should be created in thread pool" + + +def test_multiprocessing_disables_checkpoints(test_session, monkeypatch): + catalog = test_session.catalog + metastore = catalog.metastore + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (main process) ------------------- + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("main_result") + + job1 = test_session.get_or_create_job() + checkpoints_main = list(metastore.list_checkpoints(job1.id)) + assert len(checkpoints_main) > 0, "Checkpoint should be created in main process" + + # -------------- SECOND RUN (simulated subprocess) ------------------- + reset_session_job_state() + + # Simulate being in a subprocess by setting DATACHAIN_MAIN_PROCESS_PID + # to a different PID than the current one + monkeypatch.setenv("DATACHAIN_MAIN_PROCESS_PID", str(os.getpid() + 1000)) + + # Run DataChain operation - checkpoint should NOT be created + dc.read_dataset("nums", session=test_session).save("subprocess_result") + + job2 = test_session.get_or_create_job() + checkpoints_subprocess = list(metastore.list_checkpoints(job2.id)) + assert len(checkpoints_subprocess) == 0, ( + "No checkpoints should be created in subprocess" + ) + + +def test_checkpoint_reuse_after_threading(test_session_tmpfile): + test_session = test_session_tmpfile + metastore = test_session.catalog.metastore + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (creates checkpoints) ------------------- + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("result1") + dc.read_dataset("nums", session=test_session).save("result2") + + job1 = test_session.get_or_create_job() + checkpoints_initial = len(list(metastore.list_checkpoints(job1.id))) + assert checkpoints_initial > 0, "Checkpoints should be created initially" + + # Run something in a thread (disables checkpoints globally) + def thread_work(): + thread_session = clone_session(test_session) + try: + dc.read_dataset("nums", session=thread_session).save("thread_result") + finally: + thread_session.catalog.close() + + thread = threading.Thread(target=thread_work) + thread.start() + thread.join() + + # No new checkpoints should have been created in thread + assert len(list(metastore.list_checkpoints(job1.id))) == checkpoints_initial + + # -------------- SECOND RUN (new job, after threading) ------------------- + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("new_result") + + job2 = test_session.get_or_create_job() + checkpoints_new_job = list(metastore.list_checkpoints(job2.id)) + assert len(checkpoints_new_job) > 0, "New job should create checkpoints normally" + + +def test_warning_shown_once(test_session_tmpfile, caplog): + test_session = test_session_tmpfile + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + reset_session_job_state() + + def run_multiple_operations(): + """Run multiple DataChain operations in a thread.""" + thread_session = clone_session(test_session) + try: + # Each operation would check checkpoints_enabled() + dc.read_dataset("nums", session=thread_session).save("result1") + dc.read_dataset("nums", session=thread_session).save("result2") + dc.read_dataset("nums", session=thread_session).save("result3") + finally: + thread_session.catalog.close() + + thread = threading.Thread(target=run_multiple_operations) + thread.start() + thread.join() + + warning_count = sum( + 1 for record in caplog.records if "Concurrent thread detected" in record.message + ) + + assert warning_count == 1, "Warning should be shown only once per process" diff --git a/tests/func/checkpoints/test_checkpoint_events.py b/tests/func/checkpoints/test_checkpoint_events.py new file mode 100644 index 000000000..ac2007e1a --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_events.py @@ -0,0 +1,365 @@ +from collections.abc import Iterator + +import pytest + +import datachain as dc +from datachain.checkpoint_event import CheckpointEventType, CheckpointStepType +from datachain.dataset import create_dataset_full_name +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +@pytest.fixture +def nums_dataset(test_session): + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + +def get_events(metastore, job_id): + return list(metastore.get_checkpoint_events(job_id=job_id)) + + +def get_udf_events(metastore, job_id): + return [ + e + for e in get_events(metastore, job_id) + if e.event_type + in ( + CheckpointEventType.UDF_SKIPPED, + CheckpointEventType.UDF_CONTINUED, + CheckpointEventType.UDF_FROM_SCRATCH, + ) + ] + + +def get_dataset_events(metastore, job_id): + return [ + e + for e in get_events(metastore, job_id) + if e.event_type + in ( + CheckpointEventType.DATASET_SAVE_SKIPPED, + CheckpointEventType.DATASET_SAVE_COMPLETED, + ) + ] + + +def test_map_from_scratch_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + assert len(events) == 1 + + map_event = events[0] + assert map_event.event_type == CheckpointEventType.UDF_FROM_SCRATCH + assert map_event.step_type == CheckpointStepType.UDF_MAP + assert map_event.udf_name == "double" + assert map_event.rows_input == 6 + assert map_event.rows_processed == 6 + assert map_event.rows_output == 6 + assert map_event.rows_input_reused == 0 + assert map_event.rows_output_reused == 0 + assert map_event.rerun_from_job_id is None + assert map_event.hash_partial is None + + +def test_gen_from_scratch_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def duplicate(num) -> Iterator[int]: + yield num + yield num + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).gen(dup=duplicate, output=int).save( + "duplicated" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + gen_event = next(e for e in events if e.udf_name == "duplicate") + + assert gen_event.event_type == CheckpointEventType.UDF_FROM_SCRATCH + assert gen_event.step_type == CheckpointStepType.UDF_GEN + assert gen_event.rows_input == 6 + assert gen_event.rows_processed == 6 + assert gen_event.rows_output == 12 + assert gen_event.rows_input_reused == 0 + assert gen_event.rows_output_reused == 0 + + +def test_map_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + chain = dc.read_dataset("nums", session=test_session).map( + doubled=double, output=int + ) + + reset_session_job_state() + chain.save("doubled") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + chain.save("doubled2") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + assert len(events) == 1 + + map_event = events[0] + assert map_event.event_type == CheckpointEventType.UDF_SKIPPED + assert map_event.udf_name == "double" + assert map_event.rows_input == 6 + assert map_event.rows_processed == 0 + assert map_event.rows_output == 0 + assert map_event.rows_input_reused == 6 + assert map_event.rows_output_reused == 6 + assert map_event.rerun_from_job_id == first_job_id + assert map_event.hash_partial is None + + +def test_gen_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def duplicate(num) -> Iterator[int]: + yield num + yield num + + chain = dc.read_dataset("nums", session=test_session).gen(dup=duplicate, output=int) + + reset_session_job_state() + chain.save("duplicated") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + chain.save("duplicated2") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + gen_event = next(e for e in events if e.udf_name == "duplicate") + + assert gen_event.event_type == CheckpointEventType.UDF_SKIPPED + assert gen_event.rows_input == 6 + assert gen_event.rows_processed == 0 + assert gen_event.rows_output == 0 + assert gen_event.rows_input_reused == 6 + assert gen_event.rows_output_reused == 12 + assert gen_event.rerun_from_job_id == first_job_id + + +def test_map_continued_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + processed = [] + + def buggy_double(num) -> int: + if len(processed) >= 3: + raise Exception("Simulated failure") + processed.append(num) + return num * 2 + + chain = dc.read_dataset("nums", session=test_session).map( + doubled=buggy_double, output=int + ) + + reset_session_job_state() + with pytest.raises(Exception, match="Simulated failure"): + chain.save("doubled") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + processed.clear() + + def fixed_double(num) -> int: + processed.append(num) + return num * 2 + + dc.read_dataset("nums", session=test_session).map( + doubled=fixed_double, output=int + ).save("doubled") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + map_event = next(e for e in events if e.udf_name == "fixed_double") + + assert map_event.event_type == CheckpointEventType.UDF_CONTINUED + assert map_event.rows_input == 6 + assert map_event.rows_input_reused == 3 + assert map_event.rows_output_reused == 3 + assert map_event.rows_processed == 3 + assert map_event.rows_output == 3 + assert map_event.rerun_from_job_id == first_job_id + assert map_event.hash_partial is not None + + +def test_gen_continued_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + processed = [] + + def buggy_gen(num) -> Iterator[int]: + if len(processed) >= 2: + raise Exception("Simulated failure") + processed.append(num) + yield num + yield num * 10 + + chain = dc.read_dataset("nums", session=test_session).gen( + result=buggy_gen, output=int + ) + + reset_session_job_state() + with pytest.raises(Exception, match="Simulated failure"): + chain.save("results") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + processed.clear() + + def fixed_gen(num) -> Iterator[int]: + processed.append(num) + yield num + yield num * 10 + + dc.read_dataset("nums", session=test_session).gen( + result=fixed_gen, output=int + ).save("results") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + gen_event = next(e for e in events if e.udf_name == "fixed_gen") + + assert gen_event.event_type == CheckpointEventType.UDF_CONTINUED + assert gen_event.rows_input == 6 + assert gen_event.rows_input_reused == 2 + assert gen_event.rows_output_reused == 4 + assert gen_event.rows_processed == 4 + assert gen_event.rows_output == 8 + assert gen_event.rerun_from_job_id == first_job_id + assert gen_event.hash_partial is not None + + +def test_dataset_save_completed_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("nums_copy") + job_id = test_session.get_or_create_job().id + + events = get_dataset_events(metastore, job_id) + + assert len(events) == 1 + event = events[0] + + expected_name = create_dataset_full_name( + metastore.default_namespace_name, + metastore.default_project_name, + "nums_copy", + "1.0.0", + ) + assert event.event_type == CheckpointEventType.DATASET_SAVE_COMPLETED + assert event.step_type == CheckpointStepType.DATASET_SAVE + assert event.dataset_name == expected_name + assert event.checkpoint_hash is not None + + +def test_dataset_save_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + chain = dc.read_dataset("nums", session=test_session) + + reset_session_job_state() + chain.save("nums_copy") + first_job_id = test_session.get_or_create_job().id + + first_events = get_dataset_events(metastore, first_job_id) + assert len(first_events) == 1 + assert first_events[0].event_type == CheckpointEventType.DATASET_SAVE_COMPLETED + + reset_session_job_state() + chain.save("nums_copy") + second_job_id = test_session.get_or_create_job().id + + second_events = get_dataset_events(metastore, second_job_id) + + assert len(second_events) == 1 + event = second_events[0] + + expected_name = create_dataset_full_name( + metastore.default_namespace_name, + metastore.default_project_name, + "nums_copy", + "1.0.0", + ) + assert event.event_type == CheckpointEventType.DATASET_SAVE_SKIPPED + assert event.step_type == CheckpointStepType.DATASET_SAVE + assert event.dataset_name == expected_name + assert event.rerun_from_job_id is not None + + +def test_events_by_run_group(test_session, monkeypatch, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + first_job = test_session.get_or_create_job() + + reset_session_job_state() + second_job_id = metastore.create_job( + "test-job", + "echo 1", + rerun_from_job_id=first_job.id, + run_group_id=first_job.run_group_id, + ) + monkeypatch.setenv("DATACHAIN_JOB_ID", second_job_id) + + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled2" + ) + + run_group_events = list( + metastore.get_checkpoint_events(run_group_id=first_job.run_group_id) + ) + + job_ids = {e.job_id for e in run_group_events} + assert first_job.id in job_ids + assert second_job_id in job_ids + + +def test_hash_fields_populated(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + + for event in events: + assert event.checkpoint_hash is not None + assert event.hash_input is not None + assert event.hash_output is not None + if event.event_type == CheckpointEventType.UDF_FROM_SCRATCH: + assert event.hash_partial is None diff --git a/tests/func/checkpoints/test_checkpoint_invalidation.py b/tests/func/checkpoints/test_checkpoint_invalidation.py new file mode 100644 index 000000000..9c0080ec3 --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_invalidation.py @@ -0,0 +1,281 @@ +from collections.abc import Iterator + +import pytest + +import datachain as dc +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +@pytest.fixture +def nums_dataset(test_session): + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + +def test_udf_code_change_triggers_rerun(test_session, monkeypatch): + map1_calls = [] + map2_calls = [] + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + chain = dc.read_dataset("nums", session=test_session).settings(batch_size=2) + + # Run 1: map1 succeeds, map2 fails + def mapper1_v1(num: int) -> int: + map1_calls.append(num) + return num * 2 + + def mapper2_failing(doubled: int) -> int: + if len(map2_calls) >= 3: + raise Exception("Map2 failure") + map2_calls.append(doubled) + return doubled * 3 + + reset_session_job_state() + with pytest.raises(Exception, match="Map2 failure"): + (chain.map(doubled=mapper1_v1).map(tripled=mapper2_failing).save("results")) + + assert len(map1_calls) == 6 # All processed + assert len(map2_calls) == 3 # Processed 3 before failing + + # Run 2: Change map1 code, map2 fixed - both should rerun + def mapper1_v2(num: int) -> int: + map1_calls.append(num) + return num * 2 + 1 # Different code = different hash + + def mapper2_fixed(doubled: int) -> int: + map2_calls.append(doubled) + return doubled * 3 + + map1_calls.clear() + map2_calls.clear() + reset_session_job_state() + (chain.map(doubled=mapper1_v2).map(tripled=mapper2_fixed).save("results")) + + assert len(map1_calls) == 6 # Reran due to code change + assert len(map2_calls) == 6 # Ran all (no partial to continue from) + result = dc.read_dataset("results", session=test_session).to_list("tripled") + # nums [1,2,3,4,5,6] → x2+1 = [3,5,7,9,11,13] → x3 = [9,15,21,27,33,39] + assert sorted(result) == sorted([(i,) for i in [9, 15, 21, 27, 33, 39]]) + + # Run 3: Keep both unchanged - both should skip + map1_calls.clear() + map2_calls.clear() + reset_session_job_state() + (chain.map(doubled=mapper1_v2).map(tripled=mapper2_fixed).save("results")) + + assert len(map1_calls) == 0 # Skipped (checkpoint found) + assert len(map2_calls) == 0 # Skipped (checkpoint found) + result = dc.read_dataset("results", session=test_session).to_list("tripled") + assert sorted(result) == sorted([(i,) for i in [9, 15, 21, 27, 33, 39]]) + + +def test_generator_output_schema_change_triggers_rerun(test_session, monkeypatch): + processed_nums_v1 = [] + processed_nums_v2 = [] + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (INT OUTPUT, FAILS) ------------------- + def generator_v1_int(num) -> Iterator[int]: + """Generator version 1: yields int, fails on num=4.""" + processed_nums_v1.append(num) + if num == 4: + raise Exception(f"Simulated failure on num={num}") + yield num * 10 + yield num * num + + reset_session_job_state() + + chain = dc.read_dataset("nums", session=test_session).settings(batch_size=2) + + with pytest.raises(Exception, match="Simulated failure"): + chain.gen(result=generator_v1_int, output=int).save("gen_results") + + assert len(processed_nums_v1) > 0 + + # -------------- SECOND RUN (STR OUTPUT, DIFFERENT SCHEMA) ------------------- + def generator_v2_str(num) -> Iterator[str]: + """Generator version 2: yields str instead of int (schema change!).""" + processed_nums_v2.append(num) + yield f"value_{num * 10}" + yield f"square_{num * num}" + + reset_session_job_state() + + # Use generator with different output type - should run from scratch + chain.gen(result=generator_v2_str, output=str).save("gen_results") + + # Verify ALL inputs were processed in second run (not continuing from partial) + assert sorted(processed_nums_v2) == sorted([1, 2, 3, 4, 5, 6]), ( + "All inputs should be processed when schema changes" + ) + + result = sorted( + dc.read_dataset("gen_results", session=test_session).to_list("result") + ) + expected = sorted( + [ + ("square_1",), + ("value_10",), # num=1 + ("square_4",), + ("value_20",), # num=2 + ("square_9",), + ("value_30",), # num=3 + ("square_16",), + ("value_40",), # num=4 + ("square_25",), + ("value_50",), # num=5 + ("square_36",), + ("value_60",), # num=6 + ] + ) + assert result == expected + + +def test_mapper_output_schema_change_triggers_rerun(test_session, monkeypatch): + processed_nums_v1 = [] + processed_nums_v2 = [] + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (INT OUTPUT, FAILS) ------------------- + def mapper_v1_int(num) -> int: + processed_nums_v1.append(num) + if num == 4: + raise Exception(f"Simulated failure on num={num}") + return num * 10 + + reset_session_job_state() + + chain = dc.read_dataset("nums", session=test_session).settings(batch_size=2) + + with pytest.raises(Exception, match="Simulated failure"): + chain.map(result=mapper_v1_int, output=int).save("map_results") + + assert len(processed_nums_v1) > 0 + + # -------------- SECOND RUN (STR OUTPUT, DIFFERENT SCHEMA) ------------------- + def mapper_v2_str(num) -> str: + """Mapper version 2: returns str instead of int (schema change!).""" + processed_nums_v2.append(num) + return f"value_{num * 10}" + + reset_session_job_state() + + # Use mapper with different output type - should run from scratch + chain.map(result=mapper_v2_str, output=str).save("map_results") + + # Verify ALL inputs were processed in second run (not continuing from partial) + assert sorted(processed_nums_v2) == sorted([1, 2, 3, 4, 5, 6]), ( + "All inputs should be processed when schema changes" + ) + + result = sorted( + dc.read_dataset("map_results", session=test_session).to_list("result") + ) + expected = sorted( + [ + ("value_10",), # num=1 + ("value_20",), # num=2 + ("value_30",), # num=3 + ("value_40",), # num=4 + ("value_50",), # num=5 + ("value_60",), # num=6 + ] + ) + assert result == expected + + +@pytest.mark.parametrize( + "batch_size,fail_after_count", + [ + (2, 2), # batch_size=2: Fail after processing 2 partitions + (3, 2), # batch_size=3: Fail after processing 2 partitions + (10, 2), # batch_size=10: Fail after processing 2 partitions + ], +) +def test_aggregator_always_runs_from_scratch( + test_session, + monkeypatch, + nums_dataset, + batch_size, + fail_after_count, +): + processed_partitions = [] + + def buggy_aggregator(letter, num) -> Iterator[tuple[str, int]]: + """ + Buggy aggregator that fails before processing the (fail_after_count+1)th + partition. + letter: partition key value (A, B, or C) + num: iterator of num values in that partition + """ + if len(processed_partitions) >= fail_after_count: + raise Exception( + f"Simulated failure after {len(processed_partitions)} partitions" + ) + nums_list = list(num) + processed_partitions.append(nums_list) + # Yield tuple of (letter, sum) to preserve partition key in output + yield letter[0], sum(n for n in nums_list) + + def fixed_aggregator(letter, num) -> Iterator[tuple[str, int]]: + nums_list = list(num) + processed_partitions.append(nums_list) + # Yield tuple of (letter, sum) to preserve partition key in output + yield letter[0], sum(n for n in nums_list) + + nums_data = [1, 2, 3, 4, 5, 6] + leters_data = ["A", "A", "B", "B", "C", "C"] + dc.read_values(num=nums_data, letter=leters_data, session=test_session).save( + "nums_letters" + ) + + # -------------- FIRST RUN (FAILS WITH BUGGY AGGREGATOR) ------------------- + reset_session_job_state() + + chain = dc.read_dataset("nums_letters", session=test_session).settings( + batch_size=batch_size + ) + + with pytest.raises(Exception, match="Simulated failure after"): + chain.agg( + total=buggy_aggregator, + partition_by="letter", + ).save("agg_results") + + first_run_count = len(processed_partitions) + + assert first_run_count == fail_after_count + + # -------------- SECOND RUN (FIXED AGGREGATOR) ------------------- + reset_session_job_state() + + processed_partitions.clear() + + chain.agg( + total=fixed_aggregator, + partition_by="letter", + ).save("agg_results") + + second_run_count = len(processed_partitions) + + assert sorted( + dc.read_dataset("agg_results", session=test_session).to_list( + "total_0", "total_1" + ) + ) == sorted( + [ + ("A", 3), # group A: 1 + 2 = 3 + ("B", 7), # group B: 3 + 4 = 7 + ("C", 11), # group C: 5 + 6 = 11 + ] + ) + + # should re-process everything + assert second_run_count == 3 diff --git a/tests/func/checkpoints/test_checkpoint_job_linking.py b/tests/func/checkpoints/test_checkpoint_job_linking.py new file mode 100644 index 000000000..e354af97a --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_job_linking.py @@ -0,0 +1,199 @@ +import pytest +import sqlalchemy as sa + +import datachain as dc +from datachain.error import JobAncestryDepthExceededError +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +@pytest.fixture +def nums_dataset(test_session): + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + +def get_dataset_versions_for_job(metastore, job_id): + """ + Returns: + List of tuples (dataset_name, version, is_creator) + """ + query = ( + sa.select( + metastore._datasets_versions.c.dataset_id, + metastore._datasets_versions.c.version, + metastore._dataset_version_jobs.c.is_creator, + ) + .select_from( + metastore._dataset_version_jobs.join( + metastore._datasets_versions, + metastore._dataset_version_jobs.c.dataset_version_id + == metastore._datasets_versions.c.id, + ) + ) + .where(metastore._dataset_version_jobs.c.job_id == job_id) + ) + + results = list(metastore.db.execute(query)) + + dataset_versions = [] + for dataset_id, version, is_creator in results: + dataset_query = sa.select(metastore._datasets.c.name).where( + metastore._datasets.c.id == dataset_id + ) + dataset_name = next(metastore.db.execute(dataset_query))[0] + dataset_versions.append((dataset_name, version, bool(is_creator))) + + return sorted(dataset_versions) + + +def test_dataset_job_linking(test_session, monkeypatch, nums_dataset): + """Test that dataset versions are correctly linked to jobs via many-to-many. + + This test verifies that datasets should appear in ALL jobs that use them in + the single job "chain", not just the job that created them. + """ + catalog = test_session.catalog + metastore = catalog.metastore + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(False)) + + chain = dc.read_dataset("nums", session=test_session) + + # -------------- FIRST RUN: Create dataset ------------------- + reset_session_job_state() + chain.save("nums_linked") + job1_id = test_session.get_or_create_job().id + + # Verify job1 has the dataset associated (as creator) + job1_datasets = get_dataset_versions_for_job(metastore, job1_id) + assert len(job1_datasets) == 1 + assert job1_datasets[0] == ("nums_linked", "1.0.0", True) + + # -------------- SECOND RUN: Reuse dataset via checkpoint ------------------- + reset_session_job_state() + chain.save("nums_linked") + job2_id = test_session.get_or_create_job().id + + # Verify job2 also has the dataset associated (not creator) + job2_datasets = get_dataset_versions_for_job(metastore, job2_id) + assert len(job2_datasets) == 1 + assert job2_datasets[0] == ("nums_linked", "1.0.0", False) + + # Verify job1 still has it + job1_datasets = get_dataset_versions_for_job(metastore, job1_id) + assert len(job1_datasets) == 1 + assert job1_datasets[0][2] # still creator + + # -------------- THIRD RUN: Another reuse ------------------- + reset_session_job_state() + chain.save("nums_linked") + job3_id = test_session.get_or_create_job().id + + # Verify job3 also has the dataset associated (not creator) + job3_datasets = get_dataset_versions_for_job(metastore, job3_id) + assert len(job3_datasets) == 1 + assert job3_datasets[0] == ("nums_linked", "1.0.0", False) + + # Verify get_dataset_version_for_job_ancestry works correctly + dataset = catalog.get_dataset("nums_linked") + found_version = metastore.get_dataset_version_for_job_ancestry( + "nums_linked", + dataset.project.namespace.name, + dataset.project.name, + job3_id, + ) + assert found_version.version == "1.0.0" + + +def test_dataset_job_linking_with_reset(test_session, monkeypatch, nums_dataset): + catalog = test_session.catalog + metastore = catalog.metastore + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(True)) + + chain = dc.read_dataset("nums", session=test_session) + + # -------------- FIRST RUN ------------------- + reset_session_job_state() + chain.save("nums_reset") + job1_id = test_session.get_or_create_job().id + + # Verify job1 created version 1.0.0 + job1_datasets = get_dataset_versions_for_job(metastore, job1_id) + assert len(job1_datasets) == 1 + assert job1_datasets[0] == ("nums_reset", "1.0.0", True) + + # -------------- SECOND RUN ------------------- + reset_session_job_state() + chain.save("nums_reset") + job2_id = test_session.get_or_create_job().id + + job2_datasets = get_dataset_versions_for_job(metastore, job2_id) + assert len(job2_datasets) == 1 + assert job2_datasets[0] == ("nums_reset", "1.0.1", True) + + job1_datasets = get_dataset_versions_for_job(metastore, job1_id) + assert len(job1_datasets) == 1 + assert job1_datasets[0] == ("nums_reset", "1.0.0", True) + + +def test_dataset_version_job_id_updates_to_latest( + test_session, monkeypatch, nums_dataset +): + catalog = test_session.catalog + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(False)) + + chain = dc.read_dataset("nums", session=test_session) + name = "nums_jobid" + + # -------------- FIRST RUN ------------------- + reset_session_job_state() + chain.save(name) + job1_id = test_session.get_or_create_job().id + + dataset = catalog.get_dataset(name) + assert dataset.get_version(dataset.latest_version).job_id == job1_id + + # -------------- SECOND RUN: Reuse via checkpoint ------------------- + reset_session_job_state() + chain.save(name) + job2_id = test_session.get_or_create_job().id + + # job_id should now point to job2 (latest) + dataset = catalog.get_dataset(name) + assert dataset.get_version(dataset.latest_version).job_id == job2_id + + # -------------- THIRD RUN: Another reuse ------------------- + reset_session_job_state() + chain.save(name) + job3_id = test_session.get_or_create_job().id + + # job_id should now point to job3 (latest) + dataset = catalog.get_dataset(name) + assert dataset.get_version(dataset.latest_version).job_id == job3_id + + +def test_job_ancestry_depth_exceeded(test_session, monkeypatch, nums_dataset): + from datachain.data_storage import metastore + + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(False)) + # Mock max depth to a small value (3) for testing + monkeypatch.setattr(metastore, "JOB_ANCESTRY_MAX_DEPTH", 3) + + chain = dc.read_dataset("nums", session=test_session) + + max_attempts = 10 # Safety limit to prevent infinite loop + for _ in range(max_attempts): + reset_session_job_state() + try: + chain.save("nums_depth") + except JobAncestryDepthExceededError as exc_info: + assert "too deep" in str(exc_info) + assert "from scratch" in str(exc_info) + # Test passed - we hit the max depth + return + + # If we get here, we never hit the max depth error + pytest.fail(f"Expected JobAncestryDepthExceededError after {max_attempts} saves") diff --git a/tests/func/checkpoints/test_checkpoint_parallel.py b/tests/func/checkpoints/test_checkpoint_parallel.py new file mode 100644 index 000000000..59295edde --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_parallel.py @@ -0,0 +1,171 @@ +from collections.abc import Iterator + +import pytest + +import datachain as dc +from datachain.error import DatasetNotFoundError +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +def test_checkpoints_parallel(test_session_tmpfile, monkeypatch): + def mapper_fail(num) -> int: + raise Exception("Error") + + test_session = test_session_tmpfile + catalog = test_session.catalog + + dc.read_values(num=list(range(1000)), session=test_session).save("nums") + + chain = dc.read_dataset("nums", session=test_session).settings(parallel=True) + + # -------------- FIRST RUN ------------------- + reset_session_job_state() + chain.save("nums1") + chain.save("nums2") + with pytest.raises(RuntimeError): + chain.map(new=mapper_fail).save("nums3") + first_job_id = test_session.get_or_create_job().id + + catalog.get_dataset("nums1") + catalog.get_dataset("nums2") + with pytest.raises(DatasetNotFoundError): + catalog.get_dataset("nums3") + + # -------------- SECOND RUN ------------------- + reset_session_job_state() + chain.save("nums1") + chain.save("nums2") + chain.save("nums3") + second_job_id = test_session.get_or_create_job().id + + assert len(catalog.get_dataset("nums1").versions) == 1 + assert len(catalog.get_dataset("nums2").versions) == 1 + assert len(catalog.get_dataset("nums3").versions) == 1 + + assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 3 + assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3 + + +def test_udf_generator_continue_parallel(test_session_tmpfile, monkeypatch): + test_session = test_session_tmpfile + + processed_nums = [] + run_count = {"count": 0} + + def gen_multiple(num) -> Iterator[int]: + """Generator that yields multiple outputs per input.""" + processed_nums.append(num) + # Fail on input 4 in first run only + if num == 4 and run_count["count"] == 0: + raise Exception(f"Simulated failure on num={num}") + yield num * 10 + yield num + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + # -------------- FIRST RUN (FAILS) ------------------- + reset_session_job_state() + + chain = ( + dc.read_dataset("nums", session=test_session) + .settings(parallel=2, batch_size=2) + .gen(result=gen_multiple, output=int) + ) + + with pytest.raises(RuntimeError): + chain.save("results") + + # -------------- SECOND RUN (CONTINUE) ------------------- + reset_session_job_state() + + processed_nums.clear() + run_count["count"] += 1 + + # Should complete successfully + chain.save("results") + + result = ( + dc.read_dataset("results", session=test_session) + .order_by("result") + .to_list("result") + ) + # Each of 6 inputs yields 2 outputs: [10,1], [20,2], ..., [60,6] + assert result == [ + (1,), + (2,), + (3,), + (4,), + (5,), + (6,), + (10,), + (20,), + (30,), + (40,), + (50,), + (60,), + ] + + # Verify only unprocessed inputs were processed in second run + # (should be less than all 6 inputs) + assert len(processed_nums) < 6 + + +@pytest.mark.parametrize("parallel", [2, 4, 6, 20]) +def test_parallel_checkpoint_recovery_no_duplicates(test_session_tmpfile, parallel): + """Test that parallel checkpoint recovery processes all inputs exactly once. + + Verifies: + - No duplicate outputs in final result + - All inputs produce correct outputs (n^2) + - Correct total number of outputs (100) + """ + test_session = test_session_tmpfile + + # Track run count to fail only on first run + run_count = {"value": 0} + + def gen_square(num) -> Iterator[int]: + # Fail on input 95 during first run only + if num == 95 and run_count["value"] == 0: + raise Exception(f"Simulated failure on num={num}") + + yield num * num + + dc.read_values(num=list(range(1, 101)), session=test_session).save("nums") + reset_session_job_state() + + chain = ( + dc.read_dataset("nums", session=test_session) + .order_by("num") + .settings(parallel=parallel, batch_size=2) + .gen(result=gen_square, output=int) + ) + + # First run - fails on num=95 + with pytest.raises(RuntimeError): + chain.save("results") + + # Second run - should recover and complete + reset_session_job_state() + run_count["value"] += 1 + chain.save("results") + + # Verify: Final result has correct number of outputs and values + result = dc.read_dataset("results", session=test_session).to_list("result") + assert len(result) == 100, f"Expected 100 outputs, got {len(result)}" + + # Verify: No duplicate outputs + output_values = [row[0] for row in result] + assert len(output_values) == len(set(output_values)), ( + "Found duplicate outputs in final result" + ) + + # Verify: All expected outputs present (1^2, 2^2, ..., 100^2) + expected = {i * i for i in range(1, 101)} + actual = set(output_values) + assert actual == expected, f"Outputs don't match. Missing: {expected - actual}" diff --git a/tests/func/checkpoints/test_checkpoint_recovery.py b/tests/func/checkpoints/test_checkpoint_recovery.py new file mode 100644 index 000000000..f9edb862d --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_recovery.py @@ -0,0 +1,697 @@ +from collections.abc import Iterator + +import pytest + +import datachain as dc +from datachain.lib.file import File +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +@pytest.fixture +def nums_dataset(test_session): + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + +@pytest.mark.parametrize( + "batch_size,fail_after_count", + [ + (2, 3), # batch_size=2: Fail after 3 rows + (3, 4), # batch_size=3: Fail after 4 rows + (5, 3), # batch_size=5: Fail after 3 rows + ], +) +def test_udf_signals_continue_from_partial( + test_session_tmpfile, + monkeypatch, + nums_dataset, + batch_size, + fail_after_count, +): + """Test continuing UDF execution from partial output table. + + Tests with different batch sizes to ensure partial results are correctly handled + regardless of batch boundaries. + """ + test_session = test_session_tmpfile + processed_nums = [] + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + def process_buggy(num) -> int: + if len(processed_nums) >= fail_after_count: + raise Exception(f"Simulated failure after {len(processed_nums)} rows") + processed_nums.append(num) + return num * 10 + + chain = dc.read_dataset("nums", session=test_session).settings( + batch_size=batch_size + ) + + # -------------- FIRST RUN (FAILS WITH BUGGY UDF) ------------------- + reset_session_job_state() + + with pytest.raises(Exception, match="Simulated failure after"): + chain.map(result=process_buggy, output=int).save("results") + + assert len(processed_nums) == fail_after_count + + # -------------- SECOND RUN (FIXED UDF) ------------------- + reset_session_job_state() + + processed_nums.clear() + + def process_fixed(num) -> int: + processed_nums.append(num) + return num * 10 + + chain.map(result=process_fixed, output=int).save("results") + + result = dc.read_dataset("results", session=test_session).to_list("result") + assert sorted(result) == [(10,), (20,), (30,), (40,), (50,), (60,)] + + # Second run should process remaining rows (checkpoint continuation working) + assert 0 < len(processed_nums) <= 6 + + +@pytest.mark.parametrize( + "batch_size,fail_after_count", + [ + (2, 2), # batch_size=2: Fail after 2 inputs (4 outputs → 2 batches saved) + (3, 4), # batch_size=3: Fail after 4 inputs + (10, 3), # batch_size=10: Fail after 3 inputs + ], +) +def test_udf_generator_continue_from_partial( + test_session, + monkeypatch, + batch_size, + fail_after_count, +): + processed_nums = [] + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + def buggy_generator(num) -> Iterator[int]: + if len(processed_nums) >= fail_after_count: + raise Exception(f"Simulated failure after {len(processed_nums)} inputs") + processed_nums.append(num) + yield num * 10 + yield num * num + + chain = dc.read_dataset("nums", session=test_session).settings( + batch_size=batch_size + ) + + # -------------- FIRST RUN (FAILS WITH BUGGY GENERATOR) ------------------- + reset_session_job_state() + + with pytest.raises(Exception, match="Simulated failure after"): + chain.gen(value=buggy_generator, output=int).save("gen_results") + + assert len(processed_nums) == fail_after_count + + # -------------- SECOND RUN (FIXED GENERATOR) ------------------- + reset_session_job_state() + + processed_nums.clear() + + def fixed_generator(num) -> Iterator[int]: + processed_nums.append(num) + yield num * 10 + yield num * num + + chain.gen(value=fixed_generator, output=int).save("gen_results") + + result = sorted( + dc.read_dataset("gen_results", session=test_session).to_list("value") + ) + expected = sorted( + [ + (1,), + (10,), + (4,), + (20,), + (9,), + (30,), + (16,), + (40,), + (25,), + (50,), + (36,), + (60,), + ] + ) + + assert result == expected + + # Second run should process remaining inputs (checkpoint continuation working) + assert 0 < len(processed_nums) <= 6 + + +def test_generator_incomplete_input_recovery(test_session): + """Test full recovery flow from incomplete inputs. + + Tests the complete checkpoint recovery mechanism: + 1. First run fails, leaving some inputs incomplete (missing final row) + 2. Second run detects incomplete inputs + 3. Filters out partial results from incomplete inputs + 4. Re-processes incomplete inputs + 5. Final results are correct (no duplicates, no missing values) + """ + processed_inputs = [] + run_count = [0] + numbers = [6, 2, 8, 7] + + def gen_multiple(num) -> Iterator[int]: + processed_inputs.append(num) + for i in range(5): + if num == 8 and i == 2 and run_count[0] == 0: + raise Exception("Simulated crash") + yield num * 100 + i + + dc.read_values(num=numbers, session=test_session).save("nums") + + # -------------- FIRST RUN (FAILS) ------------------- + reset_session_job_state() + processed_inputs.clear() + + with pytest.raises(Exception, match="Simulated crash"): + ( + dc.read_dataset("nums", session=test_session) + .order_by("num") + .settings(batch_size=2) # Small batch for partial commits + .gen(result=gen_multiple, output=int) + .save("results") + ) + + # With order_by("num") and batch_size=2, sorted order is [2, 6, 7, 8]: + # - Batch 1: [2, 6] - fully committed before crash + # - Batch 2: [7, 8] - 7 completes but batch crashes on 8, entire batch uncommitted + # Both inputs in the crashed batch need re-processing. + incomplete_batch = [7, 8] + complete_batch = [2, 6] + + # -------------- SECOND RUN (RECOVERS) ------------------- + reset_session_job_state() + processed_inputs.clear() + run_count[0] += 1 # Increment so generator succeeds this time + + ( + dc.read_dataset("nums", session=test_session) + .order_by("num") + .settings(batch_size=2) + .gen(result=gen_multiple, output=int) + .save("results") + ) + + # Verify inputs from crashed batch are re-processed + assert any(inp in processed_inputs for inp in incomplete_batch), ( + f"Inputs from crashed batch {incomplete_batch} should be re-processed, " + f"but only processed: {processed_inputs}" + ) + + # Verify inputs from committed batch are NOT re-processed + # (tests sys__partial flag correctness - complete inputs are correctly skipped) + for inp in complete_batch: + assert inp not in processed_inputs, ( + f"Input {inp} from committed batch should NOT be re-processed, " + f"but was found in processed: {processed_inputs}" + ) + + result = ( + dc.read_dataset("results", session=test_session) + .order_by("result") + .to_list("result") + ) + + expected = sorted([(num * 100 + i,) for num in numbers for i in range(5)]) + actual = sorted(result) + + assert actual == expected, ( + f"Should have all 20 outputs with no duplicates or missing.\n" + f"Expected: {expected}\n" + f"Actual: {actual}" + ) + + # Verify each input has exactly 5 outputs + result_by_input = {} + for (val,) in result: + input_id = val // 100 + result_by_input.setdefault(input_id, []).append(val) + + for input_id in numbers: + assert len(result_by_input.get(input_id, [])) == 5, ( + f"Input {input_id} should have exactly 5 outputs" + ) + + # Verify no duplicates + all_results = [val for (val,) in result] + assert len(all_results) == len(set(all_results)), "Should have no duplicate results" + + +@pytest.mark.xfail( + reason="Known limitation: inputs that yield nothing are not tracked " + "in processed table" +) +def test_generator_yielding_nothing(test_session, monkeypatch, nums_dataset): + """Test that generator correctly handles inputs that yield zero outputs.""" + processed = [] + + def selective_generator(num) -> Iterator[int]: + processed.append(num) + if num == 3: + raise Exception("Simulated failure") + if num % 2 == 0: # Only even numbers yield outputs + yield num * 10 + + # First run - fails on num=3 + reset_session_job_state() + chain = dc.read_dataset("nums", session=test_session).gen( + value=selective_generator, output=int + ) + + with pytest.raises(Exception, match="Simulated failure"): + chain.save("results") + + # Second run - should continue from checkpoint + reset_session_job_state() + processed.clear() + chain.save("results") + + # Only inputs 3,4,5,6 should be processed (1,2 were already done) + assert processed == [3, 4, 5, 6] + result = sorted(dc.read_dataset("results", session=test_session).to_list("value")) + assert result == [(20,), (40,), (60,)] + + +def test_empty_dataset_checkpoint(test_session): + processed = [] + + def mapper(num) -> int: + processed.append(num) + return num * 10 + + dc.read_values(num=[], session=test_session).save("empty_nums") + + reset_session_job_state() + chain = dc.read_dataset("empty_nums", session=test_session).map( + result=mapper, output=int + ) + chain.save("results") + + assert len(processed) == 0 + + # Second run should also work (checkpoint reuse with empty result) + reset_session_job_state() + processed.clear() + chain.save("results") + + assert len(processed) == 0 + + result = dc.read_dataset("results", session=test_session).to_list("result") + assert result == [] + + +def test_single_row_dataset_checkpoint(test_session): + processed = [] + run_count = {"value": 0} + + def mapper(num) -> int: + processed.append(num) + if run_count["value"] == 0: + raise Exception("First run failure") + return num * 10 + + dc.read_values(num=[42], session=test_session).save("single_num") + + reset_session_job_state() + chain = ( + dc.read_dataset("single_num", session=test_session) + .settings( + batch_size=10 # Batch size larger than dataset + ) + .map(result=mapper, output=int) + ) + + with pytest.raises(Exception, match="First run failure"): + chain.save("results") + + assert len(processed) == 1 + + reset_session_job_state() + processed.clear() + run_count["value"] += 1 + + chain.save("results") + + result = dc.read_dataset("results", session=test_session).to_list("result") + assert result == [(420,)] + + +def test_multiple_consecutive_failures(test_session): + """Test checkpoint recovery across multiple consecutive failures. + + Scenario: fail at row 3, then fail at row 5, then succeed. + Each run should continue from where the previous one left off. + """ + processed = [] + run_count = {"value": 0} + + def flaky_mapper(num) -> int: + processed.append(num) + if run_count["value"] == 0 and len(processed) >= 3: + raise Exception("First failure at row 3") + if run_count["value"] == 1 and len(processed) >= 3: + raise Exception("Second failure at row 3 (of remaining)") + return num * 10 + + dc.read_values(num=[1, 2, 3, 4, 5, 6, 7, 8], session=test_session).save("nums") + + chain = dc.read_dataset("nums", session=test_session).settings(batch_size=2) + + # -------------- FIRST RUN: Fails after processing 3 rows ------------------- + reset_session_job_state() + + with pytest.raises(Exception, match="First failure"): + chain.map(result=flaky_mapper, output=int).save("results") + + first_run_processed = len(processed) + assert first_run_processed == 3 + + # -------------- SECOND RUN: Continues but fails again ------------------- + reset_session_job_state() + processed.clear() + run_count["value"] += 1 + + with pytest.raises(Exception, match="Second failure"): + chain.map(result=flaky_mapper, output=int).save("results") + + second_run_processed = len(processed) + # Should process some rows (continuing from first run's checkpoint) + assert second_run_processed > 0 + + # -------------- THIRD RUN: Finally succeeds ------------------- + reset_session_job_state() + processed.clear() + run_count["value"] += 1 + + chain.map(result=flaky_mapper, output=int).save("results") + + # Verify final result is correct + result = dc.read_dataset("results", session=test_session).to_list("result") + assert sorted(result) == [(10,), (20,), (30,), (40,), (50,), (60,), (70,), (80,)] + + +def test_generator_multiple_consecutive_failures(test_session): + processed = [] + run_count = {"value": 0} + + def flaky_generator(num) -> Iterator[int]: + processed.append(num) + if run_count["value"] == 0 and num == 3: + raise Exception("First failure on num=3") + if run_count["value"] == 1 and num == 5: + raise Exception("Second failure on num=5") + yield num * 10 + yield num * 100 + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + chain = ( + dc.read_dataset("nums", session=test_session) + .order_by("num") + .settings(batch_size=2) + ) + + # -------------- FIRST RUN: Fails on num=3 ------------------- + reset_session_job_state() + + with pytest.raises(Exception, match="First failure"): + chain.gen(result=flaky_generator, output=int).save("results") + + # -------------- SECOND RUN: Continues but fails on num=5 ------------------- + reset_session_job_state() + processed.clear() + run_count["value"] += 1 + + with pytest.raises(Exception, match="Second failure"): + chain.gen(result=flaky_generator, output=int).save("results") + + # -------------- THIRD RUN: Finally succeeds ------------------- + reset_session_job_state() + processed.clear() + run_count["value"] += 1 + + chain.gen(result=flaky_generator, output=int).save("results") + + # Verify final result is correct (each input produces 2 outputs) + result = dc.read_dataset("results", session=test_session).to_list("result") + expected = [(i * 10,) for i in range(1, 7)] + [(i * 100,) for i in range(1, 7)] + assert sorted(result) == sorted(expected) + + # Verify no duplicates + values = [r[0] for r in result] + assert len(values) == len(set(values)) + + +def test_multiple_udf_chain_continue(test_session): + """Test continuing from partial with multiple UDFs in chain. + + When mapper fails, only mapper's partial table exists. On retry, mapper + completes and gen runs from scratch. + """ + map_processed = [] + gen_processed = [] + fail_once = [True] # Mutable flag to track if we should fail + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + def mapper(num: int) -> int: + map_processed.append(num) + # Fail before processing the 4th row in first run only + if fail_once[0] and len(map_processed) == 3: + fail_once[0] = False + raise Exception("Map failure") + return num * 2 + + def doubler(doubled) -> Iterator[int]: + gen_processed.append(doubled) + yield doubled + yield doubled + + # First run - fails in mapper + # batch_size=2: processes [1,2] (commits), then [3,4] (fails on 4) + reset_session_job_state() + chain = ( + dc.read_dataset("nums", session=test_session) + .settings(batch_size=2) + .map(doubled=mapper) + .gen(value=doubler, output=int) + ) + + with pytest.raises(Exception, match="Map failure"): + chain.save("results") + + # Second run - completes successfully + # Mapper continues from partial checkpoint + reset_session_job_state() + chain.save("results") + + # Verify mapper processed some rows (continuation working) + # First run: 3 rows attempted + # Second run: varies by warehouse (0-6 rows depending on batching/buffer behavior) + # Total: 6-9 calls (some rows may be reprocessed if not saved to partial) + assert 6 <= len(map_processed) <= 9, "Expected 6-9 total mapper calls" + + assert len(gen_processed) == 6 + + result = sorted(dc.read_dataset("results", session=test_session).to_list("value")) + assert sorted([v[0] for v in result]) == sorted( + [2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12] + ) + + +def test_file_udf_continue_from_partial(test_session, tmp_dir): + """Test checkpoint continuation with File objects (file downloading UDFs). + + Ensures that File objects are correctly reconstructed from the checkpoint's + input table on the second run (regression test for bytes vs str path issue). + """ + # Create test files + file_names = [f"file_{i}.txt" for i in range(6)] + for name in file_names: + (tmp_dir / name).write_text(f"content of {name}", encoding="utf-8") + + processed_files = [] + + def process_file(file: File) -> int: + if len(processed_files) >= 3: + raise Exception("Simulated failure after 3 files") + data = file.read() + processed_files.append(file.path) + return len(data) + + chain = ( + dc.read_storage(tmp_dir.as_uri(), session=test_session) + .order_by("file.path") + .settings(batch_size=2) + ) + + # -------------- FIRST RUN (FAILS AFTER 3 FILES) ------------------- + reset_session_job_state() + + with pytest.raises(Exception, match="Simulated failure after 3 files"): + chain.map(file_size=process_file).save("file_results") + + assert len(processed_files) == 3 + + # -------------- SECOND RUN (CONTINUES FROM CHECKPOINT) ------------------- + reset_session_job_state() + processed_files.clear() + + def process_file_fixed(file: File) -> int: + data = file.read() + processed_files.append(file.path) + return len(data) + + chain.map(file_size=process_file_fixed).save("file_results") + + result = dc.read_dataset("file_results", session=test_session).to_list("file_size") + assert len(result) == 6 + + # Second run should only process remaining files + assert 0 < len(processed_files) <= 6 + + +def test_skip_udf_fallback_when_output_table_missing(test_session): + call_count = {"value": 0} + + def mapper(num) -> int: + call_count["value"] += 1 + return num * 10 + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + chain = dc.read_dataset("nums", session=test_session).map(result=mapper, output=int) + + # -------------- FIRST RUN ------------------- + reset_session_job_state() + assert chain.count() == 6 + assert call_count["value"] == 6 + + catalog = test_session.catalog + + # Drop all UDF output tables from first run + for table_name in catalog.warehouse.db.list_tables(prefix="udf_"): + if "_output" in table_name and "_partial" not in table_name: + table = catalog.warehouse.db.get_table(table_name) + catalog.warehouse.db.drop_table(table, if_exists=True) + + # -------------- SECOND RUN ------------------- + reset_session_job_state() + call_count["value"] = 0 + + result = chain.order_by("num").to_list("result") + + # UDF should have been re-executed from scratch (fallback from skip) + assert call_count["value"] == 6 + assert result == [(10,), (20,), (30,), (40,), (50,), (60,)] + + +def test_continue_udf_fallback_when_partial_table_missing(test_session): + fail_flag = [True] + + def mapper(num) -> int: + if fail_flag[0] and num >= 4: + raise RuntimeError("Simulated failure") + return num * 10 + + dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + chain = dc.read_dataset("nums", session=test_session).settings(batch_size=2) + + # -------------- FIRST RUN (FAILS) ------------------- + reset_session_job_state() + with pytest.raises(RuntimeError, match="Simulated failure"): + chain.map(result=mapper, output=int).save("results") + + catalog = test_session.catalog + test_session.get_or_create_job() + + # Drop all partial output tables from first run + for table_name in catalog.warehouse.db.list_tables(prefix="udf_"): + if "_partial" in table_name: + table = catalog.warehouse.db.get_table(table_name) + catalog.warehouse.db.drop_table(table, if_exists=True) + + # -------------- SECOND RUN ------------------- + reset_session_job_state() + fail_flag[0] = False + + chain.map(result=mapper, output=int).save("results") + + # UDF should have been re-executed from scratch (fallback from continue) + result = dc.read_dataset("results", session=test_session).to_list("result") + assert sorted(result) == [(10,), (20,), (30,), (40,), (50,), (60,)] + + +def test_aggregator_checkpoint_no_partial_continuation(test_session): + call_count = {"value": 0} + + def sum_by_group(num: list[int]) -> Iterator[tuple[int, int]]: + call_count["value"] += 1 + total = sum(num) + count = len(num) + yield total, count + + dc.read_values( + num=[1, 2, 3, 4, 5, 6], + category=["a", "a", "a", "b", "b", "b"], + session=test_session, + ).save("grouped_nums") + + chain = dc.read_dataset("grouped_nums", session=test_session) + + # -------------- FIRST RUN (FAILS) ------------------- + reset_session_job_state() + + fail_flag = [True] + + def sum_by_group_buggy(num: list[int]) -> Iterator[tuple[int, int]]: + call_count["value"] += 1 + if fail_flag[0] and call_count["value"] >= 2: + raise RuntimeError("Simulated aggregator failure") + total = sum(num) + count = len(num) + yield total, count + + with pytest.raises(RuntimeError, match="Simulated aggregator failure"): + chain.agg( + sum_by_group_buggy, + partition_by="category", + output={"total": int, "count": int}, + ).save("agg_results") + + # -------------- SECOND RUN (FIXED) ------------------- + reset_session_job_state() + call_count["value"] = 0 + fail_flag[0] = False + + chain.agg( + sum_by_group, + partition_by="category", + output={"total": int, "count": int}, + ).save("agg_results") + + # Aggregator should have run from scratch (not continued from partial) + # because partition_by prevents partial continuation + assert call_count["value"] == 2 # Two groups: "a" and "b" + + result = sorted( + dc.read_dataset("agg_results", session=test_session).to_list("total", "count") + ) + # Group "a": sum(1,2,3)=6, count=3; Group "b": sum(4,5,6)=15, count=3 + assert result == [(6, 3), (15, 3)] diff --git a/tests/unit/lib/test_checkpoints.py b/tests/func/checkpoints/test_checkpoint_workflows.py similarity index 50% rename from tests/unit/lib/test_checkpoints.py rename to tests/func/checkpoints/test_checkpoint_workflows.py index e28ab51bd..4bfe1211a 100644 --- a/tests/unit/lib/test_checkpoints.py +++ b/tests/func/checkpoints/test_checkpoint_workflows.py @@ -1,10 +1,8 @@ import pytest -import sqlalchemy as sa import datachain as dc from datachain.error import ( DatasetNotFoundError, - JobAncestryDepthExceededError, JobNotFoundError, ) from tests.utils import reset_session_job_state @@ -18,43 +16,6 @@ def mapper_fail(num: int) -> int: raise CustomMapperError("Error") -def get_dataset_versions_for_job(metastore, job_id): - """Helper to get all dataset versions associated with a job. - - Returns: - List of tuples (dataset_name, version, is_creator) - """ - query = ( - sa.select( - metastore._datasets_versions.c.dataset_id, - metastore._datasets_versions.c.version, - metastore._dataset_version_jobs.c.is_creator, - ) - .select_from( - metastore._dataset_version_jobs.join( - metastore._datasets_versions, - metastore._dataset_version_jobs.c.dataset_version_id - == metastore._datasets_versions.c.id, - ) - ) - .where(metastore._dataset_version_jobs.c.job_id == job_id) - ) - - results = list(metastore.db.execute(query)) - - # Get dataset names - dataset_versions = [] - for dataset_id, version, is_creator in results: - dataset_query = sa.select(metastore._datasets.c.name).where( - metastore._datasets.c.id == dataset_id - ) - dataset_name = next(metastore.db.execute(dataset_query))[0] - # Convert is_creator to boolean for consistent assertions across databases - dataset_versions.append((dataset_name, version, bool(is_creator))) - - return sorted(dataset_versions) - - @pytest.fixture(autouse=True) def mock_is_script_run(monkeypatch): """Mock is_script_run to return True for stable job names in tests.""" @@ -63,13 +24,9 @@ def mock_is_script_run(monkeypatch): @pytest.fixture def nums_dataset(test_session): - return dc.read_values(num=[1, 2, 3], session=test_session).save("nums") + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") -@pytest.mark.skipif( - "os.environ.get('DATACHAIN_DISTRIBUTED')", - reason="Checkpoints test skipped in distributed mode", -) @pytest.mark.parametrize("reset_checkpoints", [True, False]) @pytest.mark.parametrize("with_delta", [True, False]) @pytest.mark.parametrize("use_datachain_job_id_env", [True, False]) @@ -84,7 +41,7 @@ def test_checkpoints( catalog = test_session.catalog metastore = catalog.metastore - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints)) + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(reset_checkpoints)) if with_delta: chain = dc.read_dataset( @@ -124,7 +81,6 @@ def test_checkpoints( run_group_id=first_job.run_group_id, ), ) - chain.save("nums1") chain.save("nums2") chain.save("nums3") @@ -135,20 +91,16 @@ def test_checkpoints( assert len(catalog.get_dataset("nums2").versions) == expected_versions assert len(catalog.get_dataset("nums3").versions) == 1 - assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 2 + assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 3 assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3 -@pytest.mark.skipif( - "os.environ.get('DATACHAIN_DISTRIBUTED')", - reason="Checkpoints test skipped in distributed mode", -) @pytest.mark.parametrize("reset_checkpoints", [True, False]) def test_checkpoints_modified_chains( test_session, monkeypatch, nums_dataset, reset_checkpoints ): catalog = test_session.catalog - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints)) + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(reset_checkpoints)) chain = dc.read_dataset("nums", session=test_session) @@ -174,17 +126,13 @@ def test_checkpoints_modified_chains( assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3 -@pytest.mark.skipif( - "os.environ.get('DATACHAIN_DISTRIBUTED')", - reason="Checkpoints test skipped in distributed mode", -) @pytest.mark.parametrize("reset_checkpoints", [True, False]) def test_checkpoints_multiple_runs( test_session, monkeypatch, nums_dataset, reset_checkpoints ): catalog = test_session.catalog - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(reset_checkpoints)) + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(reset_checkpoints)) chain = dc.read_dataset("nums", session=test_session) @@ -237,22 +185,17 @@ def test_checkpoints_multiple_runs( assert num2_versions == 2 assert num3_versions == 2 - assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 2 + assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 3 assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3 - assert len(list(catalog.metastore.list_checkpoints(third_job_id))) == 2 + assert len(list(catalog.metastore.list_checkpoints(third_job_id))) == 3 assert len(list(catalog.metastore.list_checkpoints(fourth_job_id))) == 3 -@pytest.mark.skipif( - "os.environ.get('DATACHAIN_DISTRIBUTED')", - reason="Checkpoints test skipped in distributed mode", -) def test_checkpoints_check_valid_chain_is_returned( test_session, monkeypatch, nums_dataset, ): - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) chain = dc.read_dataset("nums", session=test_session) # -------------- FIRST RUN ------------------- @@ -268,7 +211,7 @@ def test_checkpoints_check_valid_chain_is_returned( assert ds.dataset is not None assert ds.dataset.name == "nums1" assert len(ds.dataset.versions) == 1 - assert ds.order_by("num").to_list("num") == [(1,), (2,), (3,)] + assert ds.order_by("num").to_list("num") == [(1,), (2,), (3,), (4,), (5,), (6,)] def test_checkpoints_invalid_parent_job_id(test_session, monkeypatch, nums_dataset): @@ -279,194 +222,130 @@ def test_checkpoints_invalid_parent_job_id(test_session, monkeypatch, nums_datas dc.read_dataset("nums", session=test_session).save("nums1") -def test_dataset_job_linking(test_session, monkeypatch, nums_dataset): - """Test that dataset versions are correctly linked to jobs via many-to-many. - - This test verifies that datasets should appear in ALL jobs that use them in - the single job "chain", not just the job that created them. - """ +def test_checkpoint_with_deleted_dataset_version( + test_session, monkeypatch, nums_dataset +): catalog = test_session.catalog - metastore = catalog.metastore - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(False)) chain = dc.read_dataset("nums", session=test_session) # -------------- FIRST RUN: Create dataset ------------------- reset_session_job_state() - chain.save("nums_linked") - job1_id = test_session.get_or_create_job().id - - # Verify job1 has the dataset associated (as creator) - job1_datasets = get_dataset_versions_for_job(metastore, job1_id) - assert len(job1_datasets) == 1 - assert job1_datasets[0] == ("nums_linked", "1.0.0", True) - - # -------------- SECOND RUN: Reuse dataset via checkpoint ------------------- - reset_session_job_state() - chain.save("nums_linked") - job2_id = test_session.get_or_create_job().id - - # Verify job2 also has the dataset associated (not creator) - job2_datasets = get_dataset_versions_for_job(metastore, job2_id) - assert len(job2_datasets) == 1 - assert job2_datasets[0] == ("nums_linked", "1.0.0", False) - - # Verify job1 still has it - job1_datasets = get_dataset_versions_for_job(metastore, job1_id) - assert len(job1_datasets) == 1 - assert job1_datasets[0][2] # still creator - - # -------------- THIRD RUN: Another reuse ------------------- - reset_session_job_state() - chain.save("nums_linked") - job3_id = test_session.get_or_create_job().id - - # Verify job3 also has the dataset associated (not creator) - job3_datasets = get_dataset_versions_for_job(metastore, job3_id) - assert len(job3_datasets) == 1 - assert job3_datasets[0] == ("nums_linked", "1.0.0", False) - - # Verify get_dataset_version_for_job_ancestry works correctly - dataset = catalog.get_dataset("nums_linked") - found_version = metastore.get_dataset_version_for_job_ancestry( - "nums_linked", - dataset.project.namespace.name, - dataset.project.name, - job3_id, - ) - assert found_version.version == "1.0.0" - - -def test_dataset_job_linking_with_reset(test_session, monkeypatch, nums_dataset): - """Test that with CHECKPOINTS_RESET=True, new versions are created each run.""" - catalog = test_session.catalog - metastore = catalog.metastore - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(True)) + chain.save("nums_deleted") + test_session.get_or_create_job() - chain = dc.read_dataset("nums", session=test_session) + dataset = catalog.get_dataset("nums_deleted") + assert len(dataset.versions) == 1 + assert dataset.latest_version == "1.0.0" - # -------------- FIRST RUN ------------------- - reset_session_job_state() - chain.save("nums_reset") - job1_id = test_session.get_or_create_job().id + catalog.remove_dataset("nums_deleted", version="1.0.0", force=True) - # Verify job1 created version 1.0.0 - job1_datasets = get_dataset_versions_for_job(metastore, job1_id) - assert len(job1_datasets) == 1 - assert job1_datasets[0] == ("nums_reset", "1.0.0", True) + with pytest.raises(DatasetNotFoundError): + catalog.get_dataset("nums_deleted") - # -------------- SECOND RUN ------------------- + # -------------- SECOND RUN: Checkpoint exists but version gone reset_session_job_state() - chain.save("nums_reset") + chain.save("nums_deleted") job2_id = test_session.get_or_create_job().id - # Verify job2 created NEW version 1.0.1 (not reusing 1.0.0) - job2_datasets = get_dataset_versions_for_job(metastore, job2_id) - assert len(job2_datasets) == 1 - assert job2_datasets[0] == ("nums_reset", "1.0.1", True) + # Should create a NEW version since old one was deleted + dataset = catalog.get_dataset("nums_deleted") + assert len(dataset.versions) == 1 + assert dataset.latest_version == "1.0.0" - # Verify job1 still only has version 1.0.0 - job1_datasets = get_dataset_versions_for_job(metastore, job1_id) - assert len(job1_datasets) == 1 - assert job1_datasets[0] == ("nums_reset", "1.0.0", True) + new_version = dataset.get_version("1.0.0") + assert new_version.job_id == job2_id -def test_dataset_version_job_id_updates_to_latest( +def test_udf_checkpoints_multiple_calls_same_job( test_session, monkeypatch, nums_dataset ): - """Test that dataset_version.job_id is updated to the latest job that used it.""" - catalog = test_session.catalog - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) - - chain = dc.read_dataset("nums", session=test_session) - name = "nums_jobid" + """ + Test that UDF execution creates checkpoints, but subsequent calls in the same + job will re-execute because the hash changes (includes previous checkpoint hash). + Checkpoint reuse is designed for cross-job execution, not within-job execution. + """ + call_count = {"count": 0} - # -------------- FIRST RUN ------------------- - reset_session_job_state() - chain.save(name) - job1_id = test_session.get_or_create_job().id + def add_ten(num) -> int: + call_count["count"] += 1 + return num + 10 - dataset = catalog.get_dataset(name) - assert dataset.get_version(dataset.latest_version).job_id == job1_id + chain = dc.read_dataset("nums", session=test_session).map( + plus_ten=add_ten, output=int + ) - # -------------- SECOND RUN: Reuse via checkpoint ------------------- reset_session_job_state() - chain.save(name) - job2_id = test_session.get_or_create_job().id - # job_id should now point to job2 (latest) - dataset = catalog.get_dataset(name) - assert dataset.get_version(dataset.latest_version).job_id == job2_id + # First count() - should execute UDF + assert chain.count() == 6 + first_calls = call_count["count"] + assert first_calls == 6, "Mapper should be called 6 times on first count()" - # -------------- THIRD RUN: Another reuse ------------------- - reset_session_job_state() - chain.save(name) - job3_id = test_session.get_or_create_job().id + # Second count() - will re-execute because hash includes previous checkpoint + call_count["count"] = 0 + assert chain.count() == 6 + assert call_count["count"] == 6, "Mapper re-executes in same job" - # job_id should now point to job3 (latest) - dataset = catalog.get_dataset(name) - assert dataset.get_version(dataset.latest_version).job_id == job3_id + # Third count() - will also re-execute + call_count["count"] = 0 + assert chain.count() == 6 + assert call_count["count"] == 6, "Mapper re-executes in same job" + # Other operations like to_list() will also re-execute + call_count["count"] = 0 + result = chain.order_by("num").to_list("plus_ten") + assert result == [(11,), (12,), (13,), (14,), (15,), (16,)] + assert call_count["count"] == 6, "Mapper re-executes in same job" -def test_job_ancestry_depth_exceeded(test_session, monkeypatch, nums_dataset): - from datachain.data_storage import metastore - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) - # Mock max depth to a small value (3) for testing - monkeypatch.setattr(metastore, "JOB_ANCESTRY_MAX_DEPTH", 3) +@pytest.mark.parametrize("reset_checkpoints", [True, False]) +def test_udf_checkpoints_cross_job_reuse( + test_session, monkeypatch, nums_dataset, reset_checkpoints +): + catalog = test_session.catalog + monkeypatch.setenv("DATACHAIN_IGNORE_CHECKPOINTS", str(reset_checkpoints)) - chain = dc.read_dataset("nums", session=test_session) + call_count = {"count": 0} - # Keep saving until we hit the max depth error - max_attempts = 10 # Safety limit to prevent infinite loop - for _ in range(max_attempts): - reset_session_job_state() - try: - chain.save("nums_depth") - except JobAncestryDepthExceededError as exc_info: - # Verify the error message - assert "too deep" in str(exc_info) - assert "from scratch" in str(exc_info) - # Test passed - we hit the max depth - return + def double_num(num) -> int: + call_count["count"] += 1 + return num * 2 - # If we get here, we never hit the max depth error - pytest.fail(f"Expected JobAncestryDepthExceededError after {max_attempts} saves") + chain = dc.read_dataset("nums", session=test_session).map( + doubled=double_num, output=int + ) + # -------------- FIRST RUN - count() triggers UDF execution ------------------- + reset_session_job_state() + assert chain.count() == 6 + first_job_id = test_session.get_or_create_job().id -def test_checkpoint_with_deleted_dataset_version( - test_session, monkeypatch, nums_dataset -): - """Test checkpoint found but dataset version deleted from ancestry.""" - catalog = test_session.catalog - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) + assert call_count["count"] == 6 - chain = dc.read_dataset("nums", session=test_session) + checkpoints = list(catalog.metastore.list_checkpoints(first_job_id)) + assert len(checkpoints) == 1 + assert checkpoints[0].partial is False - # -------------- FIRST RUN: Create dataset ------------------- + # -------------- SECOND RUN - should reuse UDF checkpoint ------------------- reset_session_job_state() - chain.save("nums_deleted") - test_session.get_or_create_job() + call_count["count"] = 0 # Reset counter - dataset = catalog.get_dataset("nums_deleted") - assert len(dataset.versions) == 1 - assert dataset.latest_version == "1.0.0" - - catalog.remove_dataset("nums_deleted", version="1.0.0", force=True) - - with pytest.raises(DatasetNotFoundError): - catalog.get_dataset("nums_deleted") + assert chain.count() == 6 + second_job_id = test_session.get_or_create_job().id - # -------------- SECOND RUN: Checkpoint exists but version gone - reset_session_job_state() - chain.save("nums_deleted") - job2_id = test_session.get_or_create_job().id + if reset_checkpoints: + assert call_count["count"] == 6, "Mapper should be called again" + else: + assert call_count["count"] == 0, "Mapper should NOT be called" - # Should create a NEW version since old one was deleted - dataset = catalog.get_dataset("nums_deleted") - assert len(dataset.versions) == 1 - assert dataset.latest_version == "1.0.0" + checkpoints_second = list(catalog.metastore.list_checkpoints(second_job_id)) + # After successful completion, only final checkpoint remains + # (partial checkpoint is deleted after promotion) + assert len(checkpoints_second) == 1 + assert checkpoints_second[0].partial is False - # Verify the new version was created by job2, not job1 - new_version = dataset.get_version("1.0.0") - assert new_version.job_id == job2_id + # Verify the data is correct + result = chain.order_by("num").to_list("doubled") + assert result == [(2,), (4,), (6,), (8,), (10,), (12,)] diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 7a6114a9e..495d3557d 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -660,12 +660,12 @@ def test_enlist_source_handles_file(cloud_test_catalog): @pytest.mark.parametrize("from_cli", [False, True]) -def test_garbage_collect(cloud_test_catalog, from_cli, capsys): +def test_garbage_collect_temp_tables(cloud_test_catalog, from_cli, capsys): catalog = cloud_test_catalog.catalog assert catalog.get_temp_table_names() == [] temp_tables = [ "tmp_vc12F", - "udf_jh653", + "tmp_jh653", ] for t in temp_tables: catalog.warehouse.create_udf_table(name=t) diff --git a/tests/func/test_checkpoints.py b/tests/func/test_checkpoints.py deleted file mode 100644 index acb25fc86..000000000 --- a/tests/func/test_checkpoints.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -import datachain as dc -from datachain.error import DatasetNotFoundError -from tests.utils import reset_session_job_state - - -@pytest.fixture(autouse=True) -def mock_is_script_run(monkeypatch): - """Mock is_script_run to return True for stable job names in tests.""" - monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) - - -@pytest.mark.skipif( - "os.environ.get('DATACHAIN_DISTRIBUTED')", - reason="Checkpoints test skipped in distributed mode", -) -def test_checkpoints_parallel(test_session_tmpfile, monkeypatch): - def mapper_fail(num) -> int: - raise Exception("Error") - - test_session = test_session_tmpfile - catalog = test_session.catalog - - dc.read_values(num=list(range(1000)), session=test_session).save("nums") - - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", str(False)) - - chain = dc.read_dataset("nums", session=test_session).settings(parallel=True) - - # -------------- FIRST RUN ------------------- - reset_session_job_state() - chain.save("nums1") - chain.save("nums2") - with pytest.raises(RuntimeError): - chain.map(new=mapper_fail).save("nums3") - first_job_id = test_session.get_or_create_job().id - - catalog.get_dataset("nums1") - catalog.get_dataset("nums2") - with pytest.raises(DatasetNotFoundError): - catalog.get_dataset("nums3") - - # -------------- SECOND RUN ------------------- - reset_session_job_state() - chain.save("nums1") - chain.save("nums2") - chain.save("nums3") - second_job_id = test_session.get_or_create_job().id - - assert len(catalog.get_dataset("nums1").versions) == 1 - assert len(catalog.get_dataset("nums2").versions) == 1 - assert len(catalog.get_dataset("nums3").versions) == 1 - - assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 2 - assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3 diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 386c8a7cd..fe2df41f2 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path, PurePosixPath from typing import cast -from unittest.mock import Mock, patch +from unittest.mock import patch from urllib.parse import urlparse import numpy as np @@ -277,7 +277,11 @@ def test_to_storage( file_type, num_threads, ): - mapper = Mock(side_effect=lambda file_path: len(file_path)) # noqa: PLW0108 + call_count = {"count": 0} + + def mapper(file_path): + call_count["count"] += 1 + return len(file_path) ctc = cloud_test_catalog df = dc.read_storage(ctc.src_uri, type=file_type, session=test_session) @@ -330,7 +334,7 @@ def _expected_destination_rel(file_obj: File, placement: str) -> Path: with (output_root / destination_rel).open() as f: assert f.read() == expected[file_obj.name] - assert mapper.call_count == len(expected) + assert call_count["count"] == len(expected) @pytest.mark.parametrize("use_cache", [True, False]) diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py index 499840a51..e4257c0e2 100644 --- a/tests/func/test_delta.py +++ b/tests/func/test_delta.py @@ -597,7 +597,6 @@ def get_index(file: File) -> int: def test_delta_update_check_num_calls( test_session, tmp_dir, tmp_path, capsys, monkeypatch ): - monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", "True") ds_name = "delta_ds" path = tmp_dir.as_uri() tmp_dir = tmp_dir / "images" @@ -645,7 +644,6 @@ def get_index(file: File) -> int: create_delta_dataset() captured = capsys.readouterr() - # assert captured.out == "Garbage collecting 2 tables.\n" assert captured.out == "\n".join([map_print] * 20) + "\n" diff --git a/tests/func/test_metastore.py b/tests/func/test_metastore.py index 9e814a89a..08c99f003 100644 --- a/tests/func/test_metastore.py +++ b/tests/func/test_metastore.py @@ -921,7 +921,6 @@ def test_get_ancestor_job_ids(metastore, depth): rerun_from_id = None group_id = None - # Create jobs from root to leaf for i in range(depth + 1): job_id = metastore.create_job( name=f"job_{i}", @@ -934,16 +933,13 @@ def test_get_ancestor_job_ids(metastore, depth): ) job_ids.append(job_id) rerun_from_id = job_id - # First job sets the group_id if group_id is None: group_id = metastore.get_job(job_id).run_group_id - # The last job is the leaf (youngest) leaf_job_id = job_ids[-1] ancestors = metastore.get_ancestor_job_ids(leaf_job_id) - # Should return all ancestors except the leaf itself, in order from parent to root expected_ancestors = list(reversed(job_ids[:-1])) assert ancestors == expected_ancestors diff --git a/tests/func/test_warehouse.py b/tests/func/test_warehouse.py index da18208b4..8c7a40083 100644 --- a/tests/func/test_warehouse.py +++ b/tests/func/test_warehouse.py @@ -51,7 +51,9 @@ def udf_gen(value: int) -> Iterator[int]: wraps=warehouse.db.executemany, ) as mock_executemany: dc.read_values(value=list(range(100)), session=test_session).save("values") - assert mock_executemany.call_count == 2 # 1 for read_values, 1 for save + # 1 for read_values gen() output, 1 for save + # Note: processed_table no longer exists (sys__input_id is in output table now) + assert mock_executemany.call_count == 2 mock_executemany.reset_mock() # Mapper @@ -73,6 +75,7 @@ def udf_gen(value: int) -> Iterator[int]: # Generator dc.read_dataset("values", session=test_session).gen(x2=udf_gen).save("large") + # Only 1 call for gen() output (processed_table no longer exists) assert mock_executemany.call_count == 1 mock_executemany.reset_mock() @@ -82,6 +85,7 @@ def udf_gen(value: int) -> Iterator[int]: .gen(x2=udf_gen) .save("large") ) + # Only 20 for outputs (processed_table no longer exists) assert mock_executemany.call_count == 20 mock_executemany.reset_mock() assert set(chain.to_values("x2")) == set(range(200)) diff --git a/tests/test_cli_e2e.py b/tests/test_cli_e2e.py index 03b72d542..856b935ed 100644 --- a/tests/test_cli_e2e.py +++ b/tests/test_cli_e2e.py @@ -158,7 +158,7 @@ def _tabulated_datasets(name, version): }, { "command": ("datachain", "gc"), - "expected": "Nothing to clean up.\n", + "expected": ("Nothing to clean up.\n"), }, ) diff --git a/tests/test_query_e2e.py b/tests/test_query_e2e.py index 3e2dc661b..0c749ea6c 100644 --- a/tests/test_query_e2e.py +++ b/tests/test_query_e2e.py @@ -113,7 +113,7 @@ }, { "command": ("datachain", "gc"), - "expected": "Nothing to clean up.\n", + "expected": ("Nothing to clean up.\n"), }, ) @@ -178,6 +178,14 @@ def run_step(step, catalog): popen_args = {"start_new_session": True} stdin_path = step.get("stdin_file") with open(stdin_path) if stdin_path else nullcontext(None) as stdin_file: + # Build env without DATACHAIN_MAIN_PROCESS_PID so script starts fresh + # as its own main process (with checkpoints enabled) + script_env = { + k: v for k, v in os.environ.items() if k != "DATACHAIN_MAIN_PROCESS_PID" + } + script_env["DATACHAIN__METASTORE"] = catalog.metastore.serialize() + script_env["DATACHAIN__WAREHOUSE"] = catalog.warehouse.serialize() + process = subprocess.Popen( # noqa: S603 command, shell=False, @@ -185,11 +193,7 @@ def run_step(step, catalog): stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", - env={ - **os.environ, - "DATACHAIN__METASTORE": catalog.metastore.serialize(), - "DATACHAIN__WAREHOUSE": catalog.warehouse.serialize(), - }, + env=script_env, **popen_args, ) interrupt_after = step.get("interrupt_after") diff --git a/tests/unit/test_data_storage.py b/tests/unit/test_data_storage.py index c6c852f7d..95b425bad 100644 --- a/tests/unit/test_data_storage.py +++ b/tests/unit/test_data_storage.py @@ -75,3 +75,28 @@ def test_db_defaults(col_type, default_value, catalog): assert values[0] == default_value warehouse.db.drop_table(table) + + +def test_get_table_missing(catalog): + from datachain.error import TableMissingError + + with pytest.raises(TableMissingError, match="not found"): + catalog.warehouse.db.get_table("nonexistent_table_12345") + + +def test_list_tables(catalog): + db = catalog.warehouse.db + tables = db.list_tables() + assert isinstance(tables, list) + + # Create a test table + table = catalog.warehouse.create_udf_table([], name="test_list_tables_abc") + try: + tables_after = db.list_tables() + assert "test_list_tables_abc" in tables_after + + # Test with prefix filter + filtered = db.list_tables(prefix="test_list_tables") + assert "test_list_tables_abc" in filtered + finally: + db.drop_table(table) diff --git a/tests/unit/test_hash_utils.py b/tests/unit/test_hash_utils.py index e6d777e3c..c1ca415b8 100644 --- a/tests/unit/test_hash_utils.py +++ b/tests/unit/test_hash_utils.py @@ -370,3 +370,56 @@ def test_lambda_different_hashes(): # Ensure hashes are all different assert len({h1, h2, h3}) == 3 + + +def test_hash_callable_objects(): + """Test hashing of callable objects (instances with __call__).""" + + class MyCallable: + def __call__(self, x): + return x * 2 + + class AnotherCallable: + def __call__(self, y): + return y + 1 + + obj1 = MyCallable() + obj2 = AnotherCallable() + + assert ( + hash_callable(obj1) + == "41dd7a38058975b10d5604beb5b60041e5b9d7de0f85c2364c11c3907b4ee9fc" + ) + assert ( + hash_callable(obj2) + == "7ae5ff45f5acd08e75373bb332b99a8c30d931645c98d18b5bef16ad638a205e" + ) + + +@pytest.mark.parametrize("value", ["not a callable", 42, None, [1, 2, 3]]) +def test_hash_callable_not_callable(value): + with pytest.raises(TypeError, match="Expected a callable"): + hash_callable(value) + + +def test_hash_callable_builtin_functions(): + h1 = hash_callable(len) + h2 = hash_callable(len) + # Built-ins return random hash each time + assert h1 != h2 + assert len(h1) == 64 + + +def test_hash_callable_no_name_attribute(): + from unittest.mock import MagicMock + + mock_callable = MagicMock() + del mock_callable.__name__ + h = hash_callable(mock_callable) + assert len(h) == 64 + + +def test_hash_column_elements_single_element(): + single_hash = hash_column_elements(C("name")) + list_hash = hash_column_elements([C("name")]) + assert single_hash == list_hash diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 99081c15c..039a2c0c9 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,8 +1,13 @@ +import os +import threading + import pytest from datachain.utils import ( + _CheckpointState, batched, batched_it, + checkpoints_enabled, datachain_paths_join, determine_processes, determine_workers, @@ -13,6 +18,7 @@ sql_escape_like, suffix_to_number, uses_glob, + with_last_flag, ) DATACHAIN_TEST_PATHS = ["/file1", "file2", "/dir/file3", "dir/file4"] @@ -295,3 +301,147 @@ def test_batched_it(num_rows, batch_size): assert num_batches == num_rows / batch_size assert len(uniq_data) == num_rows + + +def gen3(): + yield from range(3) + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + [10, 20, 30], + [(10, False), (20, False), (30, True)], + ), + ( + [42], + [(42, True)], + ), + ( + [], + [], + ), + ( + gen3(), # generator input + [(0, False), (1, False), (2, True)], + ), + ], +) +def test_with_last_flag(input_data, expected): + assert list(with_last_flag(input_data)) == expected + + +@pytest.fixture(autouse=True) +def reset_checkpoint_state(monkeypatch): + """Reset checkpoint state before each test.""" + _CheckpointState.disabled = False + _CheckpointState.warning_shown = False + _CheckpointState.owner_thread = None + # Clear any existing env vars + monkeypatch.delenv("DATACHAIN_MAIN_PROCESS_PID", raising=False) + monkeypatch.delenv("DATACHAIN_SUBPROCESS", raising=False) + yield + # Reset after test as well + _CheckpointState.disabled = False + _CheckpointState.warning_shown = False + _CheckpointState.owner_thread = None + + +def test_checkpoints_enabled_main_thread(): + """Test that checkpoints are enabled in main thread and main process.""" + assert checkpoints_enabled() is True + assert _CheckpointState.disabled is False + assert _CheckpointState.warning_shown is False + + +def test_checkpoints_enabled_non_main_thread(monkeypatch): + # First call establishes ownership with the real current thread + assert checkpoints_enabled() is True + owner_ident = _CheckpointState.owner_thread + assert owner_ident == threading.current_thread().ident + + # Now simulate a different thread by mocking the ident + class MockThread: + ident = owner_ident + 1 + name = "Thread-1" + + monkeypatch.setattr("datachain.utils.threading.current_thread", MockThread) + + assert checkpoints_enabled() is False + assert _CheckpointState.disabled is True + + +def test_checkpoints_enabled_non_main_process(monkeypatch): + # Simulate a subprocess by setting DATACHAIN_MAIN_PROCESS_PID to a different PID + monkeypatch.setenv("DATACHAIN_MAIN_PROCESS_PID", str(os.getpid() + 1000)) + + assert checkpoints_enabled() is False + # Note: disabled flag is not set for subprocess detection, just returns False + assert _CheckpointState.warning_shown is True + + +def test_checkpoints_enabled_warning_shown_once(monkeypatch, caplog): + """Test that the warning is only shown once even when called multiple times.""" + # First call establishes ownership + assert checkpoints_enabled() is True + owner_ident = _CheckpointState.owner_thread + + class MockThread: + ident = owner_ident + 1 # Different thread ident + name = "Thread-1" + + monkeypatch.setattr("datachain.utils.threading.current_thread", MockThread) + + # Call multiple times from non-owner thread + assert checkpoints_enabled() is False + assert checkpoints_enabled() is False + assert checkpoints_enabled() is False + + # Verify warning was logged only once + warning_count = sum( + 1 for record in caplog.records if "Concurrent thread detected" in record.message + ) + assert warning_count == 1 + assert _CheckpointState.warning_shown is True + + +def test_checkpoints_enabled_stays_disabled(monkeypatch): + """Test that once disabled, checkpoints stay disabled even in owner thread.""" + # First call establishes ownership + assert checkpoints_enabled() is True + owner_ident = _CheckpointState.owner_thread + + class MockThread: + ident = owner_ident + 1 # Different thread ident + name = "Thread-1" + + class MockOwnerThread: + ident = owner_ident # Same as owner + name = "MainThread" + + # Call from non-owner thread disables checkpoints + monkeypatch.setattr("datachain.utils.threading.current_thread", MockThread) + assert checkpoints_enabled() is False + assert _CheckpointState.disabled is True + + # Even if we go back to owner thread, it should stay disabled + monkeypatch.setattr("datachain.utils.threading.current_thread", MockOwnerThread) + assert checkpoints_enabled() is False + assert _CheckpointState.disabled is True + + +def test_checkpoints_enabled_datachain_subprocess(monkeypatch): + """Test that DATACHAIN_SUBPROCESS env var enables checkpoints regardless of PID.""" + # Set the main PID to something different (simulating we're in a subprocess) + monkeypatch.setenv("DATACHAIN_MAIN_PROCESS_PID", str(os.getpid() + 1000)) + + # Without DATACHAIN_SUBPROCESS, checkpoints should be disabled + assert checkpoints_enabled() is False + + # Reset state for next test + _CheckpointState.warning_shown = False + + # With DATACHAIN_SUBPROCESS, checkpoints should be enabled + monkeypatch.setenv("DATACHAIN_SUBPROCESS", "1") + assert checkpoints_enabled() is True diff --git a/tests/unit/test_warehouse.py b/tests/unit/test_warehouse.py index ae1c29e8b..dc4a726fa 100644 --- a/tests/unit/test_warehouse.py +++ b/tests/unit/test_warehouse.py @@ -45,7 +45,7 @@ def test_serialize(sqlite_db): def test_is_temp_table_name(warehouse): assert warehouse.is_temp_table_name("tmp_vc12F") is True - assert warehouse.is_temp_table_name("udf_jh653") is True + assert warehouse.is_temp_table_name("udf_jh653") is False assert warehouse.is_temp_table_name("ds_my_dataset") is False assert warehouse.is_temp_table_name("src_my_bucket") is False assert warehouse.is_temp_table_name("ds_ds_my_query_script_1_1") is False diff --git a/tests/utils.py b/tests/utils.py index c75606b69..b53351efe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -259,5 +259,12 @@ def reset_session_job_state(): Session._OWNS_JOB = None Session._JOB_HOOKS_REGISTERED = False + # Clear checkpoint state (now in utils module) + from datachain.utils import _CheckpointState + + _CheckpointState.disabled = False + _CheckpointState.warning_shown = False + # Clear DATACHAIN_JOB_ID env var to allow new job creation on next run + # This is important for studio/SaaS mode where job_id comes from env var os.environ.pop("DATACHAIN_JOB_ID", None)