Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ repos:
- --fuzzy-match-generates-todo
- id: insert-license
name: Add license for all Python files
exclude: ^\.github/.*$|^airflow/_vendor/
exclude: ^\.github/.*$|^airflow/_vendor/|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$
files: \.py$|\.pyi$
args:
- --comment-style
Expand Down Expand Up @@ -150,7 +150,7 @@ repos:
- id: black
name: Run Black (the uncompromising Python code formatter)
args: [--config=./pyproject.toml]
exclude: ^airflow/_vendor/
exclude: ^airflow/_vendor/|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$
- repo: https://github.com/asottile/blacken-docs
rev: v1.12.1
hooks:
Expand Down Expand Up @@ -194,7 +194,7 @@ repos:
exclude: ^airflow/_vendor/|^images/breeze/output.*$
- id: fix-encoding-pragma
name: Remove encoding header from python files
exclude: ^airflow/_vendor/
exclude: ^airflow/_vendor/|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$
args:
- --remove
- id: pretty-format-json
Expand All @@ -213,7 +213,7 @@ repos:
- id: pyupgrade
name: Upgrade Python code automatically
args: ["--py37-plus"]
exclude: ^airflow/_vendor/
exclude: ^airflow/_vendor/|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
hooks:
Expand All @@ -238,7 +238,7 @@ repos:
name: Run isort to sort imports in Python files
files: \.py$|\.pyi$
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^airflow/_vendor/|^build/.*$|^venv/.*$|^\.tox/.*$
exclude: ^airflow/_vendor/|^build/.*$|^venv/.*$|^\.tox/.*$|.*_pb2.py$|.*_pb2_grpc.py$
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
Expand Down Expand Up @@ -785,6 +785,12 @@ repos:
pass_filenames: false
files: ^Dockerfile$|^Dockerfile.ci$|^scripts/docker/.*$
require_serial: true
- id: grpc-proto-compile
name: Compile GRPC proto to python code
entry: ./scripts/ci/pre_commit/pre_commit_grpc_compile.py
language: python
files: ^airflow/.*\.proto
additional_dependencies: ['mypy', 'grpcio-tools', 'mypy-protobuf', 'types-protobuf']
- id: check-changelog-has-no-duplicates
name: Check changelogs for duplicate entries
language: python
Expand Down Expand Up @@ -860,7 +866,8 @@ repos:
language: python
entry: ./scripts/ci/pre_commit/pre_commit_mypy.py --namespace-packages
files: \.py$
exclude: ^provider_packages|^docs|^airflow/_vendor/|^airflow/providers|^airflow/migrations|^dev
exclude: "^provider_packages|^docs|^airflow/_vendor/|^airflow/providers\
|^airflow/migrations|^dev|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$"
require_serial: true
additional_dependencies: ['rich>=12.4.4', 'inputimeout']
- id: run-mypy
Expand All @@ -884,7 +891,7 @@ repos:
entry: ./scripts/ci/pre_commit/pre_commit_flake8.py
files: \.py$
pass_filenames: true
exclude: ^airflow/_vendor/
exclude: ^airflow/_vendor/|.*_pb2.py$|.*_pb2.pyi$|.*_pb2_grpc.py$
additional_dependencies: ['rich>=12.4.4', 'inputimeout']
- id: update-migration-references
name: Update migration ref doc
Expand Down
2 changes: 2 additions & 0 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ require Breeze Docker image to be build locally.
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| flynt | Run flynt string format converter for Python | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| grpc-proto-compile | Compile GRPC proto to python code | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| identity | Print input to the static check hooks for troubleshooting | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| insert-license | * Add license for all SQL files | |
Expand Down
85 changes: 84 additions & 1 deletion airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
# under the License.

import json
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from google.protobuf.internal.containers import RepeatedCompositeFieldContainer

from airflow.internal_api.grpc import internal_api_pb2
from airflow.internal_api.grpc.internal_api_pb2 import Callback

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
Expand Down Expand Up @@ -50,6 +55,28 @@ def from_json(cls, json_str: str):
json_object = json.loads(json_str)
return cls(**json_object)

