diff --git a/fido2/ctap2/extensions.py b/fido2/ctap2/extensions.py index 2cb3d94..9fd450d 100644 --- a/fido2/ctap2/extensions.py +++ b/fido2/ctap2/extensions.py @@ -30,6 +30,7 @@ from .base import AttestationResponse, AssertionResponse, Ctap2 from .pin import ClientPin, PinProtocol from .blob import LargeBlobs +from ..utils import sha256 from enum import Enum, unique from typing import Dict, Tuple, Any, Optional import abc @@ -91,6 +92,10 @@ def process_get_output( return None +def _prf_salt(secret): + return sha256(b"WebAuthn PRF\0" + secret) + + class HmacSecretExtension(Ctap2Extension): """ Implements the hmac-secret CTAP2 extension. @@ -104,23 +109,44 @@ def __init__(self, ctap, pin_protocol=None): self.pin_protocol = pin_protocol def process_create_input(self, inputs): - if self.is_supported() and inputs.get("hmacCreateSecret") is True: - return True + if self.is_supported(): + if inputs.get("hmacCreateSecret") is True: + self.prf = False + return True + elif inputs.get("prf") is not None: + self.prf = True + return True def process_create_output(self, attestation_response, *args): - if attestation_response.auth_data.extensions.get(self.NAME): - return {"hmacCreateSecret": True} + enabled = attestation_response.auth_data.extensions.get(self.NAME, False) + if self.prf: + return {"prf": {"enabled": enabled}} + else: + return {"hmacCreateSecret": enabled} def process_get_input(self, inputs): - data = self.is_supported() and inputs.get("hmacGetSecret") - if not data: + if not self.is_supported(): + return + + data = inputs.get("prf") + if data: + secrets = data.get("eval") + salts = ( + _prf_salt(secrets["first"]), + _prf_salt(secrets["second"]) if "second" in secrets else b"" + ) + self.prf = True + else: + data = inputs.get("hmacGetSecret") + salts = data["salt1"], data.get("salt2", b"") + self.prf = False + + if not salts: return - salt1 = data["salt1"] - salt2 = data.get("salt2", b"") if not ( - len(salt1) == HmacSecretExtension.SALT_LEN - and (not salt2 or len(salt2) == HmacSecretExtension.SALT_LEN) + len(salts[0]) == HmacSecretExtension.SALT_LEN + and (not salts[1] or len(salts[1]) == HmacSecretExtension.SALT_LEN) ): raise ValueError("Invalid salt length") @@ -129,7 +155,7 @@ def process_get_input(self, inputs): if self.pin_protocol is None: self.pin_protocol = client_pin.protocol - salt_enc = self.pin_protocol.encrypt(self.shared_secret, salt1 + salt2) + salt_enc = self.pin_protocol.encrypt(self.shared_secret, salts[0] + salts[1]) salt_auth = self.pin_protocol.authenticate(self.shared_secret, salt_enc) return { @@ -145,11 +171,17 @@ def process_get_output(self, assertion_response, *args): decrypted = self.pin_protocol.decrypt(self.shared_secret, value) output1 = decrypted[: HmacSecretExtension.SALT_LEN] output2 = decrypted[HmacSecretExtension.SALT_LEN :] - outputs = {"output1": output1} - if output2: - outputs["output2"] = output2 - return {"hmacGetSecret": outputs} + if self.prf: + results = {"first": output1} + if output2: + results["second"] = output2 + return {"prf": {"results": results}} + else: + results = {"output1": output1} + if output2: + results["output2"] = output2 + return {"hmacGetSecret": results} class LargeBlobExtension(Ctap2Extension):