Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OpenLineage support for BigQueryToBigQueryOperator #44214

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 54 additions & 21 deletions providers/src/airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.providers.common.compat.openlineage.facet import Dataset

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
ColumnLineageDatasetFacet,
DocumentationDatasetFacet,
Fields,
Expand All @@ -41,50 +42,82 @@
BIGQUERY_URI = "bigquery"


def get_facets_from_bq_table(table: Table) -> dict[Any, Any]:
def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]:
"""Get facets from BigQuery table object."""
facets = {
"schema": SchemaDatasetFacet(
facets: dict[str, BaseFacet] = {}
if table.schema:
facets["schema"] = SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(
name=field.name, type=field.field_type, description=field.description
name=schema_field.name, type=schema_field.field_type, description=schema_field.description
)
for field in table.schema
for schema_field in table.schema
]
),
"documentation": DocumentationDatasetFacet(description=table.description or ""),
}
)
if table.description:
facets["documentation"] = DocumentationDatasetFacet(description=table.description)

return facets


def get_identity_column_lineage_facet(
field_names: list[str],
dest_field_names: list[str],
input_datasets: list[Dataset],
) -> ColumnLineageDatasetFacet:
) -> dict[str, ColumnLineageDatasetFacet]:
"""
Get column lineage facet.

Simple lineage will be created, where each source column corresponds to single destination column
in each input dataset and there are no transformations made.
Get column lineage facet for identity transformations.

This function generates a simple column lineage facet, where each destination column
consists of source columns of the same name from all input datasets that have that column.
The lineage assumes there are no transformations applied, meaning the columns retain their
identity between the source and destination datasets.

Args:
dest_field_names: A list of destination column names for which lineage should be determined.
input_datasets: A list of input datasets with schema facets.

Returns:
A dictionary containing a single key, `columnLineage`, mapped to a `ColumnLineageDatasetFacet`.
If no column lineage can be determined, an empty dictionary is returned - see Notes below.

Notes:
- If any input dataset lacks a schema facet, the function immediately returns an empty dictionary.
- If any field in the source dataset's schema is not present in the destination table,
the function returns an empty dictionary. The destination table can contain extra fields, but all
source columns should be present in the destination table.
- If none of the destination columns can be matched to input dataset columns, an empty
dictionary is returned.
- Extra columns in the destination table that do not exist in the input datasets are ignored and
skipped in the lineage facet, as they cannot be traced back to a source column.
- The function assumes there are no transformations applied, meaning the columns retain their
identity between the source and destination datasets.
"""
if field_names and not input_datasets:
raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.")
fields_sources: dict[str, list[Dataset]] = {}
for ds in input_datasets:
if not ds.facets or "schema" not in ds.facets:
return {}
for schema_field in ds.facets["schema"].fields: # type: ignore[attr-defined]
if schema_field.name not in dest_field_names:
return {}
fields_sources[schema_field.name] = fields_sources.get(schema_field.name, []) + [ds]

if not fields_sources:
return {}

