diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 9d3ca90bcd..d869927c28 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,12 +1,12 @@ import importlib -from datetime import datetime import logging -from typing import Any, Optional +from typing import Any, Dict, Optional from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup -from cosmos.core.graph.entities import CosmosEntity, Group, Task + +from cosmos.core.graph.entities import Group, Task logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def __init__( super().__init__(*args, **kwargs) - entities: dict[str, Any] = {} + entities: Dict[str, Any] = {} # render all the entities in the group for ent in cosmos_group.entities: @@ -63,7 +63,7 @@ def __init__( kwargs["dag"] = dag super().__init__(*args, **kwargs) - entities: dict[str, Any] = {} + entities: Dict[str, Any] = {} # render all the entities in the group for ent in cosmos_group.entities: diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index 05e7351da9..0ef70060ac 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -1,9 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import List, Any - import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ class CosmosEntity: """ id: str - upstream_entity_ids: list[str] = field(default_factory=list) + upstream_entity_ids: List[str] = field(default_factory=list) def add_upstream(self, entity: CosmosEntity) -> None: """ @@ -35,7 +34,7 @@ class Group(CosmosEntity): A Group represents a collection of entities that are connected by dependencies. """ - entities: list[CosmosEntity] = field(default_factory=list) + entities: List[CosmosEntity] = field(default_factory=list) def add_entity(self, entity: CosmosEntity) -> None: """ @@ -58,4 +57,4 @@ class Task(CosmosEntity): """ operator_class: str = "airflow.operators.dummy.DummyOperator" - arguments: dict[str, Any] = field(default_factory=dict) + arguments: Dict[str, Any] = field(default_factory=dict) diff --git a/cosmos/providers/dbt/core/operators.py b/cosmos/providers/dbt/core/operators.py index a47911d95b..ad8bda0d20 100644 --- a/cosmos/providers/dbt/core/operators.py +++ b/cosmos/providers/dbt/core/operators.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Sequence +from typing import List, Sequence import yaml from airflow.compat.functools import cached_property @@ -81,7 +81,7 @@ def __init__( self, project_dir: str, conn_id: str, - base_cmd: str | list[str] = None, + base_cmd: str | List[str] = None, select: str = None, exclude: str = None, selector: str = None, @@ -251,7 +251,7 @@ def build_and_run_cmd(self, env: dict, cmd_flags: list = None): dbt_cmd.append(profile) ## set env vars - env = env | profile_vars + env = {**env, **profile_vars} result = self.run_command(cmd=dbt_cmd, env=env) return result diff --git a/cosmos/providers/dbt/dag.py b/cosmos/providers/dbt/dag.py index 7fe47d3c70..df94f9356a 100644 --- a/cosmos/providers/dbt/dag.py +++ b/cosmos/providers/dbt/dag.py @@ -1,13 +1,16 @@ """ This module contains a function to render a dbt project as an Airflow DAG. """ -from typing import Any, Literal +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +from typing import Any, Dict, List + from cosmos.core.airflow import CosmosDag -from cosmos.providers.dbt.parser.project import DbtProject -from .render import render_project -from airflow.models import DAG -from airflow.utils.decorators import apply_defaults +from .render import render_project class DbtDag(CosmosDag): @@ -29,11 +32,11 @@ def __init__( self, dbt_project_name: str, conn_id: str, - dbt_args: dict[str, Any] = {}, + dbt_args: Dict[str, Any] = {}, emit_datasets: bool = True, dbt_root_path: str = "/usr/local/airflow/dbt", test_behavior: Literal["none", "after_each", "after_all"] = "after_each", - dbt_tags: list[str] = [], + dbt_tags: List[str] = [], *args: Any, **kwargs: Any, ) -> None: diff --git a/cosmos/providers/dbt/parser/project.py b/cosmos/providers/dbt/parser/project.py index 7b00d3c49e..877954bad6 100644 --- a/cosmos/providers/dbt/parser/project.py +++ b/cosmos/providers/dbt/parser/project.py @@ -3,14 +3,14 @@ """ from __future__ import annotations -import os import logging -import yaml # type: ignore -import jinja2 - +import os from dataclasses import dataclass, field -from typing import Any, ClassVar from pathlib import Path +from typing import Dict + +import jinja2 +import yaml # type: ignore logger = logging.getLogger(__name__) @@ -116,7 +116,7 @@ class DbtProject: dbt_root_path: str = "/usr/local/airflow/dbt" # private instance variables for managing state - models: dict[str, DbtModel] = field(default_factory=dict) + models: Dict[str, DbtModel] = field(default_factory=dict) project_dir: Path = field(init=False) models_dir: Path = field(init=False) @@ -167,7 +167,7 @@ def _handle_config_file(self, path: Path) -> None: model_name = config.get("name") # if the model doesn't exist, we can't do anything - if not model_name in self.models: + if model_name not in self.models: continue # parse out the config fields we can recognize diff --git a/cosmos/providers/dbt/render.py b/cosmos/providers/dbt/render.py index a5b3334d13..14e76c10af 100644 --- a/cosmos/providers/dbt/render.py +++ b/cosmos/providers/dbt/render.py @@ -2,7 +2,13 @@ This module contains a function to render a dbt project into Cosmos entities. """ import logging -from typing import Any, Literal + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +from typing import Any, Dict, List from airflow.datasets import Dataset @@ -15,11 +21,11 @@ def render_project( dbt_project_name: str, dbt_root_path: str = "/usr/local/airflow/dbt", - task_args: dict[str, Any] = {}, + task_args: Dict[str, Any] = {}, test_behavior: Literal["none", "after_each", "after_all"] = "after_each", emit_datasets: bool = True, conn_id: str = "default_conn_id", - dbt_tags: list[str] = [], + dbt_tags: List[str] = [], ) -> Group: """ Turn a dbt project into a Group @@ -40,7 +46,7 @@ def render_project( ) base_group = Group(id=dbt_project_name) # this is the group that will be returned - entities: dict[ + entities: Dict[ str, CosmosEntity ] = {} # this is a dict of all the entities we create @@ -53,8 +59,8 @@ def render_project( if dbt_tags and not set(dbt_tags).intersection(model.config.tags): continue - run_args: dict[str, Any] = {**task_args, "models": model_name} - test_args: dict[str, Any] = {**task_args, "models": model_name} + run_args: Dict[str, Any] = {**task_args, "models": model_name} + test_args: Dict[str, Any] = {**task_args, "models": model_name} if emit_datasets: outlets = [ diff --git a/cosmos/providers/dbt/task_group.py b/cosmos/providers/dbt/task_group.py index 4924558852..14a4f6ad7d 100644 --- a/cosmos/providers/dbt/task_group.py +++ b/cosmos/providers/dbt/task_group.py @@ -1,13 +1,16 @@ """ This module contains a function to render a dbt project as an Airflow Task Group. """ -from typing import Any, Literal +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +from typing import Any, Dict, List + from cosmos.core.airflow import CosmosTaskGroup -from cosmos.providers.dbt.parser.project import DbtProject -from .render import render_project -from airflow.models import DAG -from airflow.utils.decorators import apply_defaults +from .render import render_project class DbtTaskGroup(CosmosTaskGroup): @@ -29,11 +32,11 @@ def __init__( self, dbt_project_name: str, conn_id: str, - dbt_args: dict[str, Any] = {}, + dbt_args: Dict[str, Any] = {}, emit_datasets: bool = True, dbt_root_path: str = "/usr/local/airflow/dbt", test_behavior: Literal["none", "after_each", "after_all"] = "after_each", - dbt_tags: list[str] = [], + dbt_tags: List[str] = [], *args: Any, **kwargs: Any, ) -> None: diff --git a/pyproject.toml b/pyproject.toml index f9ebebee9e..38f4b5e4e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ classifiers = [ dependencies = [ "apache-airflow>=2.4", "Jinja2>=3.0.0", + "typing-extensions; python_version < '3.8'", ] [project.optional-dependencies]