Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/datachain/checkpoint_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum


class CheckpointEventType(str, Enum):
"""Types of checkpoint events."""

# UDF events
UDF_SKIPPED = "UDF_SKIPPED"
UDF_CONTINUED = "UDF_CONTINUED"
UDF_FROM_SCRATCH = "UDF_FROM_SCRATCH"

# Dataset save events
DATASET_SAVE_SKIPPED = "DATASET_SAVE_SKIPPED"
DATASET_SAVE_COMPLETED = "DATASET_SAVE_COMPLETED"


class CheckpointStepType(str, Enum):
"""Types of checkpoint steps."""

UDF_MAP = "UDF_MAP"
UDF_GEN = "UDF_GEN"
DATASET_SAVE = "DATASET_SAVE"


@dataclass
class CheckpointEvent:
"""
Represents a checkpoint event for debugging and visibility.

Checkpoint events are logged during job execution to track checkpoint
decisions (skip, continue, run from scratch) and provide visibility
into what happened during script execution.
"""

id: str
job_id: str
run_group_id: str | None
timestamp: datetime
event_type: CheckpointEventType
step_type: CheckpointStepType
udf_name: str | None = None
dataset_name: str | None = None
checkpoint_hash: str | None = None
hash_partial: str | None = None
hash_input: str | None = None
hash_output: str | None = None
rows_input: int | None = None
rows_processed: int | None = None
rows_generated: int | None = None
rows_reused: int | None = None
rerun_from_job_id: str | None = None
details: dict | None = None