def to_protobuf(
self,
) -> Callback:
raise NotImplementedError()

@staticmethod
def get_callbacks_from_protobuf(
callbacks: RepeatedCompositeFieldContainer[Callback],
) -> List["CallbackRequest"]:
result_callbacks: List[CallbackRequest] = []
for callback in callbacks:
type = callback.WhichOneof('callback_type')
if type == "task_request":
result_callbacks.append(TaskCallbackRequest.from_protobuf(callback.task_request))
elif type == "dag_request":
result_callbacks.append(DagCallbackRequest.from_protobuf(callback.dag_request))
elif type == 'sla_request':
result_callbacks.append(SlaCallbackRequest.from_protobuf(callback.sla_request))
else:
raise ValueError(f"Bad type: {type}")
return result_callbacks


class TaskCallbackRequest(CallbackRequest):
"""
Expand Down Expand Up @@ -86,6 +113,27 @@ def from_json(cls, json_str: str):
simple_ti = SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
return cls(simple_task_instance=simple_ti, **kwargs)

@classmethod
def from_protobuf(cls, request: internal_api_pb2.TaskCallbackRequest) -> "TaskCallbackRequest":
from airflow.models.taskinstance import SimpleTaskInstance

return cls(
full_filepath=request.full_filepath,
simple_task_instance=SimpleTaskInstance.from_protobuf(request.task_instance),
is_failure_callback=request.is_failure_callback,
msg=request.message,
)

def to_protobuf(self) -> Callback:
return Callback(
task_request=internal_api_pb2.TaskCallbackRequest(
full_filepath=self.full_filepath,
task_instance=self.simple_task_instance.to_protobuf(),
is_failure_callback=self.is_failure_callback,
message=self.msg,
)
)


class DagCallbackRequest(CallbackRequest):
"""
Expand All @@ -111,15 +159,50 @@ def __init__(
self.run_id = run_id
self.is_failure_callback = is_failure_callback

@classmethod
def from_protobuf(cls, request: internal_api_pb2.DagCallbackRequest) -> "DagCallbackRequest":
return cls(
full_filepath=request.full_filepath,
dag_id=request.dag_id,
run_id=request.run_id,
is_failure_callback=request.is_failure_callback,
msg=request.message,
)

def to_protobuf(self) -> Callback:
return Callback(
dag_request=internal_api_pb2.DagCallbackRequest(
full_filepath=self.full_filepath,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=self.is_failure_callback,
message=self.msg,
)
)


class SlaCallbackRequest(CallbackRequest):
"""
A class with information about the SLA callback to be executed.

:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param msg: Additional Message that can be used for logging
"""

def __init__(self, full_filepath: str, dag_id: str, msg: Optional[str] = None):
super().__init__(full_filepath, msg)
self.dag_id = dag_id

@classmethod
def from_protobuf(cls, request: internal_api_pb2.SlaCallbackRequest) -> "SlaCallbackRequest":
return cls(full_filepath=request.full_filepath, dag_id=request.dag_id, msg=request.message)

def to_protobuf(self) -> Callback:
return Callback(
sla_request=internal_api_pb2.SlaCallbackRequest(
full_filepath=self.full_filepath,
dag_id=self.dag_id,
message=self.msg,
)
)
55 changes: 55 additions & 0 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,29 @@ def string_lower_type(val):
("--include-dags",), help="If passed, DAG specific permissions will also be synced.", action="store_true"
)

# internal API client
ARG_NUM_REPEATS = Arg(
("--num-repeats",),
type=positive_int(allow_zero=False),
default=1,
help="The number of times to repeat the operation.",
)
ARG_USE_GRPC = Arg(
("--use-grpc",),
default=False,
action='store_true',
help="Whether to use GRPC for tests",
)
ARG_NUM_CALLBACKS = Arg(
("--num-callbacks",),
type=positive_int(allow_zero=False),
default=1,
help="The multiplier for number of callbacks.",
)
ARG_TEST = Arg(
('--test',), help='Choose test.', type=str, choices=['file_processor'], default="file_processor"
)

# triggerer
ARG_CAPACITY = Arg(
("--capacity",),
Expand Down Expand Up @@ -1461,6 +1484,33 @@ class GroupCommand(NamedTuple):
),
),
)

