Skip to content

Commit

Permalink
Feature/add token authentication to internal api (#40899)
Browse files Browse the repository at this point in the history
* Add Authentication via Token to Internal API

* Add Authentication via Token to Internal API

* Review feedback, use other signing method

* Fix pytest after change of token auth

* Review feedback, implement token with option b / own config value

* Add configuration pytest for additional secret

* Review Feedback, direct commit

Co-authored-by: Vincent <[email protected]>

---------

Co-authored-by: Vincent <[email protected]>
  • Loading branch information
jscheffl and vincbeck authored Jul 22, 2024
1 parent 193605a commit 5b28933
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 42 deletions.
47 changes: 45 additions & 2 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,24 @@
from typing import TYPE_CHECKING, Any, Callable
from uuid import uuid4

from flask import Response

from flask import Response, request
from itsdangerous import BadSignature
from jwt import (
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidSignatureError,
)

from airflow.api_connexion.exceptions import PermissionDenied
from airflow.configuration import conf
from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _record_task_map_for_downstreams
from airflow.models.xcom_arg import _get_task_map_length
from airflow.sensors.base import _orig_start_date
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.session import create_session

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,6 +153,38 @@ def log_and_build_error_response(message, status):

def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
"""Handle Internal API /internal_api/v1/rpcapi endpoint."""
auth = request.headers.get("Authorization", "")
signer = JWTSigner(
secret_key=conf.get("core", "internal_api_secret_key"),
expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30),
audience="api",
)
try:
payload = signer.verify_token(auth)
signed_method = payload.get("method")
if not signed_method or signed_method != body.get("method"):
raise BadSignature("Invalid method in token authorization.")
except BadSignature:
raise PermissionDenied("Bad Signature. Please use only the tokens provided by the API.")
except InvalidAudienceError:
raise PermissionDenied("Invalid audience for the request", exc_info=True)
except InvalidSignatureError:
raise PermissionDenied("The signature of the request was wrong", exc_info=True)
except ImmatureSignatureError:
raise PermissionDenied("The signature of the request was sent from the future", exc_info=True)
except ExpiredSignatureError:
raise PermissionDenied(
"The signature of the request has expired. Make sure that all components "
"in your system have synchronized clocks.",
)
except InvalidIssuedAtError:
raise PermissionDenied(
"The request was issues in the future. Make sure that all components "
"in your system have synchronized clocks.",
)
except Exception:
raise PermissionDenied("Unable to authenticate API via token.")

log.debug("Got request")
json_rpc = body.get("jsonrpc")
if json_rpc != "2.0":
Expand Down
13 changes: 10 additions & 3 deletions airflow/api_internal/internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.settings import _ENABLE_AIP_44
from airflow.typing_compat import ParamSpec
from airflow.utils.jwt_signer import JWTSigner

