Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more typings #1356

Merged
merged 11 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jupyter_server/_tz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ZERO = timedelta(0)


class tzUTC(tzinfo): # noqa
class tzUTC(tzinfo): # noqa: N801
"""tzinfo object for UTC (zero offset)"""

def utcoffset(self, d: datetime | None) -> timedelta:
Expand All @@ -30,7 +30,7 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)


def utcfromtimestamp(timestamp):
def utcfromtimestamp(timestamp: float) -> datetime:
return datetime.fromtimestamp(timestamp, timezone.utc)


Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def inner(self, *args, **kwargs):
method = action
action = None
# no-arguments `@authorized` decorator called
return wrapper(method)
return cast(FuncT, wrapper(method))

return cast(FuncT, wrapper)
105 changes: 52 additions & 53 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import os
import re
import sys
import typing as t
import uuid
from dataclasses import asdict, dataclass
from http.cookies import Morsel
from typing import TYPE_CHECKING, Any, Awaitable

from tornado import escape, httputil, web
from traitlets import Bool, Dict, Type, Unicode, default
Expand All @@ -27,11 +27,6 @@
from .security import passwd_check, set_password
from .utils import get_anonymous_username

# circular imports for type checking
if TYPE_CHECKING:
from jupyter_server.base.handlers import AuthenticatedHandler, JupyterHandler
from jupyter_server.serverapp import ServerApp

_non_alphanum = re.compile(r"[^A-Za-z0-9]")


Expand Down Expand Up @@ -82,7 +77,7 @@ def fill_defaults(self):
self.display_name = self.name


def _backward_compat_user(got_user: Any) -> User:
def _backward_compat_user(got_user: t.Any) -> User:
"""Backward-compatibility for LoginHandler.get_user

Prior to 2.0, LoginHandler.get_user could return anything truthy.
Expand Down Expand Up @@ -128,7 +123,7 @@ class IdentityProvider(LoggingConfigurable):
.. versionadded:: 2.0
"""

