Skip to content

Commit

Permalink
use whole-process/cross-module global variable for token (#13)
Browse files Browse the repository at this point in the history
* use whole-process/cross-module global variable for token
  • Loading branch information
sigmarkarl authored Dec 21, 2023
1 parent 2eb3811 commit 05b122d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
39 changes: 33 additions & 6 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import contextvars
import functools
import inspect
import ipaddress
import json
import mimetypes
import os
import re
import sys
import types
import typing as ty
import logging
import warnings
from http.client import responses
from logging import Logger
Expand All @@ -23,7 +25,6 @@
from jupyter_core.paths import is_hidden
from jupyter_events import EventLogger
from tornado import web
from tornado.httputil import HTTPServerRequest
from tornado.log import app_log
from traitlets.config import Application

Expand Down Expand Up @@ -56,12 +57,12 @@
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_server.services.sessions.sessionmanager import SessionManager

_current_request_var: contextvars.ContextVar[HTTPServerRequest] = contextvars.ContextVar("current_request")
# -----------------------------------------------------------------------------
# Top-level handlers
# -----------------------------------------------------------------------------

_sys_info_cache = None
_my_globals = "myglobals"


def json_sys_info():
Expand All @@ -80,11 +81,35 @@ def log() -> Logger:
return app_log


def get_token_value(request: ty.Any, prev: str) -> str:
header = "Authorization"
if header not in request.headers:
logging.error(f'Header "{header}" is missing')
return prev
logging.debug(f'Getting value from header "{header}"')
auth_header_value: str = request.headers[header]
if len(auth_header_value) == 0:
logging.error(f'Header "{header}" is empty')
return prev

try:
logging.info(f"Auth header value: {auth_header_value}")
# We expect the header value to be of the form "Bearer: XXX"
return auth_header_value.split(" ", maxsplit=1)[1]
except Exception as e:
logging.error(f"Could not read token from auth header: {str(e)}")

return prev


class AuthenticatedHandler(web.RequestHandler):
"""A RequestHandler with an authenticated user."""

def prepare(self):
_current_request_var.set(self.request)
if _my_globals not in sys.modules:
sys.modules[_my_globals] = types.ModuleType(_my_globals)
prevtoken = sys.modules[_my_globals].token if hasattr(sys.modules[_my_globals], "token") else ""
sys.modules[_my_globals].token = get_token_value(self.request, prevtoken)

@property
def base_url(self) -> str:
Expand Down Expand Up @@ -1141,11 +1166,13 @@ def get(self) -> None:
self.set_header("Content-Type", prometheus_client.CONTENT_TYPE_LATEST)
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))

def get_current_request():
def get_current_token():
"""
Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed.
"""
return _current_request_var.get(None)
if _my_globals in sys.modules and hasattr(sys.modules[_my_globals], "token"):
return sys.modules[_my_globals].token
return ""

# -----------------------------------------------------------------------------
# URL pattern fragments for reuse
Expand Down
27 changes: 4 additions & 23 deletions jupyter_server/gateway/spottokenrenewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,6 @@
import jupyter_server.serverapp


def get_header_value(request: ty.Any, header: str) -> str:
if header not in request.headers:
logging.error(f'Header "{header}" is missing')
return ""
logging.debug(f'Getting value from header "{header}"')
value: str = request.headers[header]
if len(value) == 0:
logging.error(f'Header "{header}" is empty')
return ""
return value


class SpotTokenRenewer(GatewayTokenRenewerBase): # type:ignore[misc]

def get_token(
Expand All @@ -27,17 +15,10 @@ def get_token(
auth_token: str,
**kwargs: ty.Any,
) -> str:
request = jupyter_server.base.handlers.get_current_request()
if request is None:
token = jupyter_server.base.handlers.get_current_token()
if token is "":
logging.error("Could not get current request")
return auth_token

auth_header_value = get_header_value(request, auth_header_key)
if auth_header_value:
try:
# We expect the header value to be of the form "Bearer: XXX"
auth_token = auth_header_value.split(" ", maxsplit=1)[1]
except Exception as e:
logging.error(f"Could not read token from auth header: {str(e)}")

return auth_token
logging.info("Auth token refreshed")
return token

0 comments on commit 05b122d

Please sign in to comment.