Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dbt CLI cold start load up to 8% faster with lazy loading of task modules and agate #9744

Merged
merged 5 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240309-141054.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Improve dbt CLI speed
time: 2024-03-09T14:10:54.549618-05:00
custom:
Author: dwreeves
Issue: "4627"
9 changes: 6 additions & 3 deletions core/dbt/artifacts/schemas/run/v5/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import threading
from typing import Any, Optional, Iterable, Tuple, Sequence, Dict
import agate
from typing import Any, Optional, Iterable, Tuple, Sequence, Dict, TYPE_CHECKING
from dataclasses import dataclass, field
from datetime import datetime

Expand All @@ -22,9 +21,13 @@
from dbt_common.clients.system import write_json


if TYPE_CHECKING:
import agate

Check warning on line 25 in core/dbt/artifacts/schemas/run/v5/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/artifacts/schemas/run/v5/run.py#L25

Added line #L25 was not covered by tests


@dataclass
class RunResult(NodeResult):
agate_table: Optional[agate.Table] = field(
agate_table: Optional["agate.Table"] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)

Expand Down
53 changes: 35 additions & 18 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from dbt.artifacts.schemas.catalog import CatalogArtifact
from dbt.artifacts.schemas.run import RunExecutionResult
from dbt_common.events.base_types import EventMsg
from dbt.task.build import BuildTask
from dbt.task.clean import CleanTask
from dbt.task.clone import CloneTask
from dbt.task.compile import CompileTask
from dbt.task.debug import DebugTask
from dbt.task.deps import DepsTask
from dbt.task.docs.generate import GenerateTask
from dbt.task.docs.serve import ServeTask
from dbt.task.freshness import FreshnessTask
from dbt.task.init import InitTask
from dbt.task.list import ListTask
from dbt.task.retry import RetryTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask


@dataclass
Expand Down Expand Up @@ -211,6 +193,8 @@
@requires.manifest
def build(ctx, **kwargs):
"""Run all seeds, models, snapshots, and tests in DAG order"""
from dbt.task.build import BuildTask

task = BuildTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -239,6 +223,8 @@
@requires.project
def clean(ctx, **kwargs):
"""Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)"""
from dbt.task.clean import CleanTask

task = CleanTask(ctx.obj["flags"], ctx.obj["project"])

results = task.run()
Expand Down Expand Up @@ -279,6 +265,8 @@
@requires.manifest(write=False)
def docs_generate(ctx, **kwargs):
"""Generate the documentation website for your project"""
from dbt.task.docs.generate import GenerateTask

task = GenerateTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -309,6 +297,8 @@
@requires.runtime_config
def docs_serve(ctx, **kwargs):
"""Serve the documentation website for your project"""
from dbt.task.docs.serve import ServeTask

Check warning on line 300 in core/dbt/cli/main.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/main.py#L300

Added line #L300 was not covered by tests

task = ServeTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -348,6 +338,8 @@
def compile(ctx, **kwargs):
"""Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the
target/ directory."""
from dbt.task.compile import CompileTask

task = CompileTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -387,6 +379,8 @@
def show(ctx, **kwargs):
"""Generates executable SQL for a named resource or inline query, runs that SQL, and returns a preview of the
results. Does not materialize anything to the warehouse."""
from dbt.task.show import ShowTask

task = ShowTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand All @@ -413,6 +407,7 @@
@requires.preflight
def debug(ctx, **kwargs):
"""Show information on the current dbt environment and check dependencies, then test the database connection. Not to be confused with the --debug option which increases verbosity."""
from dbt.task.debug import DebugTask

