diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 06cff4e386..5e146eb57c 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -127,8 +127,10 @@ def onerror(target, exc): raise CheckoutErrorSuggestGit(target) from exc raise # noqa: PLE0704 - view = self.index.targets_view( - targets, recursive=recursive, with_deps=with_deps, onerror=onerror + from .index import index_from_targets + + view = index_from_targets( + self, targets=targets, recursive=recursive, with_deps=with_deps, onerror=onerror ) with ui.progress(unit="entry", desc="Building workspace index", leave=True) as pb: diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index d4f64e1d80..60571bbacd 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -41,6 +41,8 @@ def _collect_indexes( # noqa: PLR0913 onerror=None, push=False, ): + from .index import index_from_targets + indexes = {} collection_exc = None @@ -68,7 +70,8 @@ def outs_filter(out: "Output") -> bool: try: repo.config.merge(config) - idx = repo.index.targets_view( + idx = index_from_targets( + repo, targets, with_deps=with_deps, recursive=recursive, diff --git a/dvc/repo/index.py b/dvc/repo/index.py index ea872086df..2f042727e7 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from networkx import DiGraph from pygtrie import Trie + from typing_extensions import Self from dvc.dependency import Dependency from dvc.fs.callbacks import Callback @@ -318,32 +319,9 @@ def from_repo( repo: "Repo", onerror: Optional[Callable[[str, Exception], None]] = None, ) -> "Index": - stages = [] - metrics = {} - plots = {} - params = {} - artifacts = {} - datasets = {} - datasets_lock = {} - onerror = onerror or repo.stage_collection_error_handler - for _, idx in collect_files(repo, onerror=onerror): - stages.extend(idx.stages) - metrics.update(idx._metrics) - plots.update(idx._plots) - params.update(idx._params) - artifacts.update(idx._artifacts) - datasets.update(idx._datasets) - datasets_lock.update(idx._datasets_lock) - return cls( - repo, - stages=stages, - metrics=metrics, - plots=plots, - params=params, - artifacts=artifacts, - datasets=datasets, - datasets_lock=datasets_lock, + return cls.from_indexes( + repo, (idx for _, idx in collect_files(repo, onerror=onerror)) ) @classmethod @@ -364,7 +342,7 @@ def from_file(cls, repo: "Repo", path: str) -> "Index": else {}, ) - def update(self, stages: Iterable["Stage"]) -> "Index": + def update(self, stages: Iterable["Stage"]) -> "Self": stages = set(stages) # we remove existing stages with same hashes at first # and then re-add the new ones later. @@ -379,6 +357,36 @@ def update(self, stages: Iterable["Stage"]) -> "Index": datasets=self._datasets, ) + @classmethod + def from_indexes(cls, repo, idxs: Iterable["Self"]) -> "Self": + stages = [] + metrics = {} + plots = {} + params = {} + artifacts = {} + datasets = {} + datasets_lock = {} + + for idx in idxs: + stages.extend(idx.stages) + metrics.update(idx._metrics) + plots.update(idx._plots) + params.update(idx._params) + artifacts.update(idx._artifacts) + datasets.update(idx._datasets) + datasets_lock.update(idx._datasets_lock) + + return cls( + repo, + stages=stages, + metrics=metrics, + plots=plots, + params=params, + artifacts=artifacts, + datasets=datasets, + datasets_lock=datasets_lock, + ) + @cached_property def outs_trie(self) -> "Trie": from dvc.repo.trie import build_outs_trie @@ -735,6 +743,10 @@ def deps(self) -> Iterator["Dependency"]: for stage in self.stages: yield from stage.deps + @property + def index(self) -> "Index": + return self._index + @property def _filtered_outs(self) -> Iterator[tuple["Output", Optional[str]]]: for stage, filter_info in self._stage_infos: @@ -927,3 +939,51 @@ def _get_entry_hash_name( return src_entry.hash_info.name return DEFAULT_ALGORITHM + + +def index_from_targets( + repo: "Repo", + targets: Optional["TargetType"] = None, + stage_filter: Optional[Callable[["Stage"], bool]] = None, + outs_filter: Optional[Callable[["Output"], bool]] = None, + max_size: Optional[int] = None, + types: Optional[list[str]] = None, + with_deps: bool = False, + recursive: bool = False, + **kwargs: Any, +) -> "IndexView": + from dvc.stage.exceptions import StageFileDoesNotExistError, StageNotFound + from dvc.utils import parse_target + + index: Optional[Index] = None + if targets and all(targets) and not with_deps and not recursive: + indexes: list[Index] = [] + try: + for target in targets: + if not target: + continue + file, name = parse_target(target) + if file and not name: + index = Index.from_file(repo, file) + else: + stages = repo.stage.collect(target) + index = Index(repo, stages=list(stages)) + indexes.append(index) + except (StageFileDoesNotExistError, StageNotFound): + pass + else: + index = Index.from_indexes(repo, indexes) + targets = None + + if index is None: + index = repo.index + return index.targets_view( + targets, + stage_filter=stage_filter, + outs_filter=outs_filter, + max_size=max_size, + types=types, + recursive=recursive, + with_deps=with_deps, + **kwargs, + ) diff --git a/dvc/repo/push.py b/dvc/repo/push.py index acf31573ce..2de07ed018 100644 --- a/dvc/repo/push.py +++ b/dvc/repo/push.py @@ -152,10 +152,12 @@ def push( # noqa: PLR0913 finally: ws_idx = indexes.get("workspace") if ws_idx is not None: + from dvc.repo.index import IndexView + + _index = ws_idx.index if isinstance(ws_idx, IndexView) else ws_idx _update_meta( - self.index, + _index, targets=glob_targets(targets, glob=glob), - remote=remote, with_deps=with_deps, recursive=recursive, ) diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index bf59f557b1..4763fac4a4 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -8,7 +8,7 @@ from dulwich.porcelain import remove as git_rm from dvc.cli import main -from dvc.dvcfile import PROJECT_FILE, load_file +from dvc.dvcfile import PROJECT_FILE, FileMixin, SingleStageFile, load_file from dvc.exceptions import CheckoutError, CheckoutErrorSuggestGit, NoOutputOrStageError from dvc.fs import system from dvc.stage.exceptions import StageFileDoesNotExistError @@ -754,3 +754,20 @@ def test_checkout_cleanup_properly_on_untracked_nested_directories(tmp_dir, scm, dvc.checkout(force=True) assert (tmp_dir / "datasets").read_text() == {"dir1": {"file1": "file1"}} + + +def test_checkout_loads_specific_file(tmp_dir, dvc, mocker): + tmp_dir.dvc_gen("foo", "foo") + tmp_dir.dvc_gen("bar", "bar") + + (tmp_dir / "bar").unlink() + (tmp_dir / "foo").unlink() + + f = SingleStageFile(dvc, "foo.dvc") + + spy = mocker.spy(FileMixin, "_load") + assert dvc.checkout("foo.dvc") == {"added": ["foo"], "deleted": [], "modified": []} + + spy.assert_called_with(f) + assert (tmp_dir / "foo").exists() + assert not (tmp_dir / "bar").exists() diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index a506a205ce..ff7081ac75 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -7,6 +7,7 @@ import dvc_data from dvc.cli import main +from dvc.dvcfile import FileMixin, SingleStageFile from dvc.exceptions import CheckoutError from dvc.repo.open_repo import clean_repos from dvc.scm import CloneError @@ -678,3 +679,38 @@ def test_pull_granular_excluding_import_that_cannot_be_pulled( dvc.pull() with pytest.raises(CloneError, match="SCM error"): dvc.pull(imp_stage.addressing) + + +def test_loads_single_file(tmp_dir, dvc, local_remote, mocker): + tmp_dir.dvc_gen("foo", "foo") + tmp_dir.dvc_gen("bar", "bar") + + foo_dvcfile = SingleStageFile(dvc, "foo.dvc") + bar_dvcfile = SingleStageFile(dvc, "bar.dvc") + + spy = mocker.spy(FileMixin, "_load") + assert dvc.push("foo.dvc") == 1 + spy.assert_called_with(foo_dvcfile) + spy.reset_mock() + + assert dvc.push("bar.dvc") == 1 + spy.assert_called_with(bar_dvcfile) + spy.reset_mock() + + dvc.cache.local.clear() + (tmp_dir / "bar").unlink() + (tmp_dir / "foo").unlink() + + assert dvc.pull("foo.dvc") == { + "added": ["foo"], + "deleted": [], + "fetched": 1, + "modified": [], + } + spy.assert_called_with(foo_dvcfile) + assert (tmp_dir / "foo").exists() + assert not (tmp_dir / "bar").exists() + spy.reset_mock() + + assert dvc.fetch("bar.dvc") == 1 + spy.assert_called_with(bar_dvcfile) diff --git a/tests/func/test_repo_index.py b/tests/func/test_repo_index.py index ebcbee8538..34db5cbbdd 100644 --- a/tests/func/test_repo_index.py +++ b/tests/func/test_repo_index.py @@ -4,7 +4,8 @@ import pytest from pygtrie import Trie -from dvc.repo.index import Index +from dvc.exceptions import NoOutputOrStageError +from dvc.repo.index import Index, index_from_targets from dvc.stage import Stage @@ -402,3 +403,31 @@ def test_data_index(tmp_dir, dvc, local_cloud, erepo_dir): assert data.storage_map[("ifoo_partial",)].remote.read_only assert data.storage_map[("idir_partial",)].remote.read_only + + +def test_index_from_targets(tmp_dir, dvc): + stage1 = dvc.stage.add(name="stage1", cmd="echo hello") + stage2 = dvc.stage.add(name="stage2", cmd="echo hello world") + + (foo_stage,) = tmp_dir.dvc_gen("foo", "foo") + + index = index_from_targets(dvc, ["stage1"]) + assert index.stages == [stage1] + + index = index_from_targets(dvc, ["stage2"]) + assert index.stages == [stage2] + + index = index_from_targets(dvc, ["dvc.yaml"]) + assert set(index.stages) == {stage1, stage2} + + index = index_from_targets(dvc, ["dvc.yaml:stage2"]) + assert index.stages == [stage2] + + index = index_from_targets(dvc, ["foo.dvc"]) + assert index.stages == [foo_stage] + + index = index_from_targets(dvc, ["stage1", "foo.dvc"]) + assert set(index.stages) == {foo_stage, stage1} + + with pytest.raises(NoOutputOrStageError): + index = index_from_targets(dvc, ["not-existing-stage"])