Skip to content

Commit

Permalink
Make dbt CLI cold start load up to 8% faster with lazy loading of t…
Browse files Browse the repository at this point in the history
…ask modules and `agate` (#9744)
  • Loading branch information
dwreeves authored Mar 18, 2024
1 parent 58344f4 commit 29395ac
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240309-141054.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Improve dbt CLI speed
time: 2024-03-09T14:10:54.549618-05:00
custom:
Author: dwreeves
Issue: "4627"
9 changes: 6 additions & 3 deletions core/dbt/artifacts/schemas/run/v5/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import threading
from typing import Any, Optional, Iterable, Tuple, Sequence, Dict
import agate
from typing import Any, Optional, Iterable, Tuple, Sequence, Dict, TYPE_CHECKING
from dataclasses import dataclass, field
from datetime import datetime

Expand All @@ -22,9 +21,13 @@
from dbt_common.clients.system import write_json


if TYPE_CHECKING:
import agate


@dataclass
class RunResult(NodeResult):
agate_table: Optional[agate.Table] = field(
agate_table: Optional["agate.Table"] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)

Expand Down
53 changes: 35 additions & 18 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from dbt.artifacts.schemas.catalog import CatalogArtifact
from dbt.artifacts.schemas.run import RunExecutionResult
from dbt_common.events.base_types import EventMsg
from dbt.task.build import BuildTask
from dbt.task.clean import CleanTask
from dbt.task.clone import CloneTask
from dbt.task.compile import CompileTask
from dbt.task.debug import DebugTask
from dbt.task.deps import DepsTask
from dbt.task.docs.generate import GenerateTask
from dbt.task.docs.serve import ServeTask
from dbt.task.freshness import FreshnessTask
from dbt.task.init import InitTask
from dbt.task.list import ListTask
from dbt.task.retry import RetryTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask


@dataclass
Expand Down Expand Up @@ -211,6 +193,8 @@ def cli(ctx, **kwargs):
@requires.manifest
def build(ctx, **kwargs):
"""Run all seeds, models, snapshots, and tests in DAG order"""
from dbt.task.build import BuildTask

task = BuildTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -239,6 +223,8 @@ def build(ctx, **kwargs):
@requires.project
def clean(ctx, **kwargs):
"""Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)"""
from dbt.task.clean import CleanTask

task = CleanTask(ctx.obj["flags"], ctx.obj["project"])

results = task.run()
Expand Down Expand Up @@ -279,6 +265,8 @@ def docs(ctx, **kwargs):
@requires.manifest(write=False)
def docs_generate(ctx, **kwargs):
"""Generate the documentation website for your project"""
from dbt.task.docs.generate import GenerateTask

task = GenerateTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -309,6 +297,8 @@ def docs_generate(ctx, **kwargs):
@requires.runtime_config
def docs_serve(ctx, **kwargs):
"""Serve the documentation website for your project"""
from dbt.task.docs.serve import ServeTask

task = ServeTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -348,6 +338,8 @@ def docs_serve(ctx, **kwargs):
def compile(ctx, **kwargs):
"""Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the
target/ directory."""
from dbt.task.compile import CompileTask

task = CompileTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -387,6 +379,8 @@ def compile(ctx, **kwargs):
def show(ctx, **kwargs):
"""Generates executable SQL for a named resource or inline query, runs that SQL, and returns a preview of the
results. Does not materialize anything to the warehouse."""
from dbt.task.show import ShowTask

task = ShowTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand All @@ -413,6 +407,7 @@ def show(ctx, **kwargs):
@requires.preflight
def debug(ctx, **kwargs):
"""Show information on the current dbt environment and check dependencies, then test the database connection. Not to be confused with the --debug option which increases verbosity."""
from dbt.task.debug import DebugTask

task = DebugTask(
ctx.obj["flags"],
Expand Down Expand Up @@ -452,6 +447,8 @@ def deps(ctx, **kwargs):
There is a way to add new packages by providing an `--add-package` flag to deps command
which will allow user to specify a package they want to add in the format of packagename@version.
"""
from dbt.task.deps import DepsTask

flags = ctx.obj["flags"]
if flags.ADD_PACKAGE:
if not flags.ADD_PACKAGE["version"] and flags.SOURCE != "local":
Expand Down Expand Up @@ -481,6 +478,8 @@ def deps(ctx, **kwargs):
@requires.preflight
def init(ctx, **kwargs):
"""Initialize a new dbt project."""
from dbt.task.init import InitTask

task = InitTask(ctx.obj["flags"], None)

results = task.run()
Expand Down Expand Up @@ -514,6 +513,8 @@ def init(ctx, **kwargs):
@requires.manifest
def list(ctx, **kwargs):
"""List the resources in your project"""
from dbt.task.list import ListTask

task = ListTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -578,6 +579,8 @@ def parse(ctx, **kwargs):
@requires.manifest
def run(ctx, **kwargs):
"""Compile SQL and execute against the current target database."""
from dbt.task.run import RunTask

task = RunTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -608,6 +611,8 @@ def run(ctx, **kwargs):
@requires.runtime_config
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
from dbt.task.retry import RetryTask

# Retry will parse manifest inside the task after we consolidate the flags
task = RetryTask(
ctx.obj["flags"],
Expand Down Expand Up @@ -644,6 +649,8 @@ def retry(ctx, **kwargs):
@requires.postflight
def clone(ctx, **kwargs):
"""Create clones of selected nodes based on their location in the manifest provided to --state."""
from dbt.task.clone import CloneTask

task = CloneTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -676,6 +683,8 @@ def clone(ctx, **kwargs):
@requires.manifest
def run_operation(ctx, **kwargs):
"""Run the named macro with any supplied arguments."""
from dbt.task.run_operation import RunOperationTask

task = RunOperationTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -711,6 +720,8 @@ def run_operation(ctx, **kwargs):
@requires.manifest
def seed(ctx, **kwargs):
"""Load data from csv files into your data warehouse."""
from dbt.task.seed import SeedTask

task = SeedTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -743,6 +754,8 @@ def seed(ctx, **kwargs):
@requires.manifest
def snapshot(ctx, **kwargs):
"""Execute snapshots defined in your project"""
from dbt.task.snapshot import SnapshotTask

task = SnapshotTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -785,6 +798,8 @@ def source(ctx, **kwargs):
@requires.manifest
def freshness(ctx, **kwargs):
"""check the current freshness of the project's sources"""
from dbt.task.freshness import FreshnessTask

task = FreshnessTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down Expand Up @@ -825,6 +840,8 @@ def freshness(ctx, **kwargs):
@requires.manifest
def test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
from dbt.task.test import TestTask

task = TestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
Expand Down
15 changes: 10 additions & 5 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Iterable,
Mapping,
Tuple,
TYPE_CHECKING,
)

from typing_extensions import Protocol
Expand All @@ -22,7 +23,6 @@
from dbt_common.clients.jinja import MacroProtocol
from dbt_common.context import get_invocation_context
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt_common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
Expand Down Expand Up @@ -82,7 +82,8 @@
from dbt_common.utils import merge, AttrDict, cast_to_str
from dbt import selected_resources

import agate
if TYPE_CHECKING:
import agate


_MISSING = object()
Expand Down Expand Up @@ -851,8 +852,10 @@ def load_result(self, name: str) -> Optional[AttrDict]:

@contextmember()
def store_result(
self, name: str, response: Any, agate_table: Optional[agate.Table] = None
self, name: str, response: Any, agate_table: Optional["agate.Table"] = None
) -> str:
from dbt_common.clients import agate_helper

if agate_table is None:
agate_table = agate_helper.empty_table()

Expand All @@ -872,7 +875,7 @@ def store_raw_result(
message=Optional[str],
code=Optional[str],
rows_affected=Optional[str],
agate_table: Optional[agate.Table] = None,
agate_table: Optional["agate.Table"] = None,
) -> str:
response = AdapterResponse(_message=message, code=code, rows_affected=rows_affected)
return self.store_result(name, response, agate_table)
Expand Down Expand Up @@ -921,7 +924,9 @@ def try_or_compiler_error(
raise CompilationError(message_if_exception, self.model)

@contextmember()
def load_agate_table(self) -> agate.Table:
def load_agate_table(self) -> "agate.Table":
from dbt_common.clients import agate_helper

if not isinstance(self.model, SeedNode):
raise LoadAgateTableNotSeedError(self.model.resource_type, node=self.model)

Expand Down
11 changes: 7 additions & 4 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import re
import io
import agate
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union, TYPE_CHECKING

from dbt_common.exceptions import (
DbtRuntimeError,
Expand All @@ -19,6 +18,10 @@
from dbt_common.dataclass_schema import ValidationError


if TYPE_CHECKING:
import agate


class ContractBreakingChangeError(DbtRuntimeError):
CODE = 10016
MESSAGE = "Breaking Change to Contract"
Expand Down Expand Up @@ -1349,7 +1352,7 @@ def __init__(self, yaml_columns, sql_columns):
self.sql_columns = sql_columns
super().__init__(msg=self.get_message())

def get_mismatches(self) -> agate.Table:
def get_mismatches(self) -> "agate.Table":
# avoid a circular import
from dbt_common.clients.agate_helper import table_from_data_flat

Expand Down Expand Up @@ -1400,7 +1403,7 @@ def get_message(self) -> str:
"This model has an enforced contract, and its 'columns' specification is missing"
)

table: agate.Table = self.get_mismatches()
table: "agate.Table" = self.get_mismatches()
# Hack to get Agate table output as string
output = io.StringIO()
table.print_table(output=output, max_rows=None, max_column_width=50) # type: ignore
Expand Down
9 changes: 6 additions & 3 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import threading
import traceback
from datetime import datetime

import agate
from typing import TYPE_CHECKING

import dbt_common.exceptions
from dbt.adapters.factory import get_adapter
Expand All @@ -24,6 +23,10 @@
RESULT_FILE_NAME = "run_results.json"


if TYPE_CHECKING:
import agate


class RunOperationTask(ConfiguredTask):
def _get_macro_parts(self):
macro_name = self.args.macro
Expand All @@ -34,7 +37,7 @@ def _get_macro_parts(self):

return package_name, macro_name

def _run_unsafe(self, package_name, macro_name) -> agate.Table:
def _run_unsafe(self, package_name, macro_name) -> "agate.Table":
adapter = get_adapter(self.config)

macro_kwargs = self.args.args
Expand Down
13 changes: 8 additions & 5 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import agate
import daff
import io
import json
Expand All @@ -8,7 +7,7 @@
from dbt_common.events.format import pluralize
from dbt_common.dataclass_schema import dbtClassMixin
import threading
from typing import Dict, Any, Optional, Union, List
from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING

from .compile import CompileRunner
from .run import RunTask
Expand Down Expand Up @@ -37,6 +36,10 @@
from dbt_common.ui import green, red


if TYPE_CHECKING:
import agate


@dataclass
class UnitTestDiff(dbtClassMixin):
actual: List[Dict[str, Any]]
Expand Down Expand Up @@ -325,7 +328,7 @@ def _get_unit_test_agate_table(self, result_table, actual_or_expected: str):
return unit_test_table.select(columns)

def _get_daff_diff(
self, expected: agate.Table, actual: agate.Table, ordered: bool = False
self, expected: "agate.Table", actual: "agate.Table", ordered: bool = False
) -> daff.TableDiff:

expected_daff_table = daff.PythonTableView(list_rows_from_table(expected))
Expand Down Expand Up @@ -388,7 +391,7 @@ def get_runner_type(self, _):


# This was originally in agate_helper, but that was moved out into dbt_common
def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]:
def json_rows_from_table(table: "agate.Table") -> List[Dict[str, Any]]:
"Convert a table to a list of row dict objects"
output = io.StringIO()
table.to_json(path=output) # type: ignore
Expand All @@ -397,7 +400,7 @@ def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]:


# This was originally in agate_helper, but that was moved out into dbt_common
def list_rows_from_table(table: agate.Table) -> List[Any]:
def list_rows_from_table(table: "agate.Table") -> List[Any]:
"Convert a table to a list of lists, where the first element represents the header"
rows = [[col.name for col in table.columns]]
for row in table.rows:
Expand Down

0 comments on commit 29395ac

Please sign in to comment.