task = DebugTask(
ctx.obj["flags"],
Expand Down Expand Up @@ -452,6 +447,8 @@
There is a way to add new packages by providing an `--add-package` flag to deps command
which will allow user to specify a package they want to add in the format of packagename@version.
"""
from dbt.task.deps import DepsTask

flags = ctx.obj["flags"]
if flags.ADD_PACKAGE:
if not flags.ADD_PACKAGE["version"] and flags.SOURCE != "local":
Expand Down Expand Up @@ -481,6 +478,8 @@
@requires.preflight
def init(ctx, **kwargs):
"""Initialize a new dbt project."""
from dbt.task.init import InitTask

task = InitTask(ctx.obj["flags"], None)

results = task.run()
Expand Down Expand Up @@ -514,6 +513,8 @@
@requires.manifest
def list(ctx, **kwargs):
"""List the resources in your project"""
from dbt.task.list import ListTask

task = ListTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -578,6 +579,8 @@
@requires.manifest
def run(ctx, **kwargs):
"""Compile SQL and execute against the current target database."""
from dbt.task.run import RunTask

task = RunTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -608,6 +611,8 @@
@requires.runtime_config
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
from dbt.task.retry import RetryTask

# Retry will parse manifest inside the task after we consolidate the flags
task = RetryTask(
ctx.obj["flags"],
Expand Down Expand Up @@ -644,6 +649,8 @@
@requires.postflight
def clone(ctx, **kwargs):
"""Create clones of selected nodes based on their location in the manifest provided to --state."""
from dbt.task.clone import CloneTask

task = CloneTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -676,6 +683,8 @@
@requires.manifest
def run_operation(ctx, **kwargs):
"""Run the named macro with any supplied arguments."""
from dbt.task.run_operation import RunOperationTask

task = RunOperationTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -711,6 +720,8 @@
@requires.manifest
def seed(ctx, **kwargs):
"""Load data from csv files into your data warehouse."""
from dbt.task.seed import SeedTask

task = SeedTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -743,6 +754,8 @@
@requires.manifest
def snapshot(ctx, **kwargs):
"""Execute snapshots defined in your project"""
from dbt.task.snapshot import SnapshotTask

task = SnapshotTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -785,6 +798,8 @@
@requires.manifest
def freshness(ctx, **kwargs):
"""check the current freshness of the project's sources"""
from dbt.task.freshness import FreshnessTask

task = FreshnessTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -825,6 +840,8 @@
@requires.manifest
def test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
from dbt.task.test import TestTask

task = TestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down
15 changes: 10 additions & 5 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Iterable,
Mapping,
Tuple,
TYPE_CHECKING,
)

from typing_extensions import Protocol
Expand All @@ -22,7 +23,6 @@
from dbt_common.clients.jinja import MacroProtocol
from dbt_common.context import get_invocation_context
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt_common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
Expand Down Expand Up @@ -82,7 +82,8 @@
from dbt_common.utils import merge, AttrDict, cast_to_str
from dbt import selected_resources

import agate
if TYPE_CHECKING:
import agate

Check warning on line 86 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L86

Added line #L86 was not covered by tests


_MISSING = object()
Expand Down Expand Up @@ -851,8 +852,10 @@

@contextmember()
def store_result(
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
self, name: str, response: Any, agate_table: Optional["agate.Table"] = None
) -> str:
from dbt_common.clients import agate_helper

if agate_table is None:
agate_table = agate_helper.empty_table()

Expand All @@ -872,7 +875,7 @@
message=Optional[str],
code=Optional[str],
rows_affected=Optional[str],
agate_table: Optional[agate.Table] = None,
agate_table: Optional["agate.Table"] = None,
) -> str:
response = AdapterResponse(_message=message, code=code, rows_affected=rows_affected)
return self.store_result(name, response, agate_table)
Expand Down Expand Up @@ -921,7 +924,9 @@
raise CompilationError(message_if_exception, self.model)

@contextmember()
def load_agate_table(self) -> agate.Table:
def load_agate_table(self) -> "agate.Table":
from dbt_common.clients import agate_helper

if not isinstance(self.model, SeedNode):
raise LoadAgateTableNotSeedError(self.model.resource_type, node=self.model)

Expand Down
11 changes: 7 additions & 4 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import re
import io
import agate
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union, TYPE_CHECKING

