diff --git a/airflow/api_internal/__init__.py b/airflow/api_internal/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_internal/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_internal/endpoints/__init__.py b/airflow/api_internal/endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_internal/endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py new file mode 100644 index 0000000000000..90bb23e112840 --- /dev/null +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging + +from flask import Response + +from airflow.api_connexion.types import APIResponse +from airflow.dag_processing.processor import DagFileProcessor +from airflow.serialization.serialized_objects import BaseSerialization + +log = logging.getLogger(__name__) + + +def _build_methods_map(list) -> dict: + return {f"{func.__module__}.{func.__name__}": func for func in list} + + +METHODS_MAP = _build_methods_map( + [ + DagFileProcessor.update_import_errors, + ] +) + + +def internal_airflow_api( + body: dict, +) -> APIResponse: + """Handler for Internal API /internal_api/v1/rpcapi endpoint.""" + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + log.error("Not jsonrpc-2.0 request.") + return Response(response="Expected jsonrpc 2.0 request.", status=400) + + method_name = str(body.get("method")) + if method_name not in METHODS_MAP: + log.error("Unrecognized method: %s.", method_name) + return Response(response=f"Unrecognized method: {method_name}.", status=400) + + handler = METHODS_MAP[method_name] + params = {} + try: + if body.get("params"): + params_json = json.loads(str(body.get("params"))) + params = BaseSerialization.deserialize(params_json) + except Exception as err: + log.error("Error deserializing parameters.") + log.error(err) + return Response(response="Error deserializing parameters.", status=400) + + log.debug("Calling method %.", {method_name}) + try: + output = handler(**params) + output_json = BaseSerialization.serialize(output) + log.debug("Returning response") + return Response( + response=json.dumps(output_json or "{}"), headers={"Content-Type": "application/json"} + ) + except Exception as e: + log.error("Error when calling method %s.", method_name) + log.error(e) + return Response(response=f"Error executing method: {method_name}.", status=500) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py new file mode 100644 index 0000000000000..9de4d33c864d7 --- /dev/null +++ b/airflow/api_internal/internal_api_call.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +import json +from functools import wraps +from typing import Callable, TypeVar + +import requests + +from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.typing_compat import ParamSpec + +PS = ParamSpec("PS") +RT = TypeVar("RT") + + +class InternalApiConfig: + """Stores and caches configuration for Internal API.""" + + _initialized = False + _use_internal_api = False + _internal_api_endpoint = "" + + @staticmethod + def get_use_internal_api(): + if not InternalApiConfig._initialized: + InternalApiConfig._init_values() + return InternalApiConfig._use_internal_api + + @staticmethod + def get_internal_api_endpoint(): + if not InternalApiConfig._initialized: + InternalApiConfig._init_values() + return InternalApiConfig._internal_api_endpoint + + @staticmethod + def _init_values(): + use_internal_api = conf.getboolean("core", "database_access_isolation") + internal_api_endpoint = "" + if use_internal_api: + internal_api_url = conf.get("core", "internal_api_url") + internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi" + if not internal_api_endpoint.startswith("http://"): + raise AirflowConfigException("[core]internal_api_url must start with http://") + + InternalApiConfig._initialized = True + InternalApiConfig._use_internal_api = use_internal_api + InternalApiConfig._internal_api_endpoint = internal_api_endpoint + + +def internal_api_call(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: + """Decorator for methods which may be executed in database isolation mode. + + If [core]database_access_isolation is true then such method are not executed locally, + but instead RPC call is made to Database API (aka Internal API). This makes some components + decouple from direct Airflow database access. + Each decorated method must be present in METHODS list in airflow.api_internal.endpoints.rpc_api_endpoint. + Only static methods can be decorated. This decorator must be before "provide_session". + + See [AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API) + for more information . + """ + headers = { + "Content-Type": "application/json", + } + + def make_jsonrpc_request(method_name: str, params_json: str) -> bytes: + data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} + internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint() + response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers) + if response.status_code != 200: + raise AirflowException( + f"Got {response.status_code}:{response.reason} when sending the internal api request." + ) + return response.content + + @wraps(func) + def wrapper(*args, **kwargs) -> RT | None: + use_internal_api = InternalApiConfig.get_use_internal_api() + if not use_internal_api: + return func(*args, **kwargs) + + bound = inspect.signature(func).bind(*args, **kwargs) + arguments_dict = dict(bound.arguments) + if "session" in arguments_dict: + del arguments_dict["session"] + args_json = json.dumps(BaseSerialization.serialize(arguments_dict)) + method_name = f"{func.__module__}.{func.__name__}" + result = make_jsonrpc_request(method_name, args_json) + if result: + return BaseSerialization.deserialize(json.loads(result)) + else: + return None + + return wrapper diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml new file mode 100644 index 0000000000000..58ef96217969e --- /dev/null +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +openapi: 3.0.2 +info: + title: Airflow Internal API + version: 1.0.0 + description: | + This is Airflow Internal API - which is a proxy for components running + customer code for connecting to Airflow Database. + + It is not intended to be used by any external code. + + You can find more information in AIP-44 + https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API + + +servers: + - url: /internal_api/v1 + description: Airflow Internal API +paths: + "/rpcapi": + post: + operationId: rpcapi + deprecated: false + x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint + operationId: internal_airflow_api + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response + requestBody: + x-body-name: body + required: true + content: + application/json: + schema: + type: object + required: + - method + - jsonrpc + - params + properties: + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + method: + type: string + description: Method name + params: + title: Parameters + type: string +x-headers: [] +x-explorer-enabled: true +x-proxy-enabled: true +components: + schemas: + JsonRpcRequired: + type: object + required: + - method + - jsonrpc + properties: + method: + type: string + description: Method name + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + discriminator: + propertyName: method_name +tags: [] diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index b15981aab98a9..09b02fa84b502 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -418,6 +418,19 @@ type: string default: ~ example: '{"some_param": "some_value"}' + - name: database_access_isolation + description: (experimental) Whether components should use Airflow Internal API for DB connectivity. + version_added: 2.6.0 + type: boolean + example: ~ + default: "False" + - name: internal_api_url + description: | + (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. + version_added: 2.6.0 + type: string + default: ~ + example: 'http://localhost:8080' - name: database description: ~ @@ -1482,6 +1495,13 @@ type: string example: "dagrun_cleared,failed" default: ~ + - name: run_internal_api + description: | + Boolean for running Internal API in the webserver. + version_added: 2.6.0 + type: boolean + example: ~ + default: "False" - name: email description: | diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 146d0077a9060..1dfd0e98d2b97 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -240,6 +240,13 @@ daemon_umask = 0o077 # Example: dataset_manager_kwargs = {{"some_param": "some_value"}} # dataset_manager_kwargs = +# (experimental) Whether components should use Airflow Internal API for DB connectivity. +database_access_isolation = False + +# (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. +# Example: internal_api_url = http://localhost:8080 +# internal_api_url = + [database] # The SqlAlchemy connection string to the metadata database. # SqlAlchemy supports many different database engines. @@ -752,6 +759,9 @@ audit_view_excluded_events = gantt,landing_times,tries,duration,calendar,graph,g # Example: audit_view_included_events = dagrun_cleared,failed # audit_view_included_events = +# Boolean for running Internal API in the webserver. +run_internal_api = False + [email] # Configuration email backend and whether to diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 323ff94c5e63f..fbb2bd298d743 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import datetime import logging import multiprocessing import os @@ -24,7 +23,7 @@ import threading import time from contextlib import redirect_stderr, redirect_stdout, suppress -from datetime import timedelta +from datetime import datetime, timedelta from multiprocessing.connection import Connection as MultiprocessingConnection from typing import TYPE_CHECKING, Iterator @@ -33,6 +32,7 @@ from sqlalchemy.orm.session import Session from airflow import settings +from airflow.api_internal.internal_api_call import internal_api_call from airflow.callbacks.callback_requests import ( CallbackRequest, DagCallbackRequest, @@ -94,7 +94,7 @@ def __init__( # Whether the process is done running. self._done = False # When the process started. - self._start_time: datetime.datetime | None = None + self._start_time: datetime | None = None # This ID is use to uniquely name the process / thread that's launched # by this processor instance self._instance_id = DagFileProcessorProcess.class_creation_counter @@ -327,7 +327,7 @@ def result(self) -> tuple[int, int] | None: return self._result @property - def start_time(self) -> datetime.datetime: + def start_time(self) -> datetime: """Time when this started to process the file.""" if self._start_time is None: raise AirflowException("Tried to get start time before it started!") @@ -448,7 +448,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: .all() ) if slas: - sla_dates: list[datetime.datetime] = [sla.execution_date for sla in slas] + sla_dates: list[datetime] = [sla.execution_date for sla in slas] fetched_tis: list[TI] = ( session.query(TI) .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) @@ -524,17 +524,21 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: session.commit() @staticmethod - def update_import_errors(session: Session, dagbag: DagBag) -> None: + @internal_api_call + @provide_session + def update_import_errors( + file_last_changed: dict[str, datetime], import_errors: dict[str, str], session: Session = NEW_SESSION + ) -> None: """ Update any import errors to be displayed in the UI. For the DAGs in the given DagBag, record any associated import errors and clears errors for files that no longer have them. These are usually displayed through the Airflow UI so that users know that there are issues parsing DAGs. - :param session: session for ORM operations :param dagbag: DagBag containing DAGs with import errors + :param session: session for ORM operations """ - files_without_error = dagbag.file_last_changed - dagbag.import_errors.keys() + files_without_error = file_last_changed - import_errors.keys() # Clear the errors of the processed files # that no longer have errors @@ -547,7 +551,7 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: existing_import_error_files = [x.filename for x in session.query(errors.ImportError.filename).all()] # Add the errors of the processed files - for filename, stacktrace in dagbag.import_errors.items(): + for filename, stacktrace in import_errors.items(): if filename in existing_import_error_files: session.query(errors.ImportError).filter(errors.ImportError.filename == filename).update( dict(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace), @@ -754,7 +758,11 @@ def process_file( self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) - self.update_import_errors(session, dagbag) + DagFileProcessor.update_import_errors( + file_last_changed=dagbag.file_last_changed, + import_errors=dagbag.import_errors, + session=session, + ) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, @@ -781,7 +789,11 @@ def process_file( # Record import errors into the ORM try: - self.update_import_errors(session, dagbag) + DagFileProcessor.update_import_errors( + file_last_changed=dagbag.file_last_changed, + import_errors=dagbag.import_errors, + session=session, + ) except Exception: self.log.exception("Error logging import errors!") diff --git a/airflow/www/app.py b/airflow/www/app.py index 6a64401e1cecb..19d2831dfdbcb 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -48,6 +48,7 @@ from airflow.www.extensions.init_views import ( init_api_connexion, init_api_experimental, + init_api_internal, init_appbuilder_views, init_connection_form, init_error_handlers, @@ -149,6 +150,8 @@ def create_app(config=None, testing=False): init_connection_form() init_error_handlers(flask_app) init_api_connexion(flask_app) + if conf.getboolean("webserver", "run_internal_api", fallback=False): + init_api_internal(flask_app) init_api_experimental(flask_app) sync_appbuilder_roles(flask_app) diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 25d5d5898c642..86f94d2f229ce 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -220,6 +220,27 @@ def _handle_method_not_allowed(ex): app.extensions["csrf"].exempt(api_bp) +def init_api_internal(app: Flask) -> None: + """Initialize Internal API""" + if not conf.getboolean("webserver", "run_internal_api", fallback=False): + return + base_path = "/internal_api/v1" + + spec_dir = path.join(ROOT_APP_DIR, "api_internal", "openapi") + internal_app = App(__name__, specification_dir=spec_dir, skip_error_handlers=True) + internal_app.app = app + api_bp = internal_app.add_api( + specification="internal_api_v1.yaml", + base_path=base_path, + validate_responses=True, + strict_validation=True, + ).blueprint + # Like "api_bp.after_request", but the BP is already registered, so we have + # to register it in the app directly. + app.after_request_funcs.setdefault(api_bp.name, []).append(set_cors_headers_on_response) + app.extensions["csrf"].exempt(api_bp) + + def init_api_experimental(app): """Initialize Experimental API""" if not conf.getboolean("api", "enable_experimental_api", fallback=False): diff --git a/tests/api_internal/__init__.py b/tests/api_internal/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/api_internal/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_internal/endpoints/__init__.py b/tests/api_internal/endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/api_internal/endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py new file mode 100644 index 0000000000000..68f22fe6cc1a6 --- /dev/null +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock + +import pytest +from flask import Flask + +from airflow.api_internal.endpoints import rpc_api_endpoint +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.www import app +from tests.test_utils.config import conf_vars +from tests.test_utils.decorators import dont_initialize_flask_app_submodules + +TEST_METHOD_NAME = "test_method" + +mock_test_method = mock.MagicMock() + + +@pytest.fixture(scope="session") +def minimal_app_for_internal_api() -> Flask: + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_appbuilder", + "init_api_internal", + ] + ) + def factory() -> Flask: + with conf_vars({("webserver", "run_internal_api"): "true"}): + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + + return factory() + + +class TestRpcApiEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None: + rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method + self.app = minimal_app_for_internal_api + self.client = self.app.test_client() # type:ignore + mock_test_method.reset_mock() + mock_test_method.side_effect = None + + @pytest.mark.parametrize( + "input_data, method_result, method_params, expected_code", + [ + ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, "test_me", None, 200), + ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, None, None, 200), + ( + { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), + }, + ("dag_id_15", "fake-task", 1), + {"dag_id": 15, "task_id": "fake-task"}, + 200, + ), + ], + ) + def test_method(self, input_data, method_result, method_params, expected_code): + if method_result: + mock_test_method.return_value = method_result + + response = self.client.post( + "/internal_api/v1/rpcapi", + headers={"Content-Type": "application/json"}, + data=json.dumps(input_data), + ) + assert response.status_code == expected_code + if method_result: + response_data = BaseSerialization.deserialize(json.loads(response.data)) + assert response_data == method_result + if method_params: + mock_test_method.assert_called_once_with(**method_params) + else: + mock_test_method.assert_called_once() + + def test_method_with_exception(self): + mock_test_method.side_effect = ValueError("Error!!!") + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + assert response.status_code == 500 + assert response.data, b"Error executing method: test_method." + mock_test_method.assert_called_once() + + def test_unknown_method(self): + data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist", "params": ""} + + response = self.client.post( + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + assert response.status_code == 400 + assert response.data == b"Unrecognized method: i-bet-it-does-not-exist." + mock_test_method.assert_not_called() + + def test_invalid_jsonrpc(self): + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + assert response.status_code == 400 + assert response.data == b"Expected jsonrpc 2.0 request." + mock_test_method.assert_not_called() diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py new file mode 100644 index 0000000000000..579a7720cc8b0 --- /dev/null +++ b/tests/api_internal/test_internal_api_call.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import requests + +from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.serialization.serialized_objects import BaseSerialization +from tests.test_utils.config import conf_vars + + +class TestInternalApiConfig(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_disabled(self): + self.assertFalse(InternalApiConfig.get_use_internal_api()) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_enabled(self): + self.assertTrue(InternalApiConfig.get_use_internal_api()) + self.assertEqual( + InternalApiConfig.get_internal_api_endpoint(), + "http://localhost:8888/internal_api/v1/rpcapi", + ) + + +@internal_api_call +def fake_method() -> str: + return "local-call" + + +@internal_api_call +def fake_method_with_params(dag_id: str, task_id: int) -> str: + return f"local-call-with-params-{dag_id}-{task_id}" + + +class TestInternalApiCall(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_local_call(self, mock_requests): + result = fake_method() + + self.assertEqual(result, "local-call") + mock_requests.post.assert_not_called() + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method() + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method", + "params": json.dumps(BaseSerialization.serialize({})), + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal_api/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call_with_params(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method_with_params("fake-dag", task_id=123) + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method_with_params", + "params": json.dumps( + BaseSerialization.serialize( + { + "dag_id": "fake-dag", + "task_id": 123, + } + ) + ), + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal_api/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index d0b71b502c3d0..7d809834da4d5 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -39,6 +39,7 @@ def no_op(*args, **kwargs): "init_connection_form", "init_error_handlers", "init_api_connexion", + "init_api_internal", "init_api_experimental", "sync_appbuilder_roles", "init_jinja_globals",