INTERNAL_API_COMMANDS = (
ActionCommand(
name='server',
help="Start an internal API server instance",
func=lazy_load_command('airflow.cli.commands.internal_api_server_command.internal_api_server'),
args=(
ARG_PID,
ARG_DAEMON,
ARG_STDOUT,
ARG_STDERR,
ARG_LOG_FILE,
),
),
ActionCommand(
name='test-client',
help="Test client for internal API",
func=lazy_load_command('airflow.cli.commands.internal_api_client_command.internal_api_client'),
args=(
ARG_NUM_REPEATS,
ARG_NUM_CALLBACKS,
ARG_USE_GRPC,
ARG_TEST,
),
),
)

CONNECTIONS_COMMANDS = (
ActionCommand(
name='get',
Expand Down Expand Up @@ -1811,6 +1861,11 @@ class GroupCommand(NamedTuple):
help="Database operations",
subcommands=DB_COMMANDS,
),
GroupCommand(
name='internal-api',
help='Internal API commands',
subcommands=INTERNAL_API_COMMANDS,
),
ActionCommand(
name='kerberos',
help="Start a kerberos ticket renewer",
Expand Down
111 changes: 111 additions & 0 deletions airflow/cli/commands/internal_api_client_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import logging
from pathlib import Path
from typing import List

import grpc
from kubernetes.client import models as k8s
from rich.console import Console

from airflow.callbacks.callback_requests import (
CallbackRequest,
DagCallbackRequest,
SlaCallbackRequest,
TaskCallbackRequest,
)
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
from airflow.utils import cli as cli_utils

console = Console(width=400, color_system="standard")


def process_example_files(num_callbacks: int, processor: DagFileProcessor):
example_dags_folder = Path(__file__).parents[3] / "airflow" / "example_dags"
callback_base = [
TaskCallbackRequest(
full_filepath=str(example_dags_folder / "example_bash_operator.py"),
simple_task_instance=SimpleTaskInstance(
task_id="run_this_last",
dag_id="example_python_operator",
run_id="run_id",
start_date=datetime.datetime.now(),
end_date=datetime.datetime.now(),
try_number=1,
map_index=1,
state="RUNNING",
executor_config={
"test": "test",
"pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"})),
},
pool="pool",
queue="queue",
key=TaskInstanceKey(dag_id="dag", task_id="task_id", run_id="run_id"),
run_as_user="user",
),
),
DagCallbackRequest(
full_filepath="file",
dag_id="example_bash_operator",
run_id="run_this_last",
is_failure_callback=False,
msg="Error Message",
),
SlaCallbackRequest(full_filepath="file", dag_id="example_bash_operator", msg="Error message"),
]
callbacks: List[CallbackRequest] = []
for i in range(num_callbacks):
callbacks.extend(callback_base)
sum_dags = 0
sum_errors = 0
for file in example_dags_folder.iterdir():
if file.is_file() and file.name.endswith(".py"):
dags, errors = processor.process_file(
file_path=str(file), callback_requests=callbacks, pickle_dags=True
)
sum_dags += dags
sum_errors += errors
console.print(f"Found {sum_dags} dags with {sum_errors} errors")
return sum_dags, sum_errors


def file_processor_test(num_callbacks: int, processor: DagFileProcessor, num_repeats: int):
total_dags = 0
total_errors = 0
for i in range(num_repeats):
dags, errors = process_example_files(num_callbacks, processor)
total_dags += dags
total_errors += errors
console.print(f"Total found {total_dags} dags with {total_errors} errors")


@cli_utils.action_cli
def internal_api_client(args):
use_grpc = args.use_grpc
num_repeats = args.num_repeats
processor = DagFileProcessor(
dag_ids=[],
log=logging.getLogger('airflow'),
use_grpc=use_grpc,
channel=grpc.insecure_channel('localhost:50051') if use_grpc else None,
)
if args.test == "file_processor":
file_processor_test(args.num_callbacks, processor=processor, num_repeats=num_repeats)
else:
console.print(f"[red]Wrong test {args.test}")
Loading