diff --git a/src/datachain/data_storage/metastore.py b/src/datachain/data_storage/metastore.py index 88cd9b922..7095c1dfe 100644 --- a/src/datachain/data_storage/metastore.py +++ b/src/datachain/data_storage/metastore.py @@ -466,6 +466,8 @@ def create_job( python_version: str | None = None, params: dict[str, str] | None = None, parent_job_id: str | None = None, + rerun_from_job_id: str | None = None, + run_group_id: str | None = None, ) -> str: """ Creates a new job. @@ -1835,7 +1837,11 @@ def _jobs_columns() -> "list[SchemaItem]": Column("params", JSON, nullable=False), Column("metrics", JSON, nullable=False), Column("parent_job_id", Text, nullable=True), + Column("rerun_from_job_id", Text, nullable=True), + Column("run_group_id", Text, nullable=True), Index("idx_jobs_parent_job_id", "parent_job_id"), + Index("idx_jobs_rerun_from_job_id", "rerun_from_job_id"), + Index("idx_jobs_run_group_id", "run_group_id"), ] @cached_property @@ -1896,6 +1902,8 @@ def create_job( python_version: str | None = None, params: dict[str, str] | None = None, parent_job_id: str | None = None, + rerun_from_job_id: str | None = None, + run_group_id: str | None = None, conn: Any = None, ) -> str: """ @@ -1903,6 +1911,20 @@ def create_job( Returns the job id. """ job_id = str(uuid4()) + + # Validate run_group_id and rerun_from_job_id consistency + if rerun_from_job_id: + # Rerun job: run_group_id must be provided by caller + assert run_group_id is not None, ( + "run_group_id must be provided when rerun_from_job_id is set" + ) + else: + # First job: run_group_id should not be provided (we set it here) + assert run_group_id is None, ( + "run_group_id should not be provided when rerun_from_job_id is not set" + ) + run_group_id = job_id + self.db.execute( self._jobs_insert().values( id=job_id, @@ -1918,6 +1940,8 @@ def create_job( params=json.dumps(params or {}), metrics=json.dumps({}), parent_job_id=parent_job_id, + rerun_from_job_id=rerun_from_job_id, + run_group_id=run_group_id, ), conn=conn, ) @@ -2191,14 +2215,16 @@ def link_dataset_version_to_job( self.db.execute(update_query, conn=conn) def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: - # Use recursive CTE to walk up the parent chain - # Format: WITH RECURSIVE ancestors(id, parent_job_id, depth) AS (...) + # Use recursive CTE to walk up the rerun chain + # Format: WITH RECURSIVE ancestors(id, rerun_from_job_id, run_group_id, + # depth) AS (...) # Include depth tracking to prevent infinite recursion in case of # circular dependencies ancestors_cte = ( self._jobs_select( self._jobs.c.id.label("id"), - self._jobs.c.parent_job_id.label("parent_job_id"), + self._jobs.c.rerun_from_job_id.label("rerun_from_job_id"), + self._jobs.c.run_group_id.label("run_group_id"), literal(0).label("depth"), ) .where(self._jobs.c.id == job_id) @@ -2206,20 +2232,30 @@ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]: ) # Recursive part: join with parent jobs, incrementing depth and checking limit + # Also ensure we only traverse jobs within the same run_group_id for safety ancestors_recursive = ancestors_cte.union_all( self._jobs_select( self._jobs.c.id.label("id"), - self._jobs.c.parent_job_id.label("parent_job_id"), + self._jobs.c.rerun_from_job_id.label("rerun_from_job_id"), + self._jobs.c.run_group_id.label("run_group_id"), (ancestors_cte.c.depth + 1).label("depth"), ).select_from( self._jobs.join( ancestors_cte, ( self._jobs.c.id - == cast(ancestors_cte.c.parent_job_id, self._jobs.c.id.type) + == cast(ancestors_cte.c.rerun_from_job_id, self._jobs.c.id.type) ) - & (ancestors_cte.c.parent_job_id.isnot(None)) # Stop at root jobs - & (ancestors_cte.c.depth < JOB_ANCESTRY_MAX_DEPTH), + & ( + ancestors_cte.c.rerun_from_job_id.isnot(None) + ) # Stop at root jobs + & (ancestors_cte.c.depth < JOB_ANCESTRY_MAX_DEPTH) + & ( + self._jobs.c.run_group_id + == cast( + ancestors_cte.c.run_group_id, self._jobs.c.run_group_id.type + ) + ), # Safety: only traverse within same run group ) ) ) diff --git a/src/datachain/job.py b/src/datachain/job.py index 685d2e191..37fa4c7bb 100644 --- a/src/datachain/job.py +++ b/src/datachain/job.py @@ -24,6 +24,8 @@ class Job: error_message: str = "" error_stack: str = "" parent_job_id: str | None = None + rerun_from_job_id: str | None = None + run_group_id: str | None = None @classmethod def parse( @@ -42,6 +44,8 @@ def parse( params: str, metrics: str, parent_job_id: str | None, + rerun_from_job_id: str | None, + run_group_id: str | None, ) -> "Job": return cls( str(id), @@ -58,4 +62,6 @@ def parse( error_message, error_stack, str(parent_job_id) if parent_job_id else None, + str(rerun_from_job_id) if rerun_from_job_id else None, + str(run_group_id) if run_group_id else None, ) diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index 233614d8e..d467c53fd 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -718,9 +718,9 @@ def _resolve_checkpoint( _hash = self._calculate_job_hash(job.id) if ( - job.parent_job_id + job.rerun_from_job_id and not checkpoints_reset - and metastore.find_checkpoint(job.parent_job_id, _hash) + and metastore.find_checkpoint(job.rerun_from_job_id, _hash) ): # checkpoint found → find which dataset version to reuse diff --git a/src/datachain/query/session.py b/src/datachain/query/session.py index 98e0d4027..e32388771 100644 --- a/src/datachain/query/session.py +++ b/src/datachain/query/session.py @@ -154,7 +154,7 @@ def get_or_create_job(self) -> "Job": script = str(uuid4()) python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - # try to find the parent job + # try to find the parent job for checkpoint/rerun chain parent = self.catalog.metastore.get_last_job_by_name(script) job_id = self.catalog.metastore.create_job( @@ -163,7 +163,8 @@ def get_or_create_job(self) -> "Job": query_type=JobQueryType.PYTHON, status=JobStatus.RUNNING, python_version=python_version, - parent_job_id=parent.id if parent else None, + rerun_from_job_id=parent.id if parent else None, + run_group_id=parent.run_group_id if parent else None, ) Session._CURRENT_JOB = self.catalog.metastore.get_job(job_id) Session._OWNS_JOB = True diff --git a/tests/func/test_metastore.py b/tests/func/test_metastore.py index 6d872a751..9e814a89a 100644 --- a/tests/func/test_metastore.py +++ b/tests/func/test_metastore.py @@ -912,13 +912,14 @@ def test_get_job_status(metastore): @pytest.mark.parametrize("depth", [0, 1, 2, 3, 5]) def test_get_ancestor_job_ids(metastore, depth): """Test get_ancestor_job_ids with different hierarchy depths.""" - # Create a chain of jobs with parent relationships - # depth=0: single job with no parent - # depth=1: job -> parent - # depth=2: job -> parent -> grandparent + # Create a chain of jobs with rerun relationships + # depth=0: single job with no rerun ancestor + # depth=1: job -> rerun_from + # depth=2: job -> rerun_from -> rerun_from job_ids = [] - parent_id = None + rerun_from_id = None + group_id = None # Create jobs from root to leaf for i in range(depth + 1): @@ -928,10 +929,14 @@ def test_get_ancestor_job_ids(metastore, depth): query_type=JobQueryType.PYTHON, status=JobStatus.CREATED, workers=1, - parent_job_id=parent_id, + rerun_from_job_id=rerun_from_id, + run_group_id=group_id, ) job_ids.append(job_id) - parent_id = job_id + rerun_from_id = job_id + # First job sets the group_id + if group_id is None: + group_id = metastore.get_job(job_id).run_group_id # The last job is the leaf (youngest) leaf_job_id = job_ids[-1] diff --git a/tests/unit/lib/test_checkpoints.py b/tests/unit/lib/test_checkpoints.py index b8bf0427a..e28ab51bd 100644 --- a/tests/unit/lib/test_checkpoints.py +++ b/tests/unit/lib/test_checkpoints.py @@ -104,7 +104,8 @@ def test_checkpoints( chain.save("nums2") with pytest.raises(CustomMapperError): chain.map(new=mapper_fail).save("nums3") - first_job_id = test_session.get_or_create_job().id + first_job = test_session.get_or_create_job() + first_job_id = first_job.id catalog.get_dataset("nums1") catalog.get_dataset("nums2") @@ -116,7 +117,12 @@ def test_checkpoints( if use_datachain_job_id_env: monkeypatch.setenv( "DATACHAIN_JOB_ID", - metastore.create_job("my-job", "echo 1;", parent_job_id=first_job_id), + metastore.create_job( + "my-job", + "echo 1;", + rerun_from_job_id=first_job_id, + run_group_id=first_job.run_group_id, + ), ) chain.save("nums1") diff --git a/tests/unit/test_job_management.py b/tests/unit/test_job_management.py index dcab2112e..870b6e648 100644 --- a/tests/unit/test_job_management.py +++ b/tests/unit/test_job_management.py @@ -117,7 +117,7 @@ def test_get_or_create_links_to_parent(test_session, patch_argv, monkeypatch): session2 = Session(catalog=test_session.catalog) job2 = session2.get_or_create_job() - assert job2.parent_job_id == job1.id + assert job2.rerun_from_job_id == job1.id def test_nested_sessions_share_same_job(test_session, patch_argv, monkeypatch):