from dbt_common.exceptions import (
DbtRuntimeError,
Expand All @@ -19,6 +18,10 @@
from dbt_common.dataclass_schema import ValidationError


if TYPE_CHECKING:
import agate

Check warning on line 22 in core/dbt/exceptions.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/exceptions.py#L22

Added line #L22 was not covered by tests


class ContractBreakingChangeError(DbtRuntimeError):
CODE = 10016
MESSAGE = "Breaking Change to Contract"
Expand Down Expand Up @@ -1349,7 +1352,7 @@
self.sql_columns = sql_columns
super().__init__(msg=self.get_message())

def get_mismatches(self) -> agate.Table:
def get_mismatches(self) -> "agate.Table":
# avoid a circular import
from dbt_common.clients.agate_helper import table_from_data_flat

Expand Down Expand Up @@ -1400,7 +1403,7 @@
"This model has an enforced contract, and its 'columns' specification is missing"
)

table: agate.Table = self.get_mismatches()
table: "agate.Table" = self.get_mismatches()
# Hack to get Agate table output as string
output = io.StringIO()
table.print_table(output=output, max_rows=None, max_column_width=50) # type: ignore
Expand Down
9 changes: 6 additions & 3 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import threading
import traceback
from datetime import datetime

import agate
from typing import TYPE_CHECKING

import dbt_common.exceptions
from dbt.adapters.factory import get_adapter
Expand All @@ -24,6 +23,10 @@
RESULT_FILE_NAME = "run_results.json"


if TYPE_CHECKING:
import agate

Check warning on line 27 in core/dbt/task/run_operation.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run_operation.py#L27

Added line #L27 was not covered by tests


class RunOperationTask(ConfiguredTask):
def _get_macro_parts(self):
macro_name = self.args.macro
Expand All @@ -34,7 +37,7 @@

return package_name, macro_name

def _run_unsafe(self, package_name, macro_name) -> agate.Table:
def _run_unsafe(self, package_name, macro_name) -> "agate.Table":
adapter = get_adapter(self.config)

macro_kwargs = self.args.args
Expand Down
13 changes: 8 additions & 5 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import agate
import daff
import io
import json
Expand All @@ -8,7 +7,7 @@
from dbt_common.events.format import pluralize
from dbt_common.dataclass_schema import dbtClassMixin
import threading
from typing import Dict, Any, Optional, Union, List
from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING

from .compile import CompileRunner
from .run import RunTask
Expand Down Expand Up @@ -37,6 +36,10 @@
from dbt_common.ui import green, red


if TYPE_CHECKING:
import agate

Check warning on line 40 in core/dbt/task/test.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/test.py#L40

Added line #L40 was not covered by tests


@dataclass
class UnitTestDiff(dbtClassMixin):
actual: List[Dict[str, Any]]
Expand Down Expand Up @@ -325,7 +328,7 @@
return unit_test_table.select(columns)

def _get_daff_diff(
self, expected: agate.Table, actual: agate.Table, ordered: bool = False
self, expected: "agate.Table", actual: "agate.Table", ordered: bool = False
) -> daff.TableDiff:

expected_daff_table = daff.PythonTableView(list_rows_from_table(expected))
Expand Down Expand Up @@ -388,7 +391,7 @@


# This was originally in agate_helper, but that was moved out into dbt_common
def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]:
def json_rows_from_table(table: "agate.Table") -> List[Dict[str, Any]]:
"Convert a table to a list of row dict objects"
output = io.StringIO()
table.to_json(path=output) # type: ignore
Expand All @@ -397,7 +400,7 @@


# This was originally in agate_helper, but that was moved out into dbt_common
def list_rows_from_table(table: agate.Table) -> List[Any]:
def list_rows_from_table(table: "agate.Table") -> List[Any]:
"Convert a table to a list of lists, where the first element represents the header"
rows = [[col.name for col in table.columns]]
for row in table.rows:
Expand Down
Loading