|
18 | 18 | from __future__ import annotations
|
19 | 19 |
|
20 | 20 | import logging
|
21 |
| -from functools import partial |
22 |
| -from typing import Any, cast |
| 21 | +import os |
| 22 | +from functools import wraps |
| 23 | +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, cast |
23 | 24 |
|
| 25 | +import kerberos |
| 26 | +from flask import Response, g, make_response, request |
24 | 27 | from requests_kerberos import HTTPKerberosAuth
|
25 | 28 |
|
26 |
| -from airflow.api.auth.backend.kerberos_auth import ( |
27 |
| - init_app as base_init_app, |
28 |
| - requires_authentication as base_requires_authentication, |
29 |
| -) |
| 29 | +from airflow.configuration import conf |
30 | 30 | from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride
|
| 31 | +from airflow.utils.net import getfqdn |
31 | 32 | from airflow.www.extensions.init_auth_manager import get_auth_manager
|
32 | 33 |
|
| 34 | +if TYPE_CHECKING: |
| 35 | + from airflow.auth.managers.models.base_user import BaseUser |
| 36 | + |
33 | 37 | log = logging.getLogger(__name__)
|
34 | 38 |
|
35 | 39 | CLIENT_AUTH: tuple[str, str] | Any | None = HTTPKerberosAuth(service="airflow")
|
36 | 40 |
|
37 | 41 |
|
| 42 | +class KerberosService: |
| 43 | + """Class to keep information about the Kerberos Service initialized.""" |
| 44 | + |
| 45 | + def __init__(self): |
| 46 | + self.service_name = None |
| 47 | + |
| 48 | + |
| 49 | +class _KerberosAuth(NamedTuple): |
| 50 | + return_code: int | None |
| 51 | + user: str = "" |
| 52 | + token: str | None = None |
| 53 | + |
| 54 | + |
| 55 | +# Stores currently initialized Kerberos Service |
| 56 | +_KERBEROS_SERVICE = KerberosService() |
| 57 | + |
| 58 | + |
| 59 | +def init_app(app): |
| 60 | + """Initialize application with kerberos.""" |
| 61 | + hostname = app.config.get("SERVER_NAME") |
| 62 | + if not hostname: |
| 63 | + hostname = getfqdn() |
| 64 | + log.info("Kerberos: hostname %s", hostname) |
| 65 | + |
| 66 | + service = "airflow" |
| 67 | + |
| 68 | + _KERBEROS_SERVICE.service_name = f"{service}@{hostname}" |
| 69 | + |
| 70 | + if "KRB5_KTNAME" not in os.environ: |
| 71 | + os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab") |
| 72 | + |
| 73 | + try: |
| 74 | + log.info("Kerberos init: %s %s", service, hostname) |
| 75 | + principal = kerberos.getServerPrincipalDetails(service, hostname) |
| 76 | + except kerberos.KrbError as err: |
| 77 | + log.warning("Kerberos: %s", err) |
| 78 | + else: |
| 79 | + log.info("Kerberos API: server is %s", principal) |
| 80 | + |
| 81 | + |
| 82 | +def _unauthorized(): |
| 83 | + """Indicate that authorization is required.""" |
| 84 | + return Response("Unauthorized", 401, {"WWW-Authenticate": "Negotiate"}) |
| 85 | + |
| 86 | + |
| 87 | +def _forbidden(): |
| 88 | + return Response("Forbidden", 403) |
| 89 | + |
| 90 | + |
| 91 | +def _gssapi_authenticate(token) -> _KerberosAuth | None: |
| 92 | + state = None |
| 93 | + try: |
| 94 | + return_code, state = kerberos.authGSSServerInit(_KERBEROS_SERVICE.service_name) |
| 95 | + if return_code != kerberos.AUTH_GSS_COMPLETE: |
| 96 | + return _KerberosAuth(return_code=None) |
| 97 | + |
| 98 | + if (return_code := kerberos.authGSSServerStep(state, token)) == kerberos.AUTH_GSS_COMPLETE: |
| 99 | + return _KerberosAuth( |
| 100 | + return_code=return_code, |
| 101 | + user=kerberos.authGSSServerUserName(state), |
| 102 | + token=kerberos.authGSSServerResponse(state), |
| 103 | + ) |
| 104 | + elif return_code == kerberos.AUTH_GSS_CONTINUE: |
| 105 | + return _KerberosAuth(return_code=return_code) |
| 106 | + return _KerberosAuth(return_code=return_code) |
| 107 | + except kerberos.GSSError: |
| 108 | + return _KerberosAuth(return_code=None) |
| 109 | + finally: |
| 110 | + if state: |
| 111 | + kerberos.authGSSServerClean(state) |
| 112 | + |
| 113 | + |
| 114 | +T = TypeVar("T", bound=Callable) |
| 115 | + |
| 116 | + |
38 | 117 | def find_user(username=None, email=None):
|
39 | 118 | security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager)
|
40 | 119 | return security_manager.find_user(username=username, email=email)
|
41 | 120 |
|
42 | 121 |
|
43 |
| -init_app = base_init_app |
44 |
| -requires_authentication = partial(base_requires_authentication, find_user=find_user) |
| 122 | +def requires_authentication(function: T, find_user: Callable[[str], BaseUser] | None = find_user): |
| 123 | + """Decorate functions that require authentication with Kerberos.""" |
| 124 | + |
| 125 | + @wraps(function) |
| 126 | + def decorated(*args, **kwargs): |
| 127 | + header = request.headers.get("Authorization") |
| 128 | + if header: |
| 129 | + token = "".join(header.split()[1:]) |
| 130 | + auth = _gssapi_authenticate(token) |
| 131 | + if auth.return_code == kerberos.AUTH_GSS_COMPLETE: |
| 132 | + g.user = find_user(auth.user) |
| 133 | + response = function(*args, **kwargs) |
| 134 | + response = make_response(response) |
| 135 | + if auth.token is not None: |
| 136 | + response.headers["WWW-Authenticate"] = f"negotiate {auth.token}" |
| 137 | + return response |
| 138 | + elif auth.return_code != kerberos.AUTH_GSS_CONTINUE: |
| 139 | + return _forbidden() |
| 140 | + return _unauthorized() |
| 141 | + |
| 142 | + return cast(T, decorated) |
0 commit comments