Skip to content

Commit

Permalink
Add to/from_json methods to WebAuthn dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Mar 5, 2024
1 parent 4c6f7b6 commit e983d4b
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 61 deletions.
18 changes: 10 additions & 8 deletions examples/server/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@
Navigate to https://localhost:5000 in a supported web browser.
"""
from fido2.webauthn import PublicKeyCredentialRpEntity, PublicKeyCredentialUserEntity
from fido2.webauthn import (
PublicKeyCredentialRpEntity,
PublicKeyCredentialUserEntity,
RegistrationResponse,
AuthenticationResponse,
)
from fido2.server import Fido2Server
from flask import Flask, session, request, redirect, abort, jsonify

import os
import fido2.features

fido2.features.webauthn_json_mapping.enabled = True


app = Flask(__name__, static_url_path="")
Expand Down Expand Up @@ -78,12 +80,12 @@ def register_begin():
print(options)
print("\n\n\n\n")

return jsonify(dict(options))
return jsonify(options.to_json())


@app.route("/api/register/complete", methods=["POST"])
def register_complete():
response = request.json
response = RegistrationResponse.from_json(request.json)
print("RegistrationResponse:", response)
auth_data = server.register_complete(session["state"], response)

Expand All @@ -100,15 +102,15 @@ def authenticate_begin():
options, state = server.authenticate_begin(credentials)
session["state"] = state

return jsonify(dict(options))
return jsonify(options.to_json())


@app.route("/api/authenticate/complete", methods=["POST"])
def authenticate_complete():
if not credentials:
abort(404)

response = request.json
response = AuthenticationResponse.from_json(request.json)
print("AuthenticationResponse:", response)
server.authenticate_complete(
session.pop("state"),
Expand Down
1 change: 1 addition & 0 deletions fido2/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
verify_attestation: Optional[VerifyAttestation] = None,
):
self.rp = PublicKeyCredentialRpEntity.from_dict(rp)
assert self.rp.id is not None # nosec
self._verify = verify_origin or _verify_origin_for_rp(self.rp.id)
self.timeout = None
self.attestation = AttestationConveyancePreference(attestation)
Expand Down
25 changes: 21 additions & 4 deletions fido2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
Sequence,
Mapping,
Any,
Type,
TypeVar,
Hashable,
get_type_hints,
overload,
)
import struct
import warnings
Expand Down Expand Up @@ -207,14 +209,19 @@ def _parse_value(t, value):
return t.from_dict(value)

# Convert to enum values, other wrappers
return t(value)
try:
return t(value)
except Exception:
print("EXCEPTION", t, value)
raise


_T = TypeVar("_T", bound=Hashable)
_T2 = TypeVar("_T2", bound="_DataClassMapping")


class _DataClassMapping(Mapping[_T, Any]):
# TODO: This requires Python 3.9, and fixes the tpye errors we now ignore
# TODO: This requires Python 3.9, and fixes the type errors we now ignore
# __dataclass_fields__: ClassVar[Dict[str, Field[Any]]]

def __post_init__(self):
Expand All @@ -233,7 +240,7 @@ def __post_init__(self):

@classmethod
@abstractmethod
def _get_field_key(cls, field: Field) -> _T:
def _get_field_key(cls: Type[_T2], field: Field) -> _T:
raise NotImplementedError()

def __getitem__(self, key):
Expand Down Expand Up @@ -262,8 +269,18 @@ def __iter__(self):
def __len__(self):
return len(list(iter(self)))

@overload
@classmethod
def from_dict(cls: Type[_T2], data: None) -> None:
pass

@overload
@classmethod
def from_dict(cls: Type[_T2], data: Mapping[_T, Any]) -> _T2:
pass

@classmethod
def from_dict(cls, data: Optional[Mapping[_T, Any]]):
def from_dict(cls: Type[_T2], data: Optional[Mapping[_T, Any]]) -> Optional[_T2]:
if data is None:
return None
if isinstance(data, cls):
Expand Down
126 changes: 101 additions & 25 deletions fido2/webauthn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,19 @@
)
from .features import webauthn_json_mapping
from enum import Enum, EnumMeta, unique, IntFlag
from dataclasses import dataclass, field
from typing import Any, Mapping, Optional, Sequence, Tuple, Union, cast
from dataclasses import dataclass, field, fields
from typing import (
Any,
Mapping,
Optional,
Sequence,
Tuple,
Union,
Type,
TypeVar,
cast,
get_type_hints,
)
import struct
import json

Expand Down Expand Up @@ -457,8 +468,81 @@ class PublicKeyCredentialType(_StringEnum):
PUBLIC_KEY = "public-key"


_T2 = TypeVar("_T2", bound="_JsonDataObject")


def _get_basetype(t):
if Optional[t] == t: # Optional, get the type
t = t.__args__[0]

# Handle list of values
if issubclass(getattr(t, "__origin__", object), Sequence):
t = t.__args__[0]

return t


def _dump_json(value):
if isinstance(value, _JsonDataObject):
return value.to_json()
if isinstance(value, bytes):
return websafe_encode(value)
if isinstance(value, list):
return [_dump_json(x) for x in value]
return value


def _load_json(hint, value):
t = _get_basetype(hint)
if isinstance(value, str):
if issubclass(t, bytes):
value = websafe_decode(value)
return t(value)

# Handle lists
if isinstance(value, Sequence):
return [_load_json(t, v) for v in value]

# Check for subclass of _JsonDataObject
try:
is_json = issubclass(t, _JsonDataObject)
except TypeError:
is_json = False

if is_json:
# Recursively call from_json for nested _JsonDataObject
return t.from_json(value)

return value


class _JsonDataObject(_CamelCaseDataObject):
def to_json(self) -> Mapping[str, Any]:
"""Returns a dict of the object which can be serialized to JSON."""
data = {}
for f in fields(self): # type: ignore
key = self._get_field_key(f)
value = getattr(self, f.name)
if value is not None:
data[key] = _dump_json(value)
return data

@classmethod
def from_json(cls: Type[_T2], data: Mapping[str, Any]) -> _T2:
"""Instantiates an object from a JSON-compatible dict representation."""
hints = get_type_hints(cls)
resp = {}
for f in fields(cls): # type: ignore
key = cls._get_field_key(f)
if key in data:
value = data[key]
hint = hints.get(f.name)
resp[key] = _load_json(hint, value)
return cls.from_dict(resp)


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialRpEntity(_CamelCaseDataObject):
class PublicKeyCredentialRpEntity(_JsonDataObject):
name: str
id: Optional[str] = None

Expand All @@ -469,14 +553,14 @@ def id_hash(self) -> Optional[bytes]:


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialUserEntity(_CamelCaseDataObject):
class PublicKeyCredentialUserEntity(_JsonDataObject):
name: str
id: bytes = field(metadata=_b64_metadata)
display_name: Optional[str] = None


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialParameters(_CamelCaseDataObject):
class PublicKeyCredentialParameters(_JsonDataObject):
type: PublicKeyCredentialType
alg: int

Expand All @@ -489,7 +573,7 @@ def _deserialize_list(cls, value):


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialDescriptor(_CamelCaseDataObject):
class PublicKeyCredentialDescriptor(_JsonDataObject):
type: PublicKeyCredentialType
id: bytes = field(metadata=_b64_metadata)
transports: Optional[Sequence[AuthenticatorTransport]] = None
Expand All @@ -503,7 +587,7 @@ def _deserialize_list(cls, value):


@dataclass(eq=False, frozen=True)
class AuthenticatorSelectionCriteria(_CamelCaseDataObject):
class AuthenticatorSelectionCriteria(_JsonDataObject):
authenticator_attachment: Optional[AuthenticatorAttachment] = None
resident_key: Optional[ResidentKeyRequirement] = None
user_verification: Optional[UserVerificationRequirement] = None
Expand All @@ -530,7 +614,7 @@ def __post_init__(self):


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialCreationOptions(_CamelCaseDataObject):
class PublicKeyCredentialCreationOptions(_JsonDataObject):
rp: PublicKeyCredentialRpEntity
user: PublicKeyCredentialUserEntity
challenge: bytes = field(metadata=_b64_metadata)
Expand All @@ -548,7 +632,7 @@ class PublicKeyCredentialCreationOptions(_CamelCaseDataObject):


@dataclass(eq=False, frozen=True)
class PublicKeyCredentialRequestOptions(_CamelCaseDataObject):
class PublicKeyCredentialRequestOptions(_JsonDataObject):
challenge: bytes = field(metadata=_b64_metadata)
timeout: Optional[int] = None
rp_id: Optional[str] = None
Expand All @@ -561,7 +645,7 @@ class PublicKeyCredentialRequestOptions(_CamelCaseDataObject):


@dataclass(eq=False, frozen=True)
class AuthenticatorAttestationResponse(_CamelCaseDataObject):
class AuthenticatorAttestationResponse(_JsonDataObject):
client_data: CollectedClientData = field(
metadata=dict(
_b64_metadata,
Expand All @@ -578,15 +662,15 @@ def __getitem__(self, key):

@classmethod
def from_dict(cls, data: Optional[Mapping[str, Any]]):
if data is not None and not webauthn_json_mapping.enabled:
if data is not None and "clientData" in data:
value = dict(data)
value["clientDataJSON"] = value.pop("clientData", None)
data = value
return super().from_dict(data)


@dataclass(eq=False, frozen=True)
class AuthenticatorAssertionResponse(_CamelCaseDataObject):
class AuthenticatorAssertionResponse(_JsonDataObject):
client_data: CollectedClientData = field(
metadata=dict(
_b64_metadata,
Expand All @@ -606,44 +690,36 @@ def __getitem__(self, key):

@classmethod
def from_dict(cls, data: Optional[Mapping[str, Any]]):
if data is not None and not webauthn_json_mapping.enabled:
if data is not None and "clientData" in data:
value = dict(data)
value["clientDataJSON"] = value.pop("clientData", None)
data = value
return super().from_dict(data)


@dataclass(eq=False, frozen=True)
class RegistrationResponse(_CamelCaseDataObject):
class RegistrationResponse(_JsonDataObject):
id: bytes = field(metadata=_b64_metadata)
response: AuthenticatorAttestationResponse
authenticator_attachment: Optional[AuthenticatorAttachment] = None
client_extension_results: Optional[Mapping] = None
type: Optional[PublicKeyCredentialType] = None

def __post_init__(self):
webauthn_json_mapping.require()
super().__post_init__()


@dataclass(eq=False, frozen=True)
class AuthenticationResponse(_CamelCaseDataObject):
class AuthenticationResponse(_JsonDataObject):
id: bytes = field(metadata=_b64_metadata)
response: AuthenticatorAssertionResponse
authenticator_attachment: Optional[AuthenticatorAttachment] = None
client_extension_results: Optional[Mapping] = None
type: Optional[PublicKeyCredentialType] = None

def __post_init__(self):
webauthn_json_mapping.require()
super().__post_init__()


@dataclass(eq=False, frozen=True)
class CredentialCreationOptions(_CamelCaseDataObject):
class CredentialCreationOptions(_JsonDataObject):
public_key: PublicKeyCredentialCreationOptions


@dataclass(eq=False, frozen=True)
class CredentialRequestOptions(_CamelCaseDataObject):
class CredentialRequestOptions(_JsonDataObject):
public_key: PublicKeyCredentialRequestOptions
3 changes: 0 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import fido2.features

fido2.features.webauthn_json_mapping.enabled = True
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import unittest
from unittest import mock
from fido2 import cbor
from fido2.utils import sha256, websafe_encode
from fido2.utils import sha256
from fido2.hid import CAPABILITY
from fido2.ctap import CtapError
from fido2.ctap1 import RegistrationData
Expand All @@ -51,7 +51,7 @@
)

rp = {"id": "example.com", "name": "Example RP"}
user = {"id": websafe_encode(b"user_id"), "name": "A. User"}
user = {"id": b"user_id", "name": "A. User"}
challenge = b"Y2hhbGxlbmdl"
_INFO_NO_PIN = bytes.fromhex(
"a60182665532465f5632684649444f5f325f3002826375766d6b686d61632d7365637265740350f8a011f38c0a4d15800617111f9edc7d04a462726bf5627570f564706c6174f469636c69656e7450696ef4051904b0068101" # noqa E501
Expand Down
3 changes: 1 addition & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
AttestedCredentialData,
AuthenticatorData,
)
from fido2.utils import websafe_encode

from .test_ctap2 import _ATT_CRED_DATA, _CRED_ID
from .utils import U2FDevice
Expand Down Expand Up @@ -96,7 +95,7 @@ def test_register_begin_custom_challenge(self):
challenge = b"1234567890123456"
request, state = server.register_begin(USER, challenge=challenge)

self.assertEqual(request["publicKey"]["challenge"], websafe_encode(challenge))
self.assertEqual(request["publicKey"]["challenge"], challenge)

def test_register_begin_custom_challenge_too_short(self):
rp = PublicKeyCredentialRpEntity("Example", "example.com")
Expand Down
Loading

0 comments on commit e983d4b

Please sign in to comment.