Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import importlib
from datetime import datetime
import logging
from typing import Any, Optional
from typing import Any, Dict, Optional

from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from cosmos.core.graph.entities import CosmosEntity, Group, Task

from cosmos.core.graph.entities import Group, Task

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +28,7 @@ def __init__(

super().__init__(*args, **kwargs)

entities: dict[str, Any] = {}
entities: Dict[str, Any] = {}

# render all the entities in the group
for ent in cosmos_group.entities:
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
kwargs["dag"] = dag
super().__init__(*args, **kwargs)

entities: dict[str, Any] = {}
entities: Dict[str, Any] = {}

# render all the entities in the group
for ent in cosmos_group.entities:
Expand Down
11 changes: 5 additions & 6 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Any

import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List

logger = logging.getLogger(__name__)

Expand All @@ -18,7 +17,7 @@ class CosmosEntity:
"""

id: str
upstream_entity_ids: list[str] = field(default_factory=list)
upstream_entity_ids: List[str] = field(default_factory=list)

def add_upstream(self, entity: CosmosEntity) -> None:
"""
Expand All @@ -35,7 +34,7 @@ class Group(CosmosEntity):
A Group represents a collection of entities that are connected by dependencies.
"""

entities: list[CosmosEntity] = field(default_factory=list)
entities: List[CosmosEntity] = field(default_factory=list)

def add_entity(self, entity: CosmosEntity) -> None:
"""
Expand All @@ -58,4 +57,4 @@ class Task(CosmosEntity):
"""

operator_class: str = "airflow.operators.dummy.DummyOperator"
arguments: dict[str, Any] = field(default_factory=dict)
arguments: Dict[str, Any] = field(default_factory=dict)
6 changes: 3 additions & 3 deletions cosmos/providers/dbt/core/operators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from typing import Sequence
from typing import List, Sequence

import yaml
from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self,
project_dir: str,
conn_id: str,
base_cmd: str | list[str] = None,
base_cmd: str | List[str] = None,
select: str = None,
exclude: str = None,
selector: str = None,
Expand Down Expand Up @@ -251,7 +251,7 @@ def build_and_run_cmd(self, env: dict, cmd_flags: list = None):
dbt_cmd.append(profile)

## set env vars
env = env | profile_vars
env = {**env, **profile_vars}
result = self.run_command(cmd=dbt_cmd, env=env)
return result

Expand Down
17 changes: 10 additions & 7 deletions cosmos/providers/dbt/dag.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
This module contains a function to render a dbt project as an Airflow DAG.
"""
from typing import Any, Literal
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
Comment thread
jbandoro marked this conversation as resolved.

from typing import Any, Dict, List

from cosmos.core.airflow import CosmosDag
from cosmos.providers.dbt.parser.project import DbtProject
from .render import render_project

from airflow.models import DAG
from airflow.utils.decorators import apply_defaults
from .render import render_project


class DbtDag(CosmosDag):
Expand All @@ -29,11 +32,11 @@ def __init__(
self,
dbt_project_name: str,
conn_id: str,
dbt_args: dict[str, Any] = {},
dbt_args: Dict[str, Any] = {},
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
dbt_tags: list[str] = [],
dbt_tags: List[str] = [],
*args: Any,
**kwargs: Any,
) -> None:
Expand Down
14 changes: 7 additions & 7 deletions cosmos/providers/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
"""
from __future__ import annotations

import os
import logging
import yaml # type: ignore
import jinja2

import os
from dataclasses import dataclass, field
from typing import Any, ClassVar
from pathlib import Path
from typing import Dict

import jinja2
import yaml # type: ignore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,7 +116,7 @@ class DbtProject:
dbt_root_path: str = "/usr/local/airflow/dbt"

# private instance variables for managing state
models: dict[str, DbtModel] = field(default_factory=dict)
models: Dict[str, DbtModel] = field(default_factory=dict)
project_dir: Path = field(init=False)
models_dir: Path = field(init=False)

Expand Down Expand Up @@ -167,7 +167,7 @@ def _handle_config_file(self, path: Path) -> None:
model_name = config.get("name")

# if the model doesn't exist, we can't do anything
if not model_name in self.models:
if model_name not in self.models:
continue

# parse out the config fields we can recognize
Expand Down
18 changes: 12 additions & 6 deletions cosmos/providers/dbt/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
This module contains a function to render a dbt project into Cosmos entities.
"""
import logging
from typing import Any, Literal

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

from typing import Any, Dict, List

from airflow.datasets import Dataset

Expand All @@ -15,11 +21,11 @@
def render_project(
dbt_project_name: str,
dbt_root_path: str = "/usr/local/airflow/dbt",
task_args: dict[str, Any] = {},
task_args: Dict[str, Any] = {},
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
emit_datasets: bool = True,
conn_id: str = "default_conn_id",
dbt_tags: list[str] = [],
dbt_tags: List[str] = [],
) -> Group:
"""
Turn a dbt project into a Group
Expand All @@ -40,7 +46,7 @@ def render_project(
)

base_group = Group(id=dbt_project_name) # this is the group that will be returned
entities: dict[
entities: Dict[
str, CosmosEntity
] = {} # this is a dict of all the entities we create

Expand All @@ -53,8 +59,8 @@ def render_project(
if dbt_tags and not set(dbt_tags).intersection(model.config.tags):
continue

run_args: dict[str, Any] = {**task_args, "models": model_name}
test_args: dict[str, Any] = {**task_args, "models": model_name}
run_args: Dict[str, Any] = {**task_args, "models": model_name}
test_args: Dict[str, Any] = {**task_args, "models": model_name}

if emit_datasets:
outlets = [
Expand Down
17 changes: 10 additions & 7 deletions cosmos/providers/dbt/task_group.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
This module contains a function to render a dbt project as an Airflow Task Group.
"""
from typing import Any, Literal
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

from typing import Any, Dict, List

from cosmos.core.airflow import CosmosTaskGroup
from cosmos.providers.dbt.parser.project import DbtProject
from .render import render_project

from airflow.models import DAG
from airflow.utils.decorators import apply_defaults
from .render import render_project


class DbtTaskGroup(CosmosTaskGroup):
Expand All @@ -29,11 +32,11 @@ def __init__(
self,
dbt_project_name: str,
conn_id: str,
dbt_args: dict[str, Any] = {},
dbt_args: Dict[str, Any] = {},
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
dbt_tags: list[str] = [],
dbt_tags: List[str] = [],
*args: Any,
**kwargs: Any,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ classifiers = [
dependencies = [
"apache-airflow>=2.4",
"Jinja2>=3.0.0",
"typing-extensions; python_version < '3.8'",
]

[project.optional-dependencies]
Expand Down