PS = ParamSpec("PS")
RT = TypeVar("RT")
Expand Down Expand Up @@ -117,9 +118,6 @@ def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]:
See [AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API)
for more information .
"""
headers = {
"Content-Type": "application/json",
}
from requests.exceptions import ConnectionError

@tenacity.retry(
Expand All @@ -129,6 +127,15 @@ def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]:
before_sleep=tenacity.before_log(logger, logging.WARNING),
)
def make_jsonrpc_request(method_name: str, params_json: str) -> bytes:
signer = JWTSigner(
secret_key=conf.get("core", "internal_api_secret_key"),
expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30),
audience="api",
)
headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": method_name}),
}
data = {"jsonrpc": "2.0", "method": method_name, "params": params_json}
internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint()
response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers)
Expand Down
1 change: 0 additions & 1 deletion airflow/api_internal/openapi/internal_api_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ servers:
paths:
"/rpcapi":
post:
operationId: rpcapi
deprecated: false
x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint
operationId: internal_airflow_api
Expand Down
13 changes: 13 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,19 @@ core:
type: string
default: ~
example: 'http://localhost:8080'
internal_api_secret_key:
description: |
Secret key used to authenticate internal API clients to core. It should be as random as possible.
However, when running more than 1 instances of webserver / internal API services, make sure all
of them use the same ``secret_key`` otherwise calls will fail on authentication.
The authentication token generated using the secret key has a short expiry time though - make
sure that time on ALL the machines that you run airflow components on is synchronized
(for example using ntpd) otherwise you might get "forbidden" errors when the logs are accessed.
version_added: 2.10.0
type: string
sensitive: true
example: ~
default: "{SECRET_KEY}"
test_connection:
description: |
The ability to allow testing connections across Airflow UI, API and CLI.
Expand Down
80 changes: 64 additions & 16 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@

import pytest

from airflow.api_connexion.exceptions import PermissionDenied
from airflow.configuration import conf
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import _ENABLE_AIP_44
from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.state import State
from airflow.www import app
from tests.test_utils.config import conf_vars
Expand Down Expand Up @@ -82,6 +85,14 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
}
yield mock_initialize_method_map

@pytest.fixture
def signer(self) -> JWTSigner:
return JWTSigner(
secret_key=conf.get("core", "internal_api_secret_key"),
expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30),
audience="api",
)

@pytest.mark.parametrize(
"input_params, method_result, result_cmp_func, method_params",
[
Expand All @@ -108,17 +119,20 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
),
],
)
def test_method(self, input_params, method_result, result_cmp_func, method_params):
def test_method(self, input_params, method_result, result_cmp_func, method_params, signer: JWTSigner):
mock_test_method.return_value = method_result

headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}),
}
input_data = {
"jsonrpc": "2.0",
"method": TEST_METHOD_NAME,
"params": input_params,
}
response = self.client.post(
"/internal_api/v1/rpcapi",
headers={"Content-Type": "application/json"},
headers=headers,
data=json.dumps(input_data),
)
assert response.status_code == 200
Expand All @@ -131,33 +145,67 @@ def test_method(self, input_params, method_result, result_cmp_func, method_param

mock_test_method.assert_called_once_with(**method_params, session=mock.ANY)

def test_method_with_exception(self):
def test_method_with_exception(self, signer: JWTSigner):
headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}),
}
mock_test_method.side_effect = ValueError("Error!!!")
data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}}

response = self.client.post(
"/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data))
assert response.status_code == 500
assert response.data, b"Error executing method: test_method."
mock_test_method.assert_called_once()

def test_unknown_method(self):
data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist", "params": {}}
def test_unknown_method(self, signer: JWTSigner):
UNKNOWN_METHOD = "i-bet-it-does-not-exist"
headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}),
}
data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}}

response = self.client.post(
"/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data))
assert response.status_code == 400
assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.")
mock_test_method.assert_not_called()

def test_invalid_jsonrpc(self):
def test_invalid_jsonrpc(self, signer: JWTSigner):
headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}),
}
data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}}

response = self.client.post(
"/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data))
assert response.status_code == 400
assert response.data.startswith(b"Expected jsonrpc 2.0 request.")
mock_test_method.assert_not_called()

def test_missing_token(self):
mock_test_method.return_value = None

input_data = {
"jsonrpc": "2.0",
"method": TEST_METHOD_NAME,
"params": {},
}
with pytest.raises(PermissionDenied, match="Unable to authenticate API via token."):
self.client.post(
"/internal_api/v1/rpcapi",
headers={"Content-Type": "application/json"},
data=json.dumps(input_data),
)

def test_invalid_token(self, signer: JWTSigner):
headers = {
"Content-Type": "application/json",
"Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}),
}
data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}}

with pytest.raises(
PermissionDenied, match="Bad Signature. Please use only the tokens provided by the API."
):
self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data))
44 changes: 24 additions & 20 deletions tests/api_internal/test_internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ def test_remote_call(self, mock_requests):
"params": BaseSerialization.serialize({}),
}
)
mock_requests.post.assert_called_once_with(
url="http://localhost:8888/internal_api/v1/rpcapi",
data=expected_data,
headers={"Content-Type": "application/json"},
)
mock_requests.post.assert_called_once()
call_kwargs: dict = mock_requests.post.call_args.kwargs
assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi"
assert call_kwargs["data"] == expected_data
assert call_kwargs["headers"]["Content-Type"] == "application/json"
assert "Authorization" in call_kwargs["headers"]

@conf_vars(
{
Expand Down Expand Up @@ -192,11 +193,12 @@ def test_remote_call_with_params(self, mock_requests):
),
}
)
mock_requests.post.assert_called_once_with(
url="http://localhost:8888/internal_api/v1/rpcapi",
data=expected_data,
headers={"Content-Type": "application/json"},
)
mock_requests.post.assert_called_once()
call_kwargs: dict = mock_requests.post.call_args.kwargs
assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi"
assert call_kwargs["data"] == expected_data
assert call_kwargs["headers"]["Content-Type"] == "application/json"
assert "Authorization" in call_kwargs["headers"]

@conf_vars(
{
Expand Down Expand Up @@ -228,11 +230,12 @@ def test_remote_classmethod_call_with_params(self, mock_requests):
),
}
)
mock_requests.post.assert_called_once_with(
url="http://localhost:8888/internal_api/v1/rpcapi",
data=expected_data,
headers={"Content-Type": "application/json"},
)
mock_requests.post.assert_called_once()
call_kwargs: dict = mock_requests.post.call_args.kwargs
assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi"
assert call_kwargs["data"] == expected_data
assert call_kwargs["headers"]["Content-Type"] == "application/json"
assert "Authorization" in call_kwargs["headers"]

@conf_vars(
{
Expand Down Expand Up @@ -261,8 +264,9 @@ def test_remote_call_with_serialized_model(self, mock_requests):
"params": BaseSerialization.serialize({"ti": ti}, use_pydantic_models=True),
}
)
mock_requests.post.assert_called_once_with(
url="http://localhost:8888/internal_api/v1/rpcapi",
data=expected_data,
headers={"Content-Type": "application/json"},
)
mock_requests.post.assert_called_once()
call_kwargs: dict = mock_requests.post.call_args.kwargs
assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi"
assert call_kwargs["data"] == expected_data
assert call_kwargs["headers"]["Content-Type"] == "application/json"
assert "Authorization" in call_kwargs["headers"]
1 change: 1 addition & 0 deletions tests/core/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,7 @@ def test_sensitive_values():
sensitive_values = {
("database", "sql_alchemy_conn"),
("core", "fernet_key"),
("core", "internal_api_secret_key"),
("smtp", "smtp_password"),
("webserver", "secret_key"),
("secrets", "backend_kwargs"),
Expand Down

0 comments on commit 5b28933

Please sign in to comment.