column_lineage_facet = ColumnLineageDatasetFacet(
fields={
field: Fields(
field_name: Fields(
inputFields=[
InputField(namespace=dataset.namespace, name=dataset.name, field=field)
for dataset in input_datasets
InputField(namespace=dataset.namespace, name=dataset.name, field=field_name)
for dataset in source_datasets
],
transformationType="IDENTITY",
transformationDescription="identical",
)
for field in field_names
for field_name, source_datasets in fields_sources.items()
}
)
return column_lineage_facet
return {"columnLineage": column_lineage_facet}


@define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.hook: BigQueryHook | None = None
self._job_conf: dict = {}

def _prepare_job_configuration(self):
self.source_project_dataset_tables = (
Expand Down Expand Up @@ -154,39 +155,94 @@ def _prepare_job_configuration(self):

return configuration

def _submit_job(
self,
hook: BigQueryHook,
configuration: dict,
) -> str:
job = hook.insert_job(configuration=configuration, project_id=hook.project_id)
return job.job_id

def execute(self, context: Context) -> None:
self.log.info(
"Executing copy of %s into: %s",
self.source_project_dataset_tables,
self.destination_project_dataset_table,
)
hook = BigQueryHook(
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
self.hook = hook

if not hook.project_id:
if not self.hook.project_id:
raise ValueError("The project_id should be set")

configuration = self._prepare_job_configuration()
job_id = self._submit_job(hook=hook, configuration=configuration)
self._job_conf = self.hook.insert_job(
configuration=configuration, project_id=self.hook.project_id
).to_api_repr()

job = hook.get_job(job_id=job_id, location=self.location).to_api_repr()
conf = job["configuration"]["copy"]["destinationTable"]
dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"]
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=conf["datasetId"],
project_id=conf["projectId"],
table_id=conf["tableId"],
dataset_id=dest_table_info["datasetId"],
project_id=dest_table_info["projectId"],
table_id=dest_table_info["tableId"],
)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
ExternalQueryRunFacet,
)
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

if not self._job_conf:
self.log.debug("OpenLineage could not find BQ job configuration.")
return OperatorLineage()

bq_job_id = self._job_conf["jobReference"]["jobId"]
source_tables_info = self._job_conf["configuration"]["copy"]["sourceTables"]
dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"]

run_facets = {
"externalQuery": ExternalQueryRunFacet(externalQueryId=bq_job_id, source="bigquery"),
}

input_datasets = []
for in_table_info in source_tables_info:
table_id = ".".join(
(in_table_info["projectId"], in_table_info["datasetId"], in_table_info["tableId"])
)
table_object = self.hook.get_client().get_table(table_id)
input_datasets.append(
Dataset(
namespace=BIGQUERY_NAMESPACE, name=table_id, facets=get_facets_from_bq_table(table_object)
)
)

out_table_id = ".".join(
(dest_table_info["projectId"], dest_table_info["datasetId"], dest_table_info["tableId"])
)
out_table_object = self.hook.get_client().get_table(out_table_id)
output_dataset_facets = {
**get_facets_from_bq_table(out_table_object),
**get_identity_column_lineage_facet(
dest_field_names=[field.name for field in out_table_object.schema],
input_datasets=input_datasets,
),
}
output_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=out_table_id,
facets=output_dataset_facets,
)

return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets)
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def get_openlineage_facets_on_complete(self, task_instance):
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
Dataset,
ExternalQueryRunFacet,
Identifier,
Expand Down Expand Up @@ -322,12 +323,12 @@ def get_openlineage_facets_on_complete(self, task_instance):
facets=get_facets_from_bq_table(table_object),
)

output_dataset_facets = {
"schema": input_dataset.facets["schema"],
"columnLineage": get_identity_column_lineage_facet(
field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset]
),
}
output_dataset_facets: dict[str, BaseFacet] = get_identity_column_lineage_facet(
dest_field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset]
)
if "schema" in input_dataset.facets:
output_dataset_facets["schema"] = input_dataset.facets["schema"]

output_datasets = []
for uri in sorted(self.destination_cloud_storage_uris):
bucket, blob = _parse_gcs_url(uri)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,10 @@ def get_openlineage_facets_on_complete(self, task_instance):
source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
input_dataset_facets = {
"schema": output_dataset_facets["schema"],
}
input_dataset_facets = {}
if "schema" in output_dataset_facets:
input_dataset_facets["schema"] = output_dataset_facets["schema"]

input_datasets = []
for blob in sorted(source_objects):
additional_facets = {}
Expand All @@ -811,14 +812,16 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
input_datasets.append(dataset)

output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet(
field_names=[field.name for field in table_object.schema], input_datasets=input_datasets
)

output_dataset = Dataset(
namespace="bigquery",
name=str(table_object.reference),
facets=output_dataset_facets,
facets={
**output_dataset_facets,
**get_identity_column_lineage_facet(
dest_field_names=[field.name for field in table_object.schema],
input_datasets=input_datasets,
),
},
)

run_facets = {}
Expand Down
Loading