cookie_name: str | Unicode = Unicode(
cookie_name: str | Unicode[str, str | bytes] = Unicode(
"",
config=True,
help=_i18n("Name of the cookie to set for persisting login. Default: username-${Host}."),
Expand All @@ -142,7 +137,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

secure_cookie: bool | Bool = Bool(
secure_cookie: bool | Bool[bool | None, bool | int | None] = Bool(
None,
allow_none=True,
config=True,
Expand All @@ -160,7 +155,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

token: str | Unicode = Unicode(
token: str | Unicode[str, str | bytes] = Unicode(
"<generated>",
help=_i18n(
"""Token used for authenticating first-time connections to the server.
Expand Down Expand Up @@ -211,9 +206,9 @@ def _token_default(self):
self.token_generated = True
return binascii.hexlify(os.urandom(24)).decode("ascii")

need_token: bool | Bool = Bool(True)
need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True)

def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]:
"""Get the authenticated user for a request

Must return a :class:`jupyter_server.auth.User`,
Expand All @@ -228,17 +223,17 @@ def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | No
# not sure how to have optional-async type signature
# on base class with `async def` without splitting it into two methods

async def _get_user(self, handler: JupyterHandler) -> User | None:
async def _get_user(self, handler: web.RequestHandler) -> User | None:
"""Get the user."""
if getattr(handler, "_jupyter_current_user", None):
# already authenticated
return handler._jupyter_current_user
_token_user: User | None | Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, Awaitable):
return t.cast(User, handler._jupyter_current_user) # type:ignore[attr-defined]
_token_user: User | None | t.Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, t.Awaitable):
_token_user = await _token_user
token_user: User | None = _token_user # need second variable name to collapse type
_cookie_user = self.get_user_cookie(handler)
if isinstance(_cookie_user, Awaitable):
if isinstance(_cookie_user, t.Awaitable):
_cookie_user = await _cookie_user
cookie_user: User | None = _cookie_user
# prefer token to cookie if both given,
Expand Down Expand Up @@ -273,12 +268,12 @@ async def _get_user(self, handler: JupyterHandler) -> User | None:

return user

def identity_model(self, user: User) -> dict:
def identity_model(self, user: User) -> dict[str, t.Any]:
"""Return a User as an Identity model"""
# TODO: validate?
return asdict(user)

def get_handlers(self) -> list:
def get_handlers(self) -> list[tuple[str, object]]:
"""Return list of additional handlers for this identity provider

For example, an OAuth callback handler.
Expand Down Expand Up @@ -321,7 +316,7 @@ def user_from_cookie(self, cookie_value: str) -> User | None:
user["color"],
)

def get_cookie_name(self, handler: AuthenticatedHandler) -> str:
def get_cookie_name(self, handler: web.RequestHandler) -> str:
"""Return the login cookie name

Uses IdentityProvider.cookie_name, if defined.
Expand All @@ -333,7 +328,7 @@ def get_cookie_name(self, handler: AuthenticatedHandler) -> str:
else:
return _non_alphanum.sub("-", f"username-{handler.request.host}")

def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None:
def set_login_cookie(self, handler: web.RequestHandler, user: User) -> None:
"""Call this on handlers to set the login cookie for success"""
cookie_options = {}
cookie_options.update(self.cookie_options)
Expand All @@ -345,12 +340,12 @@ def set_login_cookie(self, handler: AuthenticatedHandler, user: User) -> None:
secure_cookie = handler.request.protocol == "https"
if secure_cookie:
cookie_options.setdefault("secure", True)
cookie_options.setdefault("path", handler.base_url)
cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
cookie_name = self.get_cookie_name(handler)
handler.set_secure_cookie(cookie_name, self.user_to_cookie(user), **cookie_options)

def _force_clear_cookie(
self, handler: AuthenticatedHandler, name: str, path: str = "/", domain: str | None = None
self, handler: web.RequestHandler, name: str, path: str = "/", domain: str | None = None
) -> None:
"""Deletes the cookie with the given name.

Expand All @@ -368,19 +363,19 @@ def _force_clear_cookie(
name = escape.native_str(name)
expires = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=365)

morsel: Morsel = Morsel()
morsel: Morsel[t.Any] = Morsel()
morsel.set(name, "", '""')
morsel["expires"] = httputil.format_timestamp(expires)
morsel["path"] = path
if domain:
morsel["domain"] = domain
handler.add_header("Set-Cookie", morsel.OutputString())

def clear_login_cookie(self, handler: AuthenticatedHandler) -> None:
def clear_login_cookie(self, handler: web.RequestHandler) -> None:
"""Clear the login cookie, effectively logging out the session."""
cookie_options = {}
cookie_options.update(self.cookie_options)
path = cookie_options.setdefault("path", handler.base_url)
path = cookie_options.setdefault("path", handler.base_url) # type:ignore[attr-defined]
cookie_name = self.get_cookie_name(handler)
handler.clear_cookie(cookie_name, path=path)
if path and path != "/":
Expand All @@ -390,7 +385,9 @@ def clear_login_cookie(self, handler: AuthenticatedHandler) -> None:
# two cookies with the same name. See the method above.
self._force_clear_cookie(handler, cookie_name)

def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user_cookie(
self, handler: web.RequestHandler
) -> User | None | t.Awaitable[User | None]:
"""Get user from a cookie

Calls user_from_cookie to deserialize cookie value
Expand All @@ -413,7 +410,7 @@ def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[Us

auth_header_pat = re.compile(r"(token|bearer)\s+(.+)", re.IGNORECASE)

def get_token(self, handler: JupyterHandler) -> str | None:
def get_token(self, handler: web.RequestHandler) -> str | None:
"""Get the user token from a request

Default:
Expand All @@ -429,14 +426,14 @@ def get_token(self, handler: JupyterHandler) -> str | None:
user_token = m.group(2)
return user_token

async def get_user_token(self, handler: JupyterHandler) -> User | None:
async def get_user_token(self, handler: web.RequestHandler) -> User | None:
"""Identify the user based on a token in the URL or Authorization header

Returns:
- uuid if authenticated
- None if not
"""
token = handler.token
token = t.cast("str | None", handler.token) # type:ignore[attr-defined]
if not token:
return None
# check login token from URL argument or Authorization header
Expand All @@ -455,7 +452,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None:
# which is stored in a cookie.
# still check the cookie for the user id
_user = self.get_user_cookie(handler)
if isinstance(_user, Awaitable):
if isinstance(_user, t.Awaitable):
_user = await _user
user: User | None = _user
if user is None:
Expand All @@ -464,7 +461,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None:
else:
return None

def generate_anonymous_user(self, handler: JupyterHandler) -> User:
def generate_anonymous_user(self, handler: web.RequestHandler) -> User:
"""Generate a random anonymous user.

For use when a single shared token is used,
Expand All @@ -475,10 +472,10 @@ def generate_anonymous_user(self, handler: JupyterHandler) -> User:
name = display_name = f"Anonymous {moon}"
initials = f"A{moon[0]}"
color = None
handler.log.debug(f"Generating new user for token-authenticated request: {user_id}")
handler.log.debug(f"Generating new user for token-authenticated request: {user_id}") # type:ignore[attr-defined]
return User(user_id, name, display_name, initials, None, color)

def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
def should_check_origin(self, handler: web.RequestHandler) -> bool:
"""Should the Handler check for CORS origin validation?

Origin check should be skipped for token-authenticated requests.
Expand All @@ -489,7 +486,7 @@ def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
"""
return not self.is_token_authenticated(handler)

def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
"""Returns True if handler has been token authenticated. Otherwise, False.

Login with a token is used to signal certain things, such as:
Expand All @@ -504,8 +501,8 @@ def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Check the application's security.

Expand All @@ -526,7 +523,7 @@ def validate_security(
" Anyone who can connect to this server will be able to run code."
)

def process_login_form(self, handler: JupyterHandler) -> User | None:
def process_login_form(self, handler: web.RequestHandler) -> User | None:
"""Process login form data

Return authenticated User if successful, None if not.
Expand All @@ -538,7 +535,7 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:
return self.generate_anonymous_user(handler)

if self.token and self.token == typed_password:
return self.user_for_token(typed_password) # type:ignore[attr-defined]
return t.cast(User, self.user_for_token(typed_password)) # type:ignore[attr-defined]

return user

Expand Down Expand Up @@ -633,7 +630,7 @@ def passwd_check(self, password):
"""Check password against our stored hashed password"""
return passwd_check(self.hashed_password, password)

def process_login_form(self, handler: JupyterHandler) -> User | None:
def process_login_form(self, handler: web.RequestHandler) -> User | None:
"""Process login form data

Return authenticated User if successful, None if not.
Expand All @@ -659,8 +656,8 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Handle security validation."""
super().validate_security(app, ssl_options)
Expand Down Expand Up @@ -700,31 +697,33 @@ def _default_login_handler_class(self):
def auth_enabled(self):
return self.login_available

def get_user(self, handler: JupyterHandler) -> User | None:
def get_user(self, handler: web.RequestHandler) -> User | None:
"""Get the user."""
user = self.login_handler_class.get_user(handler) # type:ignore[attr-defined]
if user is None:
return None
return _backward_compat_user(user)

@property
def login_available(self):
return self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
def login_available(self) -> bool:
return bool(
self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
)
)

def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
def should_check_origin(self, handler: web.RequestHandler) -> bool:
"""Whether we should check origin."""
return self.login_handler_class.should_check_origin(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.should_check_origin(handler)) # type:ignore[attr-defined]

def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
def is_token_authenticated(self, handler: web.RequestHandler) -> bool:
"""Whether we are token authenticated."""
return self.login_handler_class.is_token_authenticated(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.is_token_authenticated(handler)) # type:ignore[attr-defined]

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
app: t.Any,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Validate security."""
if self.password_required and (not self.hashed_password):
Expand All @@ -734,6 +733,6 @@ def validate_security(
self.log.critical(_i18n("Hint: run the following command to set a password"))
self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
sys.exit(1)
return self.login_handler_class.validate_security( # type:ignore[attr-defined]
self.login_handler_class.validate_security( # type:ignore[attr-defined]
app, ssl_options
)
Loading
Loading