diff --git a/src/datachain/checkpoint_event.py b/src/datachain/checkpoint_event.py new file mode 100644 index 000000000..f614bab94 --- /dev/null +++ b/src/datachain/checkpoint_event.py @@ -0,0 +1,98 @@ +import uuid +from dataclasses import dataclass +from datetime import datetime +from enum import Enum + + +class CheckpointEventType(str, Enum): + """Types of checkpoint events.""" + + # UDF events + UDF_SKIPPED = "UDF_SKIPPED" + UDF_CONTINUED = "UDF_CONTINUED" + UDF_FROM_SCRATCH = "UDF_FROM_SCRATCH" + + # Dataset save events + DATASET_SAVE_SKIPPED = "DATASET_SAVE_SKIPPED" + DATASET_SAVE_COMPLETED = "DATASET_SAVE_COMPLETED" + + +class CheckpointStepType(str, Enum): + """Types of checkpoint steps.""" + + UDF_MAP = "UDF_MAP" + UDF_GEN = "UDF_GEN" + DATASET_SAVE = "DATASET_SAVE" + + +@dataclass +class CheckpointEvent: + """ + Represents a checkpoint event for debugging and visibility. + + Checkpoint events are logged during job execution to track checkpoint + decisions (skip, continue, run from scratch) and provide visibility + into what happened during script execution. + """ + + id: str + job_id: str + run_group_id: str | None + timestamp: datetime + event_type: CheckpointEventType + step_type: CheckpointStepType + udf_name: str | None = None + dataset_name: str | None = None + checkpoint_hash: str | None = None + hash_partial: str | None = None + hash_input: str | None = None + hash_output: str | None = None + rows_input: int | None = None + rows_processed: int | None = None + rows_generated: int | None = None + rows_reused: int | None = None + rerun_from_job_id: str | None = None + details: dict | None = None + + @classmethod + def parse( # noqa: PLR0913 + cls, + id: str | uuid.UUID, + job_id: str, + run_group_id: str | None, + timestamp: datetime, + event_type: str, + step_type: str, + udf_name: str | None, + dataset_name: str | None, + checkpoint_hash: str | None, + hash_partial: str | None, + hash_input: str | None, + hash_output: str | None, + rows_input: int | None, + rows_processed: int | None, + rows_generated: int | None, + rows_reused: int | None, + rerun_from_job_id: str | None, + details: dict | None, + ) -> "CheckpointEvent": + return cls( + id=str(id), + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=CheckpointEventType(event_type), + step_type=CheckpointStepType(step_type), + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_generated=rows_generated, + rows_reused=rows_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) diff --git a/src/datachain/data_storage/metastore.py b/src/datachain/data_storage/metastore.py index 0f9476c4d..ecbe78988 100644 --- a/src/datachain/data_storage/metastore.py +++ b/src/datachain/data_storage/metastore.py @@ -39,6 +39,11 @@ from datachain import json from datachain.catalog.dependency import DatasetDependencyNode from datachain.checkpoint import Checkpoint +from datachain.checkpoint_event import ( + CheckpointEvent, + CheckpointEventType, + CheckpointStepType, +) from datachain.data_storage import JobQueryType, JobStatus from datachain.data_storage.serializer import Serializable from datachain.dataset import ( @@ -98,6 +103,7 @@ class AbstractMetastore(ABC, Serializable): dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode job_class: type[Job] = Job checkpoint_class: type[Checkpoint] = Checkpoint + checkpoint_event_class: type[CheckpointEvent] = CheckpointEvent def __init__( self, @@ -559,6 +565,42 @@ def get_or_create_checkpoint( def remove_checkpoint(self, checkpoint_id: str, conn: Any | None = None) -> None: """Removes a checkpoint by ID""" + # + # Checkpoint Events + # + + @abstractmethod + def log_checkpoint_event( # noqa: PLR0913 + self, + job_id: str, + event_type: "CheckpointEventType", + step_type: "CheckpointStepType", + run_group_id: str | None = None, + udf_name: str | None = None, + dataset_name: str | None = None, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_generated: int | None = None, + rows_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + conn: Any | None = None, + ) -> "CheckpointEvent": + """Log a checkpoint event.""" + + @abstractmethod + def get_checkpoint_events( + self, + job_id: str | None = None, + run_group_id: str | None = None, + conn: Any | None = None, + ) -> Iterator["CheckpointEvent"]: + """Get checkpoint events, optionally filtered by job_id or run_group_id.""" + # # Dataset Version Jobs (many-to-many) # @@ -604,6 +646,7 @@ class AbstractDBMetastore(AbstractMetastore): DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs" JOBS_TABLE = "jobs" CHECKPOINTS_TABLE = "checkpoints" + CHECKPOINT_EVENTS_TABLE = "checkpoint_events" db: "DatabaseEngine" @@ -2114,6 +2157,73 @@ def _checkpoints(self) -> "Table": @abstractmethod def _checkpoints_insert(self) -> "Insert": ... + # + # Checkpoint Events + # + + @staticmethod + def _checkpoint_events_columns() -> "list[SchemaItem]": + return [ + Column( + "id", + Text, + default=uuid4, + primary_key=True, + nullable=False, + ), + Column("job_id", Text, nullable=False), + Column("run_group_id", Text, nullable=True), + Column("timestamp", DateTime(timezone=True), nullable=False), + Column("event_type", Text, nullable=False), + Column("step_type", Text, nullable=False), + Column("udf_name", Text, nullable=True), + Column("dataset_name", Text, nullable=True), + Column("checkpoint_hash", Text, nullable=True), + Column("hash_partial", Text, nullable=True), + Column("hash_input", Text, nullable=True), + Column("hash_output", Text, nullable=True), + Column("rows_input", BigInteger, nullable=True), + Column("rows_processed", BigInteger, nullable=True), + Column("rows_generated", BigInteger, nullable=True), + Column("rows_reused", BigInteger, nullable=True), + Column("rerun_from_job_id", Text, nullable=True), + Column("details", JSON, nullable=True), + Index("dc_idx_ce_job_id", "job_id"), + Index("dc_idx_ce_run_group_id", "run_group_id"), + ] + + @cached_property + def _checkpoint_events_fields(self) -> list[str]: + return [ + c.name # type: ignore[attr-defined] + for c in self._checkpoint_events_columns() + if isinstance(c, Column) + ] + + @cached_property + def _checkpoint_events(self) -> "Table": + return Table( + self.CHECKPOINT_EVENTS_TABLE, + self.db.metadata, + *self._checkpoint_events_columns(), + ) + + @abstractmethod + def _checkpoint_events_insert(self) -> "Insert": ... + + def _checkpoint_events_select(self, *columns) -> "Select": + if not columns: + return self._checkpoint_events.select() + return select(*columns) + + def _checkpoint_events_query(self): + return self._checkpoint_events_select( + *[ + getattr(self._checkpoint_events.c, f) + for f in self._checkpoint_events_fields + ] + ) + @classmethod def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]": """Junction table for dataset versions and jobs many-to-many relationship.""" @@ -2245,6 +2355,92 @@ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None: return None return self.checkpoint_class.parse(*rows[0]) + def log_checkpoint_event( # noqa: PLR0913 + self, + job_id: str, + event_type: CheckpointEventType, + step_type: CheckpointStepType, + run_group_id: str | None = None, + udf_name: str | None = None, + dataset_name: str | None = None, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_generated: int | None = None, + rows_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + conn: Any | None = None, + ) -> CheckpointEvent: + """Log a checkpoint event.""" + event_id = str(uuid4()) + timestamp = datetime.now(timezone.utc) + + query = self._checkpoint_events_insert().values( + id=event_id, + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=event_type.value, + step_type=step_type.value, + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_generated=rows_generated, + rows_reused=rows_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + self.db.execute(query, conn=conn) + + return CheckpointEvent( + id=event_id, + job_id=job_id, + run_group_id=run_group_id, + timestamp=timestamp, + event_type=event_type, + step_type=step_type, + udf_name=udf_name, + dataset_name=dataset_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_generated=rows_generated, + rows_reused=rows_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + + def get_checkpoint_events( + self, + job_id: str | None = None, + run_group_id: str | None = None, + conn: Any | None = None, + ) -> Iterator[CheckpointEvent]: + """Get checkpoint events, optionally filtered by job_id or run_group_id.""" + query = self._checkpoint_events_query() + + if job_id is not None: + query = query.where(self._checkpoint_events.c.job_id == job_id) + if run_group_id is not None: + query = query.where(self._checkpoint_events.c.run_group_id == run_group_id) + + query = query.order_by(self._checkpoint_events.c.timestamp) + rows = list(self.db.execute(query, conn=conn)) + + yield from [self.checkpoint_event_class.parse(*r) for r in rows] + def link_dataset_version_to_job( self, dataset_version_id: int, diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index 8813f36a4..18e3fda36 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -527,6 +527,7 @@ def _metastore_tables(self) -> list[Table]: self._datasets_dependencies, self._jobs, self._checkpoints, + self._checkpoint_events, self._dataset_version_jobs, ] @@ -674,6 +675,12 @@ def _jobs_insert(self) -> "Insert": def _checkpoints_insert(self) -> "Insert": return sqlite.insert(self._checkpoints) + # + # Checkpoint Events + # + def _checkpoint_events_insert(self) -> "Insert": + return sqlite.insert(self._checkpoint_events) + def _dataset_version_jobs_insert(self) -> "Insert": return sqlite.insert(self._dataset_version_jobs) diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 7566ff0c2..7aa5fa245 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -74,6 +74,18 @@ def create_dataset_uri( return uri +def create_dataset_full_name( + namespace: str, project: str, name: str, version: str +) -> str: + """ + Creates a full dataset name including namespace, project and version. + Example: + Input: dev, clothes, zalando, 3.0.1 + Output: dev.clothes.zalando@3.0.1 + """ + return f"{namespace}.{project}.{name}@{version}" + + def parse_dataset_name(name: str) -> tuple[str | None, str | None, str]: """Parses dataset name and returns namespace, project and name""" if not name: diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index ba558bcf4..c8ec591de 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -22,7 +22,8 @@ from tqdm import tqdm from datachain import json, semver -from datachain.dataset import DatasetRecord +from datachain.checkpoint_event import CheckpointEventType, CheckpointStepType +from datachain.dataset import DatasetRecord, create_dataset_full_name from datachain.delta import delta_disabled from datachain.error import ( JobAncestryDepthExceededError, @@ -674,6 +675,20 @@ def save( # type: ignore[override] ) ) + # Log checkpoint event for new dataset save + assert result.version is not None + full_dataset_name = create_dataset_full_name( + namespace_name, project_name, name, result.version + ) + catalog.metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=CheckpointEventType.DATASET_SAVE_COMPLETED, + step_type=CheckpointStepType.DATASET_SAVE, + run_group_id=self.job.run_group_id, + dataset_name=full_dataset_name, + checkpoint_hash=_hash, + ) + if checkpoints_enabled(): catalog.metastore.get_or_create_checkpoint(self.job.id, _hash) return result @@ -773,6 +788,20 @@ def _resolve_checkpoint( is_creator=False, ) + # Log checkpoint event + full_dataset_name = create_dataset_full_name( + project.namespace.name, project.name, name, dataset_version.version + ) + metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=CheckpointEventType.DATASET_SAVE_SKIPPED, + step_type=CheckpointStepType.DATASET_SAVE, + run_group_id=self.job.run_group_id, + dataset_name=full_dataset_name, + checkpoint_hash=job_hash, + rerun_from_job_id=self.job.rerun_from_job_id, + ) + return chain return None diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index b5ba7f2ce..f60b441a6 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -30,6 +30,10 @@ from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper from datachain.catalog.catalog import clone_catalog_with_cache from datachain.checkpoint import Checkpoint +from datachain.checkpoint_event import ( + CheckpointEventType, + CheckpointStepType, +) from datachain.data_storage.schema import ( PARTITION_COLUMN_ID, partition_col_names, @@ -816,6 +820,56 @@ def _run_group_id_short(self) -> str: """Get short run_group_id for logging.""" return self.job.run_group_id[:8] if self.job.run_group_id else "none" + @property + @abstractmethod + def _step_type(self) -> CheckpointStepType: + """Get the step type for checkpoint events.""" + + def _log_event( + self, + event_type: CheckpointEventType, + checkpoint_hash: str | None = None, + hash_partial: str | None = None, + hash_input: str | None = None, + hash_output: str | None = None, + rows_input: int | None = None, + rows_processed: int | None = None, + rows_generated: int | None = None, + rows_reused: int | None = None, + rerun_from_job_id: str | None = None, + details: dict | None = None, + ) -> None: + """Log a checkpoint event and emit a log message.""" + self.metastore.log_checkpoint_event( + job_id=self.job.id, + event_type=event_type, + step_type=self._step_type, + run_group_id=self.job.run_group_id, + udf_name=self._udf_name, + checkpoint_hash=checkpoint_hash, + hash_partial=hash_partial, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_processed, + rows_generated=rows_generated, + rows_reused=rows_reused, + rerun_from_job_id=rerun_from_job_id, + details=details, + ) + logger.info( + "UDF(%s) [job=%s run_group=%s]: %s - " + "input=%s, processed=%s, generated=%s, reused=%s", + self._udf_name, + self._job_id_short, + self._run_group_id_short, + event_type.value, + rows_input, + rows_processed, + rows_generated, + rows_reused, + ) + def _find_udf_checkpoint( self, _hash: str, partial: bool = False ) -> Checkpoint | None: @@ -1002,6 +1056,19 @@ def _skip_udf( checkpoint.hash[:8], ) + # Log checkpoint event with row counts + rows_input = self.warehouse.table_rows_count(input_table) + rows_reused = self.warehouse.table_rows_count(output_table) + self._log_event( + CheckpointEventType.UDF_SKIPPED, + checkpoint_hash=checkpoint.hash, + rerun_from_job_id=checkpoint.job_id, + rows_input=rows_input, + rows_processed=0, + rows_generated=0, + rows_reused=rows_reused, + ) + return output_table, input_table def _run_from_scratch( @@ -1056,6 +1123,20 @@ def _run_from_scratch( hash_output[:8], ) + # Log checkpoint event with row counts + rows_input = self.warehouse.table_rows_count(input_table) + rows_generated = self.warehouse.table_rows_count(output_table) + self._log_event( + CheckpointEventType.UDF_FROM_SCRATCH, + checkpoint_hash=hash_output, + hash_input=hash_input, + hash_output=hash_output, + rows_input=rows_input, + rows_processed=rows_input, + rows_generated=rows_generated, + rows_reused=0, + ) + return output_table, input_table def _continue_udf( @@ -1138,6 +1219,10 @@ def _continue_udf( incomplete_input_ids, ) + # Count rows before populating with new rows + rows_reused = self.warehouse.table_rows_count(partial_table) + rows_processed = self.warehouse.query_count(unprocessed_query) + self.populate_udf_output_table(partial_table, unprocessed_query) output_table = self.warehouse.rename_table( @@ -1154,6 +1239,23 @@ def _continue_udf( hash_output[:8], ) + # Log checkpoint event with row counts + rows_input = self.warehouse.table_rows_count(input_table) + total_output = self.warehouse.table_rows_count(output_table) + rows_generated = total_output - rows_reused + self._log_event( + CheckpointEventType.UDF_CONTINUED, + checkpoint_hash=hash_output, + hash_partial=checkpoint.hash, + hash_input=hash_input, + hash_output=hash_output, + rerun_from_job_id=checkpoint.job_id, + rows_input=rows_input, + rows_processed=rows_processed, + rows_generated=rows_generated, + rows_reused=rows_reused, + ) + return output_table, input_table @abstractmethod @@ -1232,6 +1334,10 @@ class UDFSignal(UDFStep): min_task_size: int | None = None batch_size: int | None = None + @property + def _step_type(self) -> CheckpointStepType: + return CheckpointStepType.UDF_MAP + def processed_input_ids_query(self, partial_table: "Table"): """ For mappers (1:1 mapping): returns sys__id from partial table. @@ -1338,6 +1444,10 @@ class RowGenerator(UDFStep): min_task_size: int | None = None batch_size: int | None = None + @property + def _step_type(self) -> CheckpointStepType: + return CheckpointStepType.UDF_GEN + def processed_input_ids_query(self, partial_table: "Table"): """ For generators (1:N mapping): returns distinct sys__input_id from partial table. diff --git a/tests/func/checkpoints/test_checkpoint_events.py b/tests/func/checkpoints/test_checkpoint_events.py new file mode 100644 index 000000000..36f67b7ce --- /dev/null +++ b/tests/func/checkpoints/test_checkpoint_events.py @@ -0,0 +1,359 @@ +from collections.abc import Iterator + +import pytest + +import datachain as dc +from datachain.checkpoint_event import CheckpointEventType, CheckpointStepType +from datachain.dataset import create_dataset_full_name +from tests.utils import reset_session_job_state + + +@pytest.fixture(autouse=True) +def mock_is_script_run(monkeypatch): + monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True) + + +@pytest.fixture +def nums_dataset(test_session): + return dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums") + + +def get_events(metastore, job_id): + return list(metastore.get_checkpoint_events(job_id=job_id)) + + +def get_udf_events(metastore, job_id): + return [ + e + for e in get_events(metastore, job_id) + if e.event_type + in ( + CheckpointEventType.UDF_SKIPPED, + CheckpointEventType.UDF_CONTINUED, + CheckpointEventType.UDF_FROM_SCRATCH, + ) + ] + + +def get_dataset_events(metastore, job_id): + return [ + e + for e in get_events(metastore, job_id) + if e.event_type + in ( + CheckpointEventType.DATASET_SAVE_SKIPPED, + CheckpointEventType.DATASET_SAVE_COMPLETED, + ) + ] + + +def test_map_from_scratch_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + assert len(events) == 1 + + map_event = events[0] + assert map_event.event_type == CheckpointEventType.UDF_FROM_SCRATCH + assert map_event.step_type == CheckpointStepType.UDF_MAP + assert map_event.udf_name == "double" + assert map_event.rows_input == 6 + assert map_event.rows_processed == 6 + assert map_event.rows_generated == 6 + assert map_event.rows_reused == 0 + assert map_event.rerun_from_job_id is None + assert map_event.hash_partial is None + + +def test_gen_from_scratch_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def duplicate(num) -> Iterator[int]: + yield num + yield num + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).gen(dup=duplicate, output=int).save( + "duplicated" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + gen_event = next(e for e in events if e.udf_name == "duplicate") + + assert gen_event.event_type == CheckpointEventType.UDF_FROM_SCRATCH + assert gen_event.step_type == CheckpointStepType.UDF_GEN + assert gen_event.rows_input == 6 + assert gen_event.rows_processed == 6 + assert gen_event.rows_generated == 12 + assert gen_event.rows_reused == 0 + + +def test_map_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + chain = dc.read_dataset("nums", session=test_session).map( + doubled=double, output=int + ) + + reset_session_job_state() + chain.save("doubled") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + chain.save("doubled2") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + assert len(events) == 1 + + map_event = events[0] + assert map_event.event_type == CheckpointEventType.UDF_SKIPPED + assert map_event.udf_name == "double" + assert map_event.rows_input == 6 + assert map_event.rows_processed == 0 + assert map_event.rows_generated == 0 + assert map_event.rows_reused == 6 + assert map_event.rerun_from_job_id == first_job_id + assert map_event.hash_partial is None + + +def test_gen_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def duplicate(num) -> Iterator[int]: + yield num + yield num + + chain = dc.read_dataset("nums", session=test_session).gen(dup=duplicate, output=int) + + reset_session_job_state() + chain.save("duplicated") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + chain.save("duplicated2") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + gen_event = next(e for e in events if e.udf_name == "duplicate") + + assert gen_event.event_type == CheckpointEventType.UDF_SKIPPED + assert gen_event.rows_input == 6 + assert gen_event.rows_processed == 0 + assert gen_event.rows_generated == 0 + assert gen_event.rows_reused == 12 + assert gen_event.rerun_from_job_id == first_job_id + + +def test_map_continued_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + processed = [] + + def buggy_double(num) -> int: + if len(processed) >= 3: + raise Exception("Simulated failure") + processed.append(num) + return num * 2 + + chain = dc.read_dataset("nums", session=test_session).map( + doubled=buggy_double, output=int + ) + + reset_session_job_state() + with pytest.raises(Exception, match="Simulated failure"): + chain.save("doubled") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + processed.clear() + + def fixed_double(num) -> int: + processed.append(num) + return num * 2 + + dc.read_dataset("nums", session=test_session).map( + doubled=fixed_double, output=int + ).save("doubled") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + map_event = next(e for e in events if e.udf_name == "fixed_double") + + assert map_event.event_type == CheckpointEventType.UDF_CONTINUED + assert map_event.rows_input == 6 + assert map_event.rows_reused == 3 + assert map_event.rows_processed == 3 + assert map_event.rows_generated == 3 + assert map_event.rerun_from_job_id == first_job_id + assert map_event.hash_partial is not None + + +def test_gen_continued_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + processed = [] + + def buggy_gen(num) -> Iterator[int]: + if len(processed) >= 2: + raise Exception("Simulated failure") + processed.append(num) + yield num + yield num * 10 + + chain = dc.read_dataset("nums", session=test_session).gen( + result=buggy_gen, output=int + ) + + reset_session_job_state() + with pytest.raises(Exception, match="Simulated failure"): + chain.save("results") + first_job_id = test_session.get_or_create_job().id + + reset_session_job_state() + processed.clear() + + def fixed_gen(num) -> Iterator[int]: + processed.append(num) + yield num + yield num * 10 + + dc.read_dataset("nums", session=test_session).gen( + result=fixed_gen, output=int + ).save("results") + second_job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, second_job_id) + gen_event = next(e for e in events if e.udf_name == "fixed_gen") + + assert gen_event.event_type == CheckpointEventType.UDF_CONTINUED + assert gen_event.rows_input == 6 + assert gen_event.rows_reused == 4 + assert gen_event.rows_processed == 4 + assert gen_event.rows_generated == 8 + assert gen_event.rerun_from_job_id == first_job_id + assert gen_event.hash_partial is not None + + +def test_dataset_save_completed_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).save("nums_copy") + job_id = test_session.get_or_create_job().id + + events = get_dataset_events(metastore, job_id) + + assert len(events) == 1 + event = events[0] + + expected_name = create_dataset_full_name( + metastore.default_namespace_name, + metastore.default_project_name, + "nums_copy", + "1.0.0", + ) + assert event.event_type == CheckpointEventType.DATASET_SAVE_COMPLETED + assert event.step_type == CheckpointStepType.DATASET_SAVE + assert event.dataset_name == expected_name + assert event.checkpoint_hash is not None + + +def test_dataset_save_skipped_event(test_session, nums_dataset): + metastore = test_session.catalog.metastore + chain = dc.read_dataset("nums", session=test_session) + + reset_session_job_state() + chain.save("nums_copy") + first_job_id = test_session.get_or_create_job().id + + first_events = get_dataset_events(metastore, first_job_id) + assert len(first_events) == 1 + assert first_events[0].event_type == CheckpointEventType.DATASET_SAVE_COMPLETED + + reset_session_job_state() + chain.save("nums_copy") + second_job_id = test_session.get_or_create_job().id + + second_events = get_dataset_events(metastore, second_job_id) + + assert len(second_events) == 1 + event = second_events[0] + + expected_name = create_dataset_full_name( + metastore.default_namespace_name, + metastore.default_project_name, + "nums_copy", + "1.0.0", + ) + assert event.event_type == CheckpointEventType.DATASET_SAVE_SKIPPED + assert event.step_type == CheckpointStepType.DATASET_SAVE + assert event.dataset_name == expected_name + assert event.rerun_from_job_id is not None + + +def test_events_by_run_group(test_session, monkeypatch, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + first_job = test_session.get_or_create_job() + + reset_session_job_state() + second_job_id = metastore.create_job( + "test-job", + "echo 1", + rerun_from_job_id=first_job.id, + run_group_id=first_job.run_group_id, + ) + monkeypatch.setenv("DATACHAIN_JOB_ID", second_job_id) + + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled2" + ) + + run_group_events = list( + metastore.get_checkpoint_events(run_group_id=first_job.run_group_id) + ) + + job_ids = {e.job_id for e in run_group_events} + assert first_job.id in job_ids + assert second_job_id in job_ids + + +def test_hash_fields_populated(test_session, nums_dataset): + metastore = test_session.catalog.metastore + + def double(num) -> int: + return num * 2 + + reset_session_job_state() + dc.read_dataset("nums", session=test_session).map(doubled=double, output=int).save( + "doubled" + ) + job_id = test_session.get_or_create_job().id + + events = get_udf_events(metastore, job_id) + + for event in events: + assert event.checkpoint_hash is not None + assert event.hash_input is not None + assert event.hash_output is not None + if event.event_type == CheckpointEventType.UDF_FROM_SCRATCH: + assert event.hash_partial is None