Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import base64
import copy
import datetime
import functools
import itertools
Expand Down Expand Up @@ -64,6 +65,21 @@
logger = get_logger(__name__)


@functools.lru_cache(maxsize=8)
def _load_manifest_cached(path: str, mtime: float) -> dict[str, Any]:
"""
Load and cache a parsed dbt manifest.json file.

When multiple DbtDag/DbtTaskGroup instances share the same manifest file,
this avoids re-parsing the JSON for each one during a single DagBag import cycle.

The cache is keyed on (path, mtime) so it auto-invalidates when the file changes.
maxsize=8 bounds memory for projects with multiple distinct manifests.
"""
with open(path) as fp:
return json.load(fp) or {}


def _normalize_path(path: str | None) -> str:
"""
Converts a potentially Windows path string into a Posix-friendly path.
Expand Down Expand Up @@ -1236,8 +1252,14 @@ def load_from_dbt_manifest(self) -> None:
if TYPE_CHECKING:
assert self.project.manifest_path is not None # pragma: no cover

with self.project.manifest_path.open() as fp:
manifest = json.load(fp) or {}
manifest_path = self.project.manifest_path
manifest_path_str = str(manifest_path)
is_local = not ("://" in manifest_path_str and not manifest_path_str.startswith("file://"))
if is_local:
manifest = copy.deepcopy(_load_manifest_cached(manifest_path_str, manifest_path.stat().st_mtime))
else:
with manifest_path.open() as fp:
manifest = json.load(fp)

project_path = self.execution_config.project_path
nodes = self._load_nodes_from_manifest_data(manifest, project_path)
Expand Down
145 changes: 145 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,151 @@ def test_load_from_dbt_manifest_handles_null_manifest(tmp_path):
assert dbt_graph.filtered_nodes == {}


def test_load_manifest_cached_shares_across_dags(tmp_path):
"""When multiple DbtGraph instances share the same manifest file, it is only parsed once."""
from cosmos.dbt.graph import _load_manifest_cached

# Clear cache from previous tests
_load_manifest_cached.cache_clear()

manifest = {
"metadata": {"project_name": "my_project"},
"nodes": {},
"sources": {},
"exposures": {},
}
manifest_file = tmp_path / "manifest.json"
manifest_file.write_text(json.dumps(manifest))
project_path = tmp_path / "project"
project_path.mkdir()
(project_path / "dbt_project.yml").write_text("name: my_project")

graphs = []
for _ in range(3):
project_config = ProjectConfig(manifest_path=manifest_file, project_name="my_project")
execution_config = ExecutionConfig(dbt_project_path=project_path)
dbt_graph = DbtGraph(
project=project_config,
execution_config=execution_config,
profile_config=ProfileConfig(
profile_name="test",
target_name="test",
profile_mapping=PostgresUserPasswordProfileMapping(conn_id="test", profile_args={}),
),
render_config=RenderConfig(load_method=LoadMode.DBT_MANIFEST),
)
dbt_graph.load_from_dbt_manifest()
graphs.append(dbt_graph)

# After 3 loads of the same file, the cache should have exactly 1 miss (first load)
cache_info = _load_manifest_cached.cache_info()
assert cache_info.hits >= 2
assert cache_info.misses >= 1

_load_manifest_cached.cache_clear()


def test_load_manifest_cached_invalidates_on_file_change(tmp_path):
"""Cache invalidates when the manifest file is modified (mtime changes)."""
import os

from cosmos.dbt.graph import _load_manifest_cached

_load_manifest_cached.cache_clear()

manifest_file = tmp_path / "manifest.json"
manifest_file.write_text(
json.dumps({"metadata": {"project_name": "p"}, "nodes": {}, "sources": {}, "exposures": {}})
)

path_str = str(manifest_file)
mtime1 = manifest_file.stat().st_mtime
result1 = _load_manifest_cached(path_str, mtime1)
assert result1["metadata"]["project_name"] == "p"

manifest_file.write_text(
json.dumps({"metadata": {"project_name": "q"}, "nodes": {}, "sources": {}, "exposures": {}})
)
# Force a different mtime deterministically (works even on filesystems with 1s granularity)
future_time = mtime1 + 10
os.utime(manifest_file, (future_time, future_time))
mtime2 = manifest_file.stat().st_mtime

assert mtime2 != mtime1
result2 = _load_manifest_cached(path_str, mtime2)
assert result2["metadata"]["project_name"] == "q"

_load_manifest_cached.cache_clear()


def test_load_manifest_cached_different_selectors_no_interference(tmp_path):
"""Two DbtGraphs with different select filters sharing a cached manifest produce independent filtered_nodes."""
from cosmos.dbt.graph import _load_manifest_cached

_load_manifest_cached.cache_clear()

manifest = {
"metadata": {"project_name": "my_project"},
"nodes": {
"model.my_project.alpha": {
"resource_type": "model",
"depends_on": {"nodes": []},
"original_file_path": "models/alpha.sql",
"package_name": "my_project",
"tags": ["daily"],
"config": {},
"fqn": ["my_project", "alpha"],
},
"model.my_project.beta": {
"resource_type": "model",
"depends_on": {"nodes": []},
"original_file_path": "models/beta.sql",
"package_name": "my_project",
"tags": ["hourly"],
"config": {},
"fqn": ["my_project", "beta"],
},
},
"sources": {},
"exposures": {},
}
manifest_file = tmp_path / "manifest.json"
manifest_file.write_text(json.dumps(manifest))
project_path = tmp_path / "project"
project_path.mkdir()
(project_path / "dbt_project.yml").write_text("name: my_project")
(project_path / "models").mkdir()
(project_path / "models" / "alpha.sql").write_text("select 1")
(project_path / "models" / "beta.sql").write_text("select 2")

def make_graph(select):
project_config = ProjectConfig(manifest_path=manifest_file, project_name="my_project")
execution_config = ExecutionConfig(dbt_project_path=project_path)
g = DbtGraph(
project=project_config,
execution_config=execution_config,
profile_config=ProfileConfig(
profile_name="test",
target_name="test",
profile_mapping=PostgresUserPasswordProfileMapping(conn_id="test", profile_args={}),
),
render_config=RenderConfig(load_method=LoadMode.DBT_MANIFEST, select=select),
)
g.load_from_dbt_manifest()
return g

graph_alpha = make_graph(["tag:daily"])
graph_beta = make_graph(["tag:hourly"])

assert "model.my_project.alpha" in graph_alpha.filtered_nodes
assert "model.my_project.beta" not in graph_alpha.filtered_nodes

assert "model.my_project.beta" in graph_beta.filtered_nodes
assert "model.my_project.alpha" not in graph_beta.filtered_nodes

_load_manifest_cached.cache_clear()


def test_load_from_dbt_manifest_resolves_package_path(tmp_path):
"""Package nodes get file_path under project_path/dbt_packages/<package_name>/."""
manifest = {
Expand Down
Loading