Skip to content

Commit

Permalink
Make JWKSigner an actor (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoffy authored Jul 24, 2024
1 parent 26555af commit 0bb40ec
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
23 changes: 11 additions & 12 deletions Sources/JWTKit/JWK/JWKSigner.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
final class JWKSigner: Sendable {
actor JWKSigner: Sendable {
let jwk: JWK
var signer: JWTSigner?
private(set) var signer: JWTSigner?

let parser: any JWTParser
let serializer: any JWTSerializer

Expand All @@ -19,12 +19,12 @@ final class JWKSigner: Sendable {
self.parser = parser
self.serializer = serializer
}

func makeSigner(for algorithm: JWK.Algorithm) throws -> JWTSigner {
guard let key = try jwk.getKey(for: algorithm) else {
throw JWTError.invalidJWK(reason: "Unable to create signer with given algorithm")
}

let signer = JWTSigner(algorithm: key, parser: parser, serializer: serializer)
self.signer = signer
return signer
Expand All @@ -48,7 +48,7 @@ extension JWK {
} else {
rsaKey = try Insecure.RSA.PublicKey(modulus: modulus, exponent: exponent)
}

let algorithm = alg ?? self.algorithm

switch algorithm {
Expand All @@ -69,15 +69,15 @@ extension JWK {
}

// ECDSA

case .ecdsa:
guard
let x = self.x,
let y = self.y
else {
throw JWTError.invalidJWK(reason: "Missing ECDSA coordinates")
}

let algorithm = alg ?? self.algorithm

switch algorithm {
Expand All @@ -87,7 +87,6 @@ extension JWK {
} else {
return try ECDSASigner(key: ES256PublicKey(parameters: (x, y)))
}

case .es384:
if let privateExponent = self.privateExponent {
return try ECDSASigner(key: ES384PrivateKey(key: privateExponent))
Expand All @@ -103,14 +102,14 @@ extension JWK {
default:
return nil
}

// EdDSA

case .octetKeyPair:
guard let curve = self.curve.flatMap({ EdDSACurve(rawValue: $0.rawValue) }) else {
throw JWTError.invalidJWK(reason: "Invalid EdDSA curve")
}

let algorithm = alg ?? self.algorithm

switch (algorithm, self.x, self.privateExponent) {
Expand Down
16 changes: 8 additions & 8 deletions Sources/JWTKit/JWTKeyCollection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public actor JWTKeyCollection: Sendable {
/// - kid: An optional ``JWKIdentifier``. If not provided, the default signer is returned.
/// - alg: An optional algorithm identifier.
/// - Returns: A ``JWTSigner`` if one is found; otherwise, `nil`.
func getSigner(for kid: JWKIdentifier? = nil, alg: String? = nil) throws -> JWTSigner {
func getSigner(for kid: JWKIdentifier? = nil, alg: String? = nil) async throws -> JWTSigner {
let signer: Signer
if let kid = kid, let stored = self.storage[kid] {
signer = stored
Expand All @@ -132,13 +132,13 @@ public actor JWTKeyCollection: Sendable {
case let .jwt(jwt):
return jwt
case let .jwk(jwk):
if let signer = jwk.signer {
if let signer = await jwk.signer {
return signer
} else {
guard let alg, let jwkAlg = JWK.Algorithm(rawValue: alg) else {
throw JWTError.generic(identifier: "Algorithm", reason: "Invalid algorithm or unable to create signer with provided algorithm.")
}
return try jwk.makeSigner(for: jwkAlg)
return try await jwk.makeSigner(for: jwkAlg)
}
}
}
Expand All @@ -149,8 +149,8 @@ public actor JWTKeyCollection: Sendable {
/// - alg: An optional algorithm identifier.
/// - Returns: A ``JWTKey`` if one is found; otherwise, `nil`.
/// - Throws: ``JWTError/generic`` if the algorithm cannot be retrieved.
public func getKey(for kid: JWKIdentifier? = nil, alg: String? = nil) throws -> JWTAlgorithm {
try self.getSigner(for: kid, alg: alg).algorithm
public func getKey(for kid: JWKIdentifier? = nil, alg: String? = nil) async throws -> JWTAlgorithm {
try await self.getSigner(for: kid, alg: alg).algorithm
}

/// Decodes an unverified JWT payload.
Expand Down Expand Up @@ -224,15 +224,15 @@ public actor JWTKeyCollection: Sendable {
{
let header = try defaultJWTParser.parseHeader(token)
let kid = header.kid.flatMap { JWKIdentifier(string: $0) }
var signer = try self.getSigner(for: kid, alg: header.alg)
var signer = try await self.getSigner(for: kid, alg: header.alg)

do {
return try await signer.verify(token)
} catch {
if iteratingKeys == true {
for (_kid, _) in self.storage where _kid != kid {
do {
signer = try self.getSigner(for: _kid, alg: header.alg)
signer = try await self.getSigner(for: _kid, alg: header.alg)
return try await signer.verify(token)
} catch {}
}
Expand Down Expand Up @@ -265,7 +265,7 @@ public actor JWTKeyCollection: Sendable {
updatedHeader.kid = effectiveKidValue
}

let signer = try self.getSigner(for: effectiveKid, alg: updatedHeader.alg)
let signer = try await self.getSigner(for: effectiveKid, alg: updatedHeader.alg)
return try await signer.sign(payload, with: updatedHeader)
}
}

0 comments on commit 0bb40ec

Please sign in to comment.