@classmethod
def parse( # noqa: PLR0913
cls,
id: str | uuid.UUID,
job_id: str,
run_group_id: str | None,
timestamp: datetime,
event_type: str,
step_type: str,
udf_name: str | None,
dataset_name: str | None,
checkpoint_hash: str | None,
hash_partial: str | None,
hash_input: str | None,
hash_output: str | None,
rows_input: int | None,
rows_processed: int | None,
rows_generated: int | None,
rows_reused: int | None,
rerun_from_job_id: str | None,
details: dict | None,
) -> "CheckpointEvent":
return cls(
id=str(id),
job_id=job_id,
run_group_id=run_group_id,
timestamp=timestamp,
event_type=CheckpointEventType(event_type),
step_type=CheckpointStepType(step_type),
udf_name=udf_name,
dataset_name=dataset_name,
checkpoint_hash=checkpoint_hash,
hash_partial=hash_partial,
hash_input=hash_input,
hash_output=hash_output,
rows_input=rows_input,
rows_processed=rows_processed,
rows_generated=rows_generated,
rows_reused=rows_reused,
rerun_from_job_id=rerun_from_job_id,
details=details,
)
196 changes: 196 additions & 0 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
from datachain import json
from datachain.catalog.dependency import DatasetDependencyNode
from datachain.checkpoint import Checkpoint
from datachain.checkpoint_event import (
CheckpointEvent,
CheckpointEventType,
CheckpointStepType,
)
from datachain.data_storage import JobQueryType, JobStatus
from datachain.data_storage.serializer import Serializable
from datachain.dataset import (
Expand Down Expand Up @@ -98,6 +103,7 @@ class AbstractMetastore(ABC, Serializable):
dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
job_class: type[Job] = Job
checkpoint_class: type[Checkpoint] = Checkpoint
checkpoint_event_class: type[CheckpointEvent] = CheckpointEvent

def __init__(
self,
Expand Down Expand Up @@ -559,6 +565,42 @@ def get_or_create_checkpoint(
def remove_checkpoint(self, checkpoint_id: str, conn: Any | None = None) -> None:
"""Removes a checkpoint by ID"""

#
# Checkpoint Events
#

@abstractmethod
def log_checkpoint_event( # noqa: PLR0913
self,
job_id: str,
event_type: "CheckpointEventType",
step_type: "CheckpointStepType",
run_group_id: str | None = None,
udf_name: str | None = None,
dataset_name: str | None = None,
checkpoint_hash: str | None = None,
hash_partial: str | None = None,
hash_input: str | None = None,
hash_output: str | None = None,
rows_input: int | None = None,
rows_processed: int | None = None,
rows_generated: int | None = None,
rows_reused: int | None = None,
rerun_from_job_id: str | None = None,
details: dict | None = None,
conn: Any | None = None,
) -> "CheckpointEvent":
"""Log a checkpoint event."""

@abstractmethod
def get_checkpoint_events(
self,
job_id: str | None = None,
run_group_id: str | None = None,
conn: Any | None = None,
) -> Iterator["CheckpointEvent"]:
"""Get checkpoint events, optionally filtered by job_id or run_group_id."""

#
# Dataset Version Jobs (many-to-many)
#
Expand Down Expand Up @@ -604,6 +646,7 @@ class AbstractDBMetastore(AbstractMetastore):
DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs"
JOBS_TABLE = "jobs"
CHECKPOINTS_TABLE = "checkpoints"
CHECKPOINT_EVENTS_TABLE = "checkpoint_events"

db: "DatabaseEngine"

Expand Down Expand Up @@ -2114,6 +2157,73 @@ def _checkpoints(self) -> "Table":
@abstractmethod
def _checkpoints_insert(self) -> "Insert": ...

#
# Checkpoint Events
#

@staticmethod
def _checkpoint_events_columns() -> "list[SchemaItem]":
return [
Column(
"id",
Text,
default=uuid4,
primary_key=True,
nullable=False,
),
Column("job_id", Text, nullable=False),
Column("run_group_id", Text, nullable=True),
Column("timestamp", DateTime(timezone=True), nullable=False),
Column("event_type", Text, nullable=False),
Column("step_type", Text, nullable=False),
Column("udf_name", Text, nullable=True),
Column("dataset_name", Text, nullable=True),
Column("checkpoint_hash", Text, nullable=True),
Column("hash_partial", Text, nullable=True),
Column("hash_input", Text, nullable=True),
Column("hash_output", Text, nullable=True),
Column("rows_input", BigInteger, nullable=True),
Column("rows_processed", BigInteger, nullable=True),
Column("rows_generated", BigInteger, nullable=True),
Column("rows_reused", BigInteger, nullable=True),
Column("rerun_from_job_id", Text, nullable=True),
Column("details", JSON, nullable=True),
Index("dc_idx_ce_job_id", "job_id"),
Index("dc_idx_ce_run_group_id", "run_group_id"),
]

@cached_property
def _checkpoint_events_fields(self) -> list[str]:
return [
c.name # type: ignore[attr-defined]
for c in self._checkpoint_events_columns()
if isinstance(c, Column)
]

@cached_property
def _checkpoint_events(self) -> "Table":
return Table(
self.CHECKPOINT_EVENTS_TABLE,
self.db.metadata,
*self._checkpoint_events_columns(),
)

@abstractmethod
def _checkpoint_events_insert(self) -> "Insert": ...

def _checkpoint_events_select(self, *columns) -> "Select":
if not columns:
return self._checkpoint_events.select()
return select(*columns)

def _checkpoint_events_query(self):
return self._checkpoint_events_select(
*[
getattr(self._checkpoint_events.c, f)
for f in self._checkpoint_events_fields
]
)

@classmethod
def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]":
"""Junction table for dataset versions and jobs many-to-many relationship."""
Expand Down Expand Up @@ -2245,6 +2355,92 @@ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
return None
return self.checkpoint_class.parse(*rows[0])

def log_checkpoint_event( # noqa: PLR0913
self,
job_id: str,
event_type: CheckpointEventType,
step_type: CheckpointStepType,
run_group_id: str | None = None,
udf_name: str | None = None,
dataset_name: str | None = None,
checkpoint_hash: str | None = None,
hash_partial: str | None = None,
hash_input: str | None = None,
hash_output: str | None = None,
rows_input: int | None = None,
rows_processed: int | None = None,
rows_generated: int | None = None,
rows_reused: int | None = None,
rerun_from_job_id: str | None = None,
details: dict | None = None,
conn: Any | None = None,
) -> CheckpointEvent:
"""Log a checkpoint event."""
event_id = str(uuid4())
timestamp = datetime.now(timezone.utc)

query = self._checkpoint_events_insert().values(
id=event_id,
job_id=job_id,
run_group_id=run_group_id,
timestamp=timestamp,
event_type=event_type.value,
step_type=step_type.value,
udf_name=udf_name,
dataset_name=dataset_name,
checkpoint_hash=checkpoint_hash,
hash_partial=hash_partial,
hash_input=hash_input,
hash_output=hash_output,
rows_input=rows_input,
rows_processed=rows_processed,
rows_generated=rows_generated,
rows_reused=rows_reused,
rerun_from_job_id=rerun_from_job_id,
details=details,
)
self.db.execute(query, conn=conn)

return CheckpointEvent(
id=event_id,
job_id=job_id,
run_group_id=run_group_id,
timestamp=timestamp,
event_type=event_type,
step_type=step_type,
udf_name=udf_name,
dataset_name=dataset_name,
checkpoint_hash=checkpoint_hash,
hash_partial=hash_partial,
hash_input=hash_input,
hash_output=hash_output,
rows_input=rows_input,
rows_processed=rows_processed,
rows_generated=rows_generated,
rows_reused=rows_reused,
rerun_from_job_id=rerun_from_job_id,
details=details,
)

def get_checkpoint_events(
self,
job_id: str | None = None,
run_group_id: str | None = None,
conn: Any | None = None,
) -> Iterator[CheckpointEvent]:
"""Get checkpoint events, optionally filtered by job_id or run_group_id."""
query = self._checkpoint_events_query()

if job_id is not None:
query = query.where(self._checkpoint_events.c.job_id == job_id)
if run_group_id is not None:
query = query.where(self._checkpoint_events.c.run_group_id == run_group_id)

query = query.order_by(self._checkpoint_events.c.timestamp)
rows = list(self.db.execute(query, conn=conn))

yield from [self.checkpoint_event_class.parse(*r) for r in rows]

def link_dataset_version_to_job(
self,
dataset_version_id: int,
Expand Down
7 changes: 7 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def _metastore_tables(self) -> list[Table]:
self._datasets_dependencies,
self._jobs,
self._checkpoints,
self._checkpoint_events,
self._dataset_version_jobs,
]

Expand Down Expand Up @@ -674,6 +675,12 @@ def _jobs_insert(self) -> "Insert":
def _checkpoints_insert(self) -> "Insert":
return sqlite.insert(self._checkpoints)

#
# Checkpoint Events
#
def _checkpoint_events_insert(self) -> "Insert":
return sqlite.insert(self._checkpoint_events)

def _dataset_version_jobs_insert(self) -> "Insert":
return sqlite.insert(self._dataset_version_jobs)

Expand Down
12 changes: 12 additions & 0 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def create_dataset_uri(
return uri


def create_dataset_full_name(
namespace: str, project: str, name: str, version: str
) -> str:
"""
Creates a full dataset name including namespace, project and version.
Example:
Input: dev, clothes, zalando, 3.0.1
Output: dev.clothes.zalando@3.0.1
"""
return f"{namespace}.{project}.{name}@{version}"


def parse_dataset_name(name: str) -> tuple[str | None, str | None, str]:
"""Parses dataset name and returns namespace, project and name"""
if not name:
Expand Down
Loading