-
Notifications
You must be signed in to change notification settings - Fork 134
Skip processing chain if checkpoint exists #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Reviewer's GuideAdds a checkpointing mechanism in DataChain.save by computing job-specific chain hashes, skipping redundant executions when a matching checkpoint exists, and persisting checkpoints across save paths; extends the Job model and metastore schema to track parent_job relationships and handle missing jobs; and updates tests to use metastore-generated jobs and validate checkpoint behavior. File-Level Changes
Possibly linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Deploying datachain-documentation with
|
| Latest commit: |
79792ce
|
| Status: | ✅ Deploy successful! |
| Preview URL: | https://77f3d125.datachain-documentation.pages.dev |
| Branch Preview URL: | https://ilongin-1350-checkpoints-ski.datachain-documentation.pages.dev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes and they look great!
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/datachain/lib/dc/datachain.py:584-593` </location>
<code_context>
query=self._query.save(project=project, feature_schema=schema)
)
+ def _calculate_job_hash(self, job_id: str) -> str:
+ """
+ Calculates hash of the job at the place of this chain's save method.
+ Hash is calculated using previous job checkpoint hash (if exists) and
+ adding hash of this chain to produce new hash.
+ """
+ last_checkpoint = max(
+ self.session.catalog.metastore.list_checkpoints(job_id),
+ key=lambda obj: obj.created_at,
+ default=None,
+ )
+
+ return hashlib.sha256(
+ bytes.fromhex(last_checkpoint.hash)
+ if last_checkpoint
</code_context>
<issue_to_address>
**issue:** The hash calculation logic may not combine the previous checkpoint and current chain hash as intended.
Currently, only one hash is used depending on whether last_checkpoint exists. To combine both, concatenate their byte representations before hashing.
</issue_to_address>
### Comment 2
<location> `src/datachain/lib/dc/datachain.py:602` </location>
<code_context>
def save( # type: ignore[override]
self,
name: str,
version: Optional[str] = None,
description: Optional[str] = None,
attrs: Optional[list[str]] = None,
update_version: Optional[str] = "patch",
**kwargs,
) -> "DataChain":
"""Save to a Dataset. It returns the chain itself.
Parameters:
name: dataset name. This can be either a fully qualified name, including
the namespace and project, or just a regular dataset name. In the latter
case, the namespace and project will be taken from the settings
(if specified) or from the default values otherwise.
version: version of a dataset. If version is not specified and dataset
already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
description: description of a dataset.
attrs: attributes of a dataset. They can be without value, e.g "NLP",
or with a value, e.g "location=US".
update_version: which part of the dataset version to automatically increase.
Available values: `major`, `minor` or `patch`. Default is `patch`.
"""
catalog = self.session.catalog
metastore = catalog.metastore
job = None
_hash = None
job_id = os.getenv("DATACHAIN_JOB_ID")
checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET")
if version is not None:
semver.validate(version)
if update_version is not None and update_version not in [
"patch",
"major",
"minor",
]:
raise ValueError(
"update_version can have one of the following values: major, minor or"
" patch"
)
namespace_name, project_name, name = catalog.get_full_dataset_name(
name,
namespace_name=self._settings.namespace,
project_name=self._settings.project,
)
try:
project = self.session.catalog.metastore.get_project(
project_name,
namespace_name,
create=is_studio(),
)
except ProjectNotFoundError as e:
# not being able to create it as creation is not allowed
raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
# checking checkpoints and skip re-calculation of the chain if checkpoint exist
if job_id:
job = metastore.get_job(job_id) # type: ignore[arg-type]
if not job:
raise JobNotFoundError(f"Job with id {job_id} not found")
_hash = self._calculate_job_hash(job.id)
if (
job.parent_job_id
and not checkpoints_reset
and metastore.find_checkpoint(job.parent_job_id, _hash)
):
# if we find checkpoint with correct hash, we can skip chain calculation
catalog.metastore.create_checkpoint(job.id, _hash)
from .datasets import read_dataset
return read_dataset(
name, namespace=namespace_name, project=project_name, **kwargs
)
schema = self.signals_schema.clone_without_sys_signals().serialize()
# Handle retry and delta functionality
if self.delta and name:
from datachain.delta import delta_retry_update
# Delta chains must have delta_on defined (ensured by _as_delta method)
assert self._delta_on is not None, "Delta chain must have delta_on defined"
result_ds, dependencies, has_changes = delta_retry_update(
self,
namespace_name,
project_name,
name,
on=self._delta_on,
right_on=self._delta_result_on,
compare=self._delta_compare,
delta_retry=self._delta_retry,
)
if result_ds:
return self._evolve(
query=result_ds._query.save(
name=name,
version=version,
project=project,
feature_schema=schema,
dependencies=dependencies,
**kwargs,
)
)
if not has_changes:
# sources have not been changed so new version of resulting dataset
# would be the same as previous one. To avoid duplicating exact
# datasets, we won't create new version of it and we will return
# current latest version instead.
from .datasets import read_dataset
if job:
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
return read_dataset(
name, namespace=namespace_name, project=project_name, **kwargs
)
result = self._evolve(
query=self._query.save(
name=name,
version=version,
project=project,
description=description,
attrs=attrs,
feature_schema=schema,
update_version=update_version,
**kwargs,
)
)
if job:
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
return result
</code_context>
<issue_to_address>
**issue (code-quality):** Low code quality found in DataChain.save - 21% ([`low-code-quality`](https://docs.sourcery.ai/Reference/Default-Rules/comments/low-code-quality/))
<br/><details><summary>Explanation</summary>The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
- Reduce the function length by extracting pieces of functionality out into
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines.
- Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
sits together within the function rather than being scattered.</details>
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1362 +/- ##
==========================================
+ Coverage 87.61% 87.71% +0.09%
==========================================
Files 159 159
Lines 14838 14883 +45
Branches 2129 2135 +6
==========================================
+ Hits 13001 13054 +53
+ Misses 1351 1344 -7
+ Partials 486 485 -1
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes and they look great!
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/datachain/lib/dc/datachain.py:721` </location>
<code_context>
# current latest version instead.
from .datasets import read_dataset
+ if job:
+ catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
return read_dataset(
</code_context>
<issue_to_address>
**suggestion:** Checkpoint creation is repeated in multiple places; consider centralizing logic.
Centralizing checkpoint creation will help avoid duplication and reduce the risk of inconsistent behavior across the codebase.
Suggested implementation:
```python
if job:
self.create_job_checkpoint(job, _hash, catalog)
```
```python
# Handle retry and delta functionality
# current latest version instead.
from .datasets import read_dataset
def create_job_checkpoint(self, job, _hash, catalog):
"""
Centralized checkpoint creation for jobs.
"""
if job:
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
```
You should also:
- Replace any other occurrences of `catalog.metastore.create_checkpoint(job.id, _hash)` in this file with `self.create_job_checkpoint(job, _hash, catalog)`.
- If this logic is needed in other files, consider moving `create_job_checkpoint` to a shared utility module or the appropriate class base.
</issue_to_address>
### Comment 2
<location> `tests/unit/lib/test_checkpoints.py:18-27` </location>
<code_context>
+
+
+@pytest.mark.parametrize("reset_checkpoints", [True, False])
+def test_checkpoints(test_session, monkeypatch, nums_dataset, reset_checkpoints):
+ catalog = test_session.catalog
+
+ monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints)
+
+ # -------------- FIRST RUN -------------------
+ first_job_id = catalog.metastore.create_job("my-job", "echo 1;")
+ monkeypatch.setenv("DATACHAIN_JOB_ID", first_job_id)
+ dc.read_dataset("nums", session=test_session).save("nums1")
+ dc.read_dataset("nums", session=test_session).save("nums2")
+ with pytest.raises(DataChainError):
+ (
+ dc.read_dataset("nums", session=test_session)
+ .map(new=mapper_fail)
+ .save("nums3")
+ )
+ catalog.get_dataset("nums1")
+ catalog.get_dataset("nums2")
+ with pytest.raises(DatasetNotFoundError):
+ catalog.get_dataset("nums3")
+
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding assertions to verify that the skipped chain processing actually returns the correct dataset object.
Add assertions to confirm that the dataset returned after skipping chain processing matches the expected data and schema, ensuring the checkpoint mechanism's correctness.
Suggested implementation:
```python
ds1 = catalog.get_dataset("nums1")
ds2 = catalog.get_dataset("nums2")
# Assert that the data matches expected values
assert list(ds1.data) == [1, 2, 3]
assert list(ds2.data) == [1, 2, 3]
# Assert that the schema matches expected schema
assert ds1.schema == ds2.schema
assert ds1.schema == {"num": int} or ds1.schema == {"num": "int"} # Adjust as needed for your schema representation
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("nums3")
```
- If your dataset objects use different attribute names for data or schema, adjust `ds1.data`, `ds1.schema` accordingly.
- If your schema representation differs (e.g., uses type objects or strings), update the assertion to match your codebase's conventions.
</issue_to_address>
### Comment 3
<location> `tests/unit/lib/test_checkpoints.py:96-105` </location>
<code_context>
+def test_checkpoints_multiple_runs(
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding a test for invalid or missing parent_job_id.
Add a test case where parent_job_id is invalid or missing, and verify that JobNotFoundError is raised to ensure robust error handling.
</issue_to_address>
### Comment 4
<location> `tests/unit/lib/test_checkpoints.py:9` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Raise a specific error instead of the general `Exception` or `BaseException` ([`raise-specific-error`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/raise-specific-error))
<details><summary>Explanation</summary>If a piece of code raises a specific exception type
rather than the generic
[`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException)
or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception),
the calling code can:
- get more information about what type of error it is
- define specific exception handling for it
This way, callers of the code can handle the error appropriately.
How can you solve this?
- Use one of the [built-in exceptions](https://docs.python.org/3/library/exceptions.html) of the standard library.
- [Define your own error class](https://docs.python.org/3/tutorial/errors.html#tut-userexceptions) that subclasses `Exception`.
So instead of having code raising `Exception` or `BaseException` like
```python
if incorrect_input(value):
raise Exception("The input is incorrect")
```
you can have code raising a specific error like
```python
if incorrect_input(value):
raise ValueError("The input is incorrect")
```
or
```python
class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")
```
</details>
</issue_to_address>
### Comment 5
<location> `tests/unit/lib/test_checkpoints.py:49` </location>
<code_context>
len(catalog.get_dataset("nums1").versions) == 1 if not reset_checkpoints else 2
</code_context>
<issue_to_address>
**suggestion (code-quality):** Swap if/else branches of if expression to remove negation ([`swap-if-expression`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/swap-if-expression))
```suggestion
2 if reset_checkpoints else len(catalog.get_dataset("nums1").versions) == 1
```
<br/><details><summary>Explanation</summary>Negated conditions are more difficult to read than positive ones, so it is best
to avoid them where we can. By swapping the `if` and `else` conditions around we
can invert the condition and make it positive.
</details>
</issue_to_address>
### Comment 6
<location> `tests/unit/lib/test_checkpoints.py:52` </location>
<code_context>
len(catalog.get_dataset("nums2").versions) == 1 if not reset_checkpoints else 2
</code_context>
<issue_to_address>
**suggestion (code-quality):** Swap if/else branches of if expression to remove negation ([`swap-if-expression`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/swap-if-expression))
```suggestion
2 if reset_checkpoints else len(catalog.get_dataset("nums2").versions) == 1
```
<br/><details><summary>Explanation</summary>Negated conditions are more difficult to read than positive ones, so it is best
to avoid them where we can. By swapping the `if` and `else` conditions around we
can invert the condition and make it positive.
</details>
</issue_to_address>
### Comment 7
<location> `tests/unit/lib/test_checkpoints.py:86` </location>
<code_context>
len(catalog.get_dataset("nums1").versions) == 1 if not reset_checkpoints else 2
</code_context>
<issue_to_address>
**suggestion (code-quality):** Swap if/else branches of if expression to remove negation ([`swap-if-expression`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/swap-if-expression))
```suggestion
2 if reset_checkpoints else len(catalog.get_dataset("nums1").versions) == 1
```
<br/><details><summary>Explanation</summary>Negated conditions are more difficult to read than positive ones, so it is best
to avoid them where we can. By swapping the `if` and `else` conditions around we
can invert the condition and make it positive.
</details>
</issue_to_address>
### Comment 8
<location> `tests/unit/lib/test_checkpoints.py:155-163` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Avoid conditionals in tests. ([`no-conditionals-in-tests`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/no-conditionals-in-tests))
<details><summary>Explanation</summary>Avoid complex code, like conditionals, in test functions.
Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
* loops
* conditionals
Some ways to fix this:
* Use parametrized tests to get rid of the loop.
* Move the complex logic into helpers.
* Move the complex part into pytest fixtures.
> Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.
Software Engineering at Google / [Don't Put Logic in Tests](https://abseil.io/resources/swe-book/html/ch12.html#donapostrophet_put_logic_in_tests)
</details>
</issue_to_address>
### Comment 9
<location> `src/datachain/lib/dc/datachain.py:601` </location>
<code_context>
def save( # type: ignore[override]
self,
name: str,
version: Optional[str] = None,
description: Optional[str] = None,
attrs: Optional[list[str]] = None,
update_version: Optional[str] = "patch",
**kwargs,
) -> "DataChain":
"""Save to a Dataset. It returns the chain itself.
Parameters:
name: dataset name. This can be either a fully qualified name, including
the namespace and project, or just a regular dataset name. In the latter
case, the namespace and project will be taken from the settings
(if specified) or from the default values otherwise.
version: version of a dataset. If version is not specified and dataset
already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
description: description of a dataset.
attrs: attributes of a dataset. They can be without value, e.g "NLP",
or with a value, e.g "location=US".
update_version: which part of the dataset version to automatically increase.
Available values: `major`, `minor` or `patch`. Default is `patch`.
"""
catalog = self.session.catalog
metastore = catalog.metastore
job = None
_hash = None
job_id = os.getenv("DATACHAIN_JOB_ID")
checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET")
if version is not None:
semver.validate(version)
if update_version is not None and update_version not in [
"patch",
"major",
"minor",
]:
raise ValueError(
"update_version can have one of the following values: major, minor or"
" patch"
)
namespace_name, project_name, name = catalog.get_full_dataset_name(
name,
namespace_name=self._settings.namespace,
project_name=self._settings.project,
)
try:
project = self.session.catalog.metastore.get_project(
project_name,
namespace_name,
create=is_studio(),
)
except ProjectNotFoundError as e:
# not being able to create it as creation is not allowed
raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
# checking checkpoints and skip re-calculation of the chain if checkpoint exist
if job_id:
job = metastore.get_job(job_id) # type: ignore[arg-type]
if not job:
raise JobNotFoundError(f"Job with id {job_id} not found")
_hash = self._calculate_job_hash(job.id)
if (
job.parent_job_id
and not checkpoints_reset
and metastore.find_checkpoint(job.parent_job_id, _hash)
):
# if we find checkpoint with correct hash, we can skip chain calculation
catalog.metastore.create_checkpoint(job.id, _hash)
from .datasets import read_dataset
return read_dataset(
name, namespace=namespace_name, project=project_name, **kwargs
)
schema = self.signals_schema.clone_without_sys_signals().serialize()
# Handle retry and delta functionality
if self.delta and name:
from datachain.delta import delta_retry_update
# Delta chains must have delta_on defined (ensured by _as_delta method)
assert self._delta_on is not None, "Delta chain must have delta_on defined"
result_ds, dependencies, has_changes = delta_retry_update(
self,
namespace_name,
project_name,
name,
on=self._delta_on,
right_on=self._delta_result_on,
compare=self._delta_compare,
delta_retry=self._delta_retry,
)
if result_ds:
return self._evolve(
query=result_ds._query.save(
name=name,
version=version,
project=project,
feature_schema=schema,
dependencies=dependencies,
**kwargs,
)
)
if not has_changes:
# sources have not been changed so new version of resulting dataset
# would be the same as previous one. To avoid duplicating exact
# datasets, we won't create new version of it and we will return
# current latest version instead.
from .datasets import read_dataset
if job:
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
return read_dataset(
name, namespace=namespace_name, project=project_name, **kwargs
)
result = self._evolve(
query=self._query.save(
name=name,
version=version,
project=project,
description=description,
attrs=attrs,
feature_schema=schema,
update_version=update_version,
**kwargs,
)
)
if job:
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
return result
</code_context>
<issue_to_address>
**issue (code-quality):** Low code quality found in DataChain.save - 21% ([`low-code-quality`](https://docs.sourcery.ai/Reference/Default-Rules/comments/low-code-quality/))
<br/><details><summary>Explanation</summary>The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
- Reduce the function length by extracting pieces of functionality out into
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines.
- Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
sits together within the function rather than being scattered.</details>
</issue_to_address>
### Comment 10
<location> `tests/unit/lib/test_checkpoints.py:158-163` </location>
<code_context>
@pytest.mark.parametrize("reset_checkpoints", [True, False])
def test_checkpoints_multiple_runs(
test_session, monkeypatch, nums_dataset, reset_checkpoints
):
catalog = test_session.catalog
monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints)
# -------------- FIRST RUN -------------------
first_job_id = catalog.metastore.create_job("my-job", "echo 1;")
monkeypatch.setenv("DATACHAIN_JOB_ID", first_job_id)
dc.read_dataset("nums", session=test_session).save("nums1")
dc.read_dataset("nums", session=test_session).save("nums2")
with pytest.raises(DataChainError):
(
dc.read_dataset("nums", session=test_session)
.map(new=mapper_fail)
.save("nums3")
)
catalog.get_dataset("nums1")
catalog.get_dataset("nums2")
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("nums3")
# -------------- SECOND RUN -------------------
second_job_id = catalog.metastore.create_job(
"my-job", "echo 1;", parent_job_id=first_job_id
)
monkeypatch.setenv("DATACHAIN_JOB_ID", second_job_id)
dc.read_dataset("nums", session=test_session).save("nums1")
dc.read_dataset("nums", session=test_session).save("nums2")
dc.read_dataset("nums", session=test_session).save("nums3")
# -------------- THIRD RUN -------------------
third_job_id = catalog.metastore.create_job(
"my-job", "echo 1;", parent_job_id=second_job_id
)
monkeypatch.setenv("DATACHAIN_JOB_ID", third_job_id)
dc.read_dataset("nums", session=test_session).save("nums1")
dc.read_dataset("nums", session=test_session).filter(dc.C("num") > 1).save("nums2")
with pytest.raises(DataChainError):
(
dc.read_dataset("nums", session=test_session)
.map(new=mapper_fail)
.save("nums3")
)
# -------------- FOURTH RUN -------------------
fourth_job_id = catalog.metastore.create_job(
"my-job", "echo 1;", parent_job_id=third_job_id
)
monkeypatch.setenv("DATACHAIN_JOB_ID", fourth_job_id)
dc.read_dataset("nums", session=test_session).save("nums1")
dc.read_dataset("nums", session=test_session).filter(dc.C("num") > 1).save("nums2")
dc.read_dataset("nums", session=test_session).save("nums3")
num1_versions = len(catalog.get_dataset("nums1").versions)
num2_versions = len(catalog.get_dataset("nums2").versions)
num3_versions = len(catalog.get_dataset("nums3").versions)
if reset_checkpoints:
assert num1_versions == 4
assert num2_versions == 4
assert num3_versions == 2
else:
assert num1_versions == 1
assert num2_versions == 2
assert num3_versions == 2
assert len(list(catalog.metastore.list_checkpoints(first_job_id))) == 2
assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3
assert len(list(catalog.metastore.list_checkpoints(third_job_id))) == 2
assert len(list(catalog.metastore.list_checkpoints(fourth_job_id))) == 3
</code_context>
<issue_to_address>
**issue (code-quality):** Hoist repeated code outside conditional statement ([`hoist-statement-from-if`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/hoist-statement-from-if/))
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
tests/unit/lib/test_checkpoints.py
Outdated
| def test_checkpoints(test_session, monkeypatch, nums_dataset, reset_checkpoints): | ||
| catalog = test_session.catalog | ||
|
|
||
| monkeypatch.setenv("DATACHAIN_CHECKPOINTS_RESET", reset_checkpoints) | ||
|
|
||
| # -------------- FIRST RUN ------------------- | ||
| first_job_id = catalog.metastore.create_job("my-job", "echo 1;") | ||
| monkeypatch.setenv("DATACHAIN_JOB_ID", first_job_id) | ||
| dc.read_dataset("nums", session=test_session).save("nums1") | ||
| dc.read_dataset("nums", session=test_session).save("nums2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Consider adding assertions to verify that the skipped chain processing actually returns the correct dataset object.
Add assertions to confirm that the dataset returned after skipping chain processing matches the expected data and schema, ensuring the checkpoint mechanism's correctness.
Suggested implementation:
ds1 = catalog.get_dataset("nums1")
ds2 = catalog.get_dataset("nums2")
# Assert that the data matches expected values
assert list(ds1.data) == [1, 2, 3]
assert list(ds2.data) == [1, 2, 3]
# Assert that the schema matches expected schema
assert ds1.schema == ds2.schema
assert ds1.schema == {"num": int} or ds1.schema == {"num": "int"} # Adjust as needed for your schema representation
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("nums3")- If your dataset objects use different attribute names for data or schema, adjust
ds1.data,ds1.schemaaccordingly. - If your schema representation differs (e.g., uses type objects or strings), update the assertion to match your codebase's conventions.
src/datachain/lib/dc/datachain.py
Outdated
| Hash is calculated using previous job checkpoint hash (if exists) and | ||
| adding hash of this chain to produce new hash. | ||
| """ | ||
| last_checkpoint = max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have an API for the last checkpoint to simplify this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added special metastore api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New method is now being used here. Is that what you meant?
src/datachain/lib/dc/datachain.py
Outdated
|
|
||
| job = None | ||
| _hash = None | ||
| job_id = os.getenv("DATACHAIN_JOB_ID") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how will users provide this? what is the example of the intefrace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently provide this from Studio (and from that datachain query command but that one is not that important it seems). Note that this still doesn't touch local CLI / Python checkpoint implementation. That will be handled in #1361
src/datachain/lib/dc/datachain.py
Outdated
| ): | ||
| # if we find checkpoint with correct hash, we can skip chain calculation | ||
| catalog.metastore.create_checkpoint(job.id, _hash) | ||
| from .datasets import read_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to from .datasets import read_dataset in multiple places?
also the whole method is now very large - needs refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've put it in the top of the .save() now to have it in one place. I will think about refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do the refactoring, it is getting out of hand (and we ignore the linter complaint?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored it by calling helper methods for checkpoints, delta etc.
src/datachain/lib/dc/datachain.py
Outdated
| from .datasets import read_dataset | ||
|
|
||
| if job: | ||
| catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, fix linter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because linter is not smart enough to realize that _hash will 100% be defined if job is defined. The only thing I can do here is assert _hash before method call
src/datachain/lib/dc/datachain.py
Outdated
| str(uuid4()).encode() | ||
| ).hexdigest(), | ||
| ) | ||
| if job: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how will it work the first time we run it from Python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is not implementing for Python... Checkpoints with pure Python run will be implemented in #1361 . This PR relies on DATACHAIN_JOB_ID env to get the job. In Python implementation we will just need to create special function to get the job somehow (either by env var or in different way if run withlocal Python) - TBD.
So either way, job should be fetcher here and code should work. Was that you question? I'm not sure I understood it 100%.
src/datachain/lib/dc/datachain.py
Outdated
| if job_id: | ||
| job = metastore.get_job(job_id) # type: ignore[arg-type] | ||
| if not job: | ||
| raise JobNotFoundError(f"Job with id {job_id} not found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so, can job below be None? (I see if job: there) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, those are not under the same branch. btw I'm working on refactoring this whole method to separate checkpoints, delta, some helper methods etc. I think it will be much cleaner and clearer then
|
@ilongin fix tests, make sure everything is green before merging |
This PR adds a logic in
DataChain.save()to create checkpoints where needed and to skip processing chain steps if checkpoint for this chain already exists.Summary by Sourcery
Add job-level checkpointing to DataChain.save to reuse prior results and skip redundant processing, introduce parent_job_id tracking, and expand tests to cover checkpoint reset and chain modification scenarios
New Features:
Bug Fixes:
Enhancements:
Tests: