Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow_dbt_python/hooks/remote/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def get_remote(scheme: str, conn_id: Optional[str] = None) -> DbtRemoteHook:
from .s3 import DbtS3RemoteHook

remote_cls: Type[DbtRemoteHook] = DbtS3RemoteHook
elif scheme == "gs":
from .gcs import DbtGCSRemoteHook

remote_cls = DbtGCSRemoteHook
elif scheme in ("https", "git", "git+ssh", "ssh", "http"):
from .git import DbtGitRemoteHook

Expand Down
234 changes: 234 additions & 0 deletions airflow_dbt_python/hooks/remote/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""An implementation for an GCS remote for dbt."""

from __future__ import annotations

from pathlib import Path
from typing import Iterable, Optional

from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from google.cloud.storage import Blob

from airflow_dbt_python.hooks.remote import DbtRemoteHook
from airflow_dbt_python.utils.url import URL, URLLike


class DbtGCSRemoteHook(GCSHook, DbtRemoteHook):
Copy link
Owner

@tomasfarias tomasfarias Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: I'm not very familiar with GCS unfortunately, so I'll default on saying this is fine 👍, although I do have one comment about some (apparently) copy-pasted comment.

"""A dbt remote implementation for GCS.
This concrete remote class implements the DbtRemote interface by using GCS as a
storage for uploading and downloading dbt files to and from.
The DbtGCSRemoteHook subclasses Airflow's GCSHook to interact with GCS.
A connection id may be passed to set the connection to use with GCS.
"""

conn_type = "gcs"
hook_name = "dbt GCS Remote"

def __init__(self, *args, **kwargs):
"""Initialize a dbt remote for GCS."""
super().__init__(*args, **kwargs)

def _upload(
self,
source: URL,
destination: URL,
replace: bool = False,
delete_before: bool = False,
) -> None:
"""Upload one or more files under source URL to GCS.
Args:
source: A local URL where to fetch the file/s to push.
destination: An GCS URL where the file should be uploaded. The bucket
name and key prefix will be extracted by calling GCSHook.parse_gcs_url.
replace: Whether to replace existing GCS keys or not.
delete_before: Whether to delete the contents of destination before pushing.
"""
self.log.info("Uploading to GCS from %s to %s", source, destination)
self.log.debug("All files: %s", [s for s in source])

bucket_name, key = _parse_gcs_url(str(destination))

if delete_before:
keys = self.list(bucket_name, prefix=key)
for _key in keys:
self.delete(bucket_name, _key)

base_key = URL(f"gs://{bucket_name}/{key}")
for file_url in source:
self.log.debug("Uploading: %s", file_url)

if file_url.is_dir():
continue

gcs_key = base_key / file_url.relative_to(source)

self.load_file_handle_replace_error(
file_url=file_url,
key=str(gcs_key),
replace=replace,
)

def load_file_handle_replace_error(
self,
file_url: URLLike,
key: str,
bucket_name: Optional[str] = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
) -> bool:
"""Calls GCSHook.load_file but handles ValueError when replacing existing keys.
Will also log a warning whenever attempting to replace an existing key with
replace = False.
Returns:
True if no ValueError was raised, False otherwise.
"""
success = True

if bucket_name is None:
# We can't call load_file with bucket_name=None as it checks for the
# presence of the parameter to decide whether setting a bucket_name is
# required. By passing bucket_name=None, the parameter is set, and
# 'None' will be used as the bucket name.
bucket_name, key = _parse_gcs_url(str(key))

self.log.info("Loading file %s to GCS: %s", file_url, key)
try:
self.load_file(
str(file_url),
key,
bucket_name=bucket_name,
replace=replace,
encrypt=encrypt,
gzip=gzip,
)
except ValueError:
success = False
self.log.warning("Failed to load %s: key already exists in GCS.", key)

return success

def _download(
self,
source: URL,
destination: URL,
replace: bool = False,
delete_before: bool = False,
):
"""Download one or more files from a destination URL in GCS.
Lists all GCS keys that have source as a prefix to find what to download.
Args:
source: An GCS URL to a key prefix containing objects to download.
destination: A destination URL where to download the objects to. The
existing sub-directory hierarchy in GCS will be preserved.
replace: Indicates whether to replace existing files when downloading.
This flag is kept here to comply with the DbtRemote interface but its
ignored as files downloaded from GCS always overwrite local files.
delete_before: Delete destination directory before download.
"""
gcs_object_keys = self.iter_url(source)

if destination.exists() and delete_before is True:
for _file in destination:
_file.unlink()

if destination.is_dir():
destination.rmdir()

for gcs_object_key in gcs_object_keys:
self.log.info("GCSObjectKey: %s", gcs_object_key)
self.log.info("Source: %s", source)

bucket_name, object_name = _parse_gcs_url(str(gcs_object_key))
gcs_object = self.get_key(object_name, bucket_name)
gcs_object_url = URL(gcs_object_key)

if source != gcs_object_url and gcs_object_url.is_relative_to(source):
gcs_object_url = gcs_object_url.relative_to(source)

if gcs_object_url.suffix == "" and str(gcs_object_url).endswith("/"):
# Empty GCS files may also be confused with unwanted directories.
self.log.warning(
"A file with no name was found in GCS at %s", gcs_object
)
continue

if destination.is_dir():
destination_url = destination / gcs_object_url.path
else:
destination_url = destination

destination_url.parent.mkdir(parents=True, exist_ok=True)

gcs_object.download_to_filename(str(destination_url))

def iter_url(self, source: URL) -> Iterable[URL]:
"""Iterate over an GCS key given by a URL."""
bucket_name, key_prefix = _parse_gcs_url(str(source))

for key in self.list(bucket_name=bucket_name, prefix=key_prefix):
if key.endswith("//"):
# Sometimes, GCS files with empty names can appear, usually when using
# the UI. These empty GCS files may also be confused with directories.
continue
yield URL.from_parts(scheme="gs", netloc=bucket_name, path=key)

def get_key(self, key: str, bucket_name: str) -> Blob:
"""Get Blob object by key and bucket name."""
return self._get_blob(bucket_name, key)

def check_for_key(self, key: str, bucket_name: str) -> bool:
"""Checking if the key exists in the bucket."""
client = self.get_conn()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=key)
return blob.exists()

def load_file(
self,
filename: Path | str,
key: str,
bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
) -> None:
"""Load a local file to GCS.
:param filename: path to the file to load.
:param key: GCS key that will point to the file
:param bucket_name: Name of the bucket in which to store the file
:param replace: A flag to decide whether or not to overwrite the key
if it already exists. If replace is False and the key exists, an
error will be raised.
:param encrypt: If True, the file will be encrypted on the server-side
by GCS and will be stored in an encrypted form while at rest in GCS.
:param gzip: If True, the file will be compressed locally
"""
filename = str(filename)
if not replace and self.check_for_key(key, bucket_name):
raise ValueError(f"The key {key} already exists.")

metadata = {}
if encrypt:
raise NotImplementedError("Encrypt is not implemented in GCSHook.")

self.upload(
bucket_name=bucket_name,
object_name=key,
filename=filename,
gzip=gzip,
metadata=metadata,
)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket_name, "key": key}
)
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ build-backend = "poetry.core.masonry.api"
[project.optional-dependencies]
airflow-providers = [
"apache-airflow-providers-amazon>=3.0.0",
"apache-airflow-providers-google>=11.0.0",
"apache-airflow-providers-ssh>=3.0.0",
]
adapters = [
Expand All @@ -47,6 +48,9 @@ adapters = [
bigquery = [
"dbt-bigquery>=1.8.0,<2.0.0",
]
gcs = [
"apache-airflow-providers-google>=11.0.0",
]
git = [
"apache-airflow-providers-ssh>=3.0.0",
"dulwich>=0.21",
Expand Down Expand Up @@ -75,6 +79,7 @@ optional = true

[tool.poetry.group.dev.dependencies]
apache-airflow-providers-amazon = ">=3.0.0"
apache-airflow-providers-google = ">=11.0.0"
apache-airflow-providers-ssh = ">=3.0.0"
black = ">=22"
boto3-stubs = { extras = ["s3"], version = ">=1.26.8" }
Expand All @@ -91,6 +96,8 @@ pytest-postgresql = ">=5"
ruff = ">=0.0.254"
types-freezegun = ">=1.1.6"
types-PyYAML = ">=6.0.7"
pytest-mock = "^3.14.0"
mock-gcp = {git = "https://github.com/millin/mock-gcp.git", rev = "0d972df9b6cce164b49f09ec4417a4eb77beb960"}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: I skimmed through this, and it seems like upstream is not active. Could you publish your own fork in PyPI? I feel a bit uneasy about including a git source, so if publishing to PyPI is not possible I would like a comment briefly describing what are we getting from the fork that is not in upstream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference with upstream is that it mostly doesn't work 😄. I sent a PR to the original author but haven't gotten a response. But I liked the concept and used it as a basis.

I may try publishing this in PyPI (never done it, to be honest).


[tool.poetry.group.docs]
optional = true
Expand Down
69 changes: 65 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List
from unittest.mock import patch

import boto3
import pytest
from airflow import settings
from airflow.models.connection import Connection
from mockgcp.storage.client import MockClient as MockStorageClient
from moto import mock_aws
from pytest_postgresql.janitor import DatabaseJanitor

Expand Down Expand Up @@ -376,10 +378,7 @@ def mocked_s3_res():
@pytest.fixture
def s3_hook():
"""Provide an S3 for testing."""
try:
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
except ImportError:
from airflow.hooks.S3_hook import S3Hook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

return S3Hook()

Expand Down Expand Up @@ -411,6 +410,68 @@ def s3_bucket(mocked_s3_res, s3_hook):
assert keys is None or len(keys) == 0


@pytest.fixture
def gcp_conn_id():
"""Provide a GCS connection for testing."""
from airflow.providers.google.cloud.hooks.gcs import GCSHook

conn_id = GCSHook.default_conn_name

session = settings.Session()
existing = session.query(Connection).filter_by(conn_id=conn_id).first()
if existing is not None:
# Connections may exist from previous test run.
session.delete(existing)
session.commit()

conn = Connection(conn_id=conn_id, conn_type=GCSHook.conn_type)

session.add(conn)

session.commit()

yield conn_id

session.delete(conn)

session.commit()
session.close()


@pytest.fixture
def mocked_gcs_client():
"""Provide mock Google Storage Client for testing."""
with patch("google.cloud.storage.client.Client", MockStorageClient):
yield MockStorageClient(project="test-project")


@pytest.fixture
def gcs_hook(gcp_conn_id):
"""Provide an GCS for testing."""
from airflow_dbt_python.hooks.remote.gcs import DbtGCSRemoteHook

with patch(
"airflow.providers.google.cloud.hooks.gcs.GCSHook.get_credentials_and_project_id",
lambda x: ({}, "test-project"),
):
with patch("google.cloud.storage.Client", MockStorageClient):
yield DbtGCSRemoteHook()


@pytest.fixture
def gcs_bucket(mocked_gcs_client, gcs_hook):
"""Return a mocked gcs bucket for testing.

Bucket is cleaned after every use.
"""
bucket_name = "airflow-dbt-test-gcs-bucket"
bucket = mocked_gcs_client.create_bucket(bucket_name)

yield bucket_name

bucket.delete()


BROKEN_SQL = """
SELECT
field1 AS field1
Expand Down
Loading