Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Table,
Text,
UniqueConstraint,
desc,
select,
)
from sqlalchemy.sql import func as f
Expand Down Expand Up @@ -399,6 +400,7 @@ def create_job(
workers: int = 1,
python_version: Optional[str] = None,
params: Optional[dict[str, str]] = None,
parent_job_id: Optional[str] = None,
) -> str:
"""
Creates a new job.
Expand Down Expand Up @@ -443,6 +445,10 @@ def get_job_status(self, job_id: str) -> Optional[JobStatus]:
def list_checkpoints(self, job_id: str, conn=None) -> Iterator["Checkpoint"]:
"""Returns all checkpoints related to some job"""

@abstractmethod
def get_last_checkpoint(self, job_id: str, conn=None) -> Optional[Checkpoint]:
"""Get last created checkpoint for some job."""

@abstractmethod
def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
"""Gets single checkpoint by id"""
Expand Down Expand Up @@ -1548,6 +1554,7 @@ def _jobs_columns() -> "list[SchemaItem]":
Column("error_stack", Text, nullable=False, default=""),
Column("params", JSON, nullable=False),
Column("metrics", JSON, nullable=False),
Column("parent_job_id", Text, nullable=True),
]

@cached_property
Expand Down Expand Up @@ -1595,6 +1602,7 @@ def create_job(
workers: int = 1,
python_version: Optional[str] = None,
params: Optional[dict[str, str]] = None,
parent_job_id: Optional[str] = None,
conn: Optional[Any] = None,
) -> str:
"""
Expand All @@ -1616,6 +1624,7 @@ def create_job(
error_stack="",
params=json.dumps(params or {}),
metrics=json.dumps({}),
parent_job_id=parent_job_id,
),
conn=conn,
)
Expand Down Expand Up @@ -1770,7 +1779,7 @@ def create_checkpoint(
)
return self.get_checkpoint_by_id(checkpoint_id)

def list_checkpoints(self, job_id: str, conn=None) -> Iterator["Checkpoint"]:
def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
"""List checkpoints by job id."""
query = self._checkpoints_query().where(self._checkpoints.c.job_id == job_id)
rows = list(self.db.execute(query, conn=conn))
Expand Down Expand Up @@ -1800,3 +1809,15 @@ def find_checkpoint(
if not rows:
return None
return self.checkpoint_class.parse(*rows[0])

def get_last_checkpoint(self, job_id: str, conn=None) -> Optional[Checkpoint]:
query = (
self._checkpoints_query()
.where(self._checkpoints.c.job_id == job_id)
.order_by(desc(self._checkpoints.c.created_at))
.limit(1)
)
rows = list(self.db.execute(query, conn=conn))
if not rows:
return None
return self.checkpoint_class.parse(*rows[0])
4 changes: 4 additions & 0 deletions src/datachain/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,7 @@ class OutdatedDatabaseSchemaError(DataChainError):

class CheckpointNotFoundError(NotFoundError):
pass


class JobNotFoundError(NotFoundError):
pass
3 changes: 3 additions & 0 deletions src/datachain/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Job:
python_version: Optional[str] = None
error_message: str = ""
error_stack: str = ""
parent_job_id: Optional[str] = None

@classmethod
def parse(
Expand All @@ -39,6 +40,7 @@ def parse(
error_stack: str,
params: str,
metrics: str,
parent_job_id: Optional[str],
) -> "Job":
return cls(
str(id),
Expand All @@ -54,4 +56,5 @@ def parse(
python_version,
error_message,
error_stack,
parent_job_id,
)
Loading