Skip to content

Commit

Permalink
Cache JWKSigner when possible (#172)
Browse files Browse the repository at this point in the history
* Cache JWKSigner when possible

* Format nit

* Update Sources/JWTKit/JWK/JWKSigner.swift

Co-authored-by: Tim Condon <[email protected]>

---------

Co-authored-by: Tim Condon <[email protected]>
  • Loading branch information
ptoffy and 0xTim authored Jul 24, 2024
1 parent 71d78cc commit 26555af
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 79 deletions.
151 changes: 82 additions & 69 deletions Sources/JWTKit/JWK/JWKSigner.swift
Original file line number Diff line number Diff line change
@@ -1,113 +1,126 @@
struct JWKSigner: Sendable {
final class JWKSigner: Sendable {
let jwk: JWK

var signer: JWTSigner?

let parser: any JWTParser
let serializer: any JWTSerializer

init(jwk: JWK, parser: some JWTParser, serializer: some JWTSerializer) {
init(
jwk: JWK,
parser: some JWTParser = DefaultJWTParser(),
serializer: some JWTSerializer = DefaultJWTSerializer()
) throws {
self.jwk = jwk
if let algorithm = try jwk.getKey() {
self.signer = .init(algorithm: algorithm, parser: parser, serializer: serializer)
} else {
self.signer = nil
}
self.parser = parser
self.serializer = serializer
}

func makeSigner(for algorithm: JWK.Algorithm) throws -> JWTSigner {
guard let key = try jwk.getKey(for: algorithm) else {
throw JWTError.invalidJWK(reason: "Unable to create signer with given algorithm")
}

let signer = JWTSigner(algorithm: key, parser: parser, serializer: serializer)
self.signer = signer
return signer
}
}

func signer(for algorithm: JWK.Algorithm? = nil) -> JWTSigner? {
switch jwk.keyType.backing {
extension JWK {
func getKey(for alg: JWK.Algorithm? = nil) throws -> (any JWTAlgorithm)? {
switch self.keyType.backing {
case .rsa:
guard
let modulus = self.jwk.modulus,
let exponent = self.jwk.exponent
let modulus = self.modulus,
let exponent = self.exponent
else {
return nil
throw JWTError.invalidJWK(reason: "Missing RSA primitives")
}

let rsaKey: RSAKey
do {
if let privateExponent = jwk.privateExponent {
rsaKey = try Insecure.RSA.PrivateKey(modulus: modulus, exponent: exponent, privateExponent: privateExponent)
} else {
rsaKey = try Insecure.RSA.PublicKey(modulus: modulus, exponent: exponent)
}
} catch {
return nil
}

guard let algorithm = algorithm ?? self.jwk.algorithm else {
return nil
if let privateExponent = self.privateExponent {
rsaKey = try Insecure.RSA.PrivateKey(modulus: modulus, exponent: exponent, privateExponent: privateExponent)
} else {
rsaKey = try Insecure.RSA.PublicKey(modulus: modulus, exponent: exponent)
}

let algorithm = alg ?? self.algorithm

switch algorithm {
case .rs256:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha256, name: "RS256", padding: .insecurePKCS1v1_5))
return RSASigner(key: rsaKey, algorithm: .sha256, name: "RS256", padding: .insecurePKCS1v1_5)
case .rs384:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha384, name: "RS384", padding: .insecurePKCS1v1_5))
return RSASigner(key: rsaKey, algorithm: .sha384, name: "RS384", padding: .insecurePKCS1v1_5)
case .rs512:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha512, name: "RS512", padding: .insecurePKCS1v1_5))
return RSASigner(key: rsaKey, algorithm: .sha512, name: "RS512", padding: .insecurePKCS1v1_5)
case .ps256:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha256, name: "PS256", padding: .PSS))
return RSASigner(key: rsaKey, algorithm: .sha256, name: "PS256", padding: .PSS)
case .ps384:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha384, name: "PS384", padding: .PSS))
return RSASigner(key: rsaKey, algorithm: .sha384, name: "PS384", padding: .PSS)
case .ps512:
return .init(algorithm: RSASigner(key: rsaKey, algorithm: .sha512, name: "PS512", padding: .PSS))
return RSASigner(key: rsaKey, algorithm: .sha512, name: "PS512", padding: .PSS)
default:
return nil
}

// ECDSA

case .ecdsa:
guard let x = self.jwk.x else {
return nil
}
guard let y = self.jwk.y else {
return nil
}

guard let algorithm = algorithm ?? self.jwk.algorithm else {
return nil
guard
let x = self.x,
let y = self.y
else {
throw JWTError.invalidJWK(reason: "Missing ECDSA coordinates")
}

let algorithm = alg ?? self.algorithm

do {
switch algorithm {
case .es256:
if let privateExponent = self.jwk.privateExponent {
return try .init(algorithm: ECDSASigner(key: ES256PrivateKey(key: privateExponent)))
} else {
return try .init(algorithm: ECDSASigner(key: ES256PublicKey(parameters: (x, y))))
}
switch algorithm {
case .es256:
if let privateExponent = self.privateExponent {
return try ECDSASigner(key: ES256PrivateKey(key: privateExponent))
} else {
return try ECDSASigner(key: ES256PublicKey(parameters: (x, y)))
}

case .es384:
if let privateExponent = self.jwk.privateExponent {
return try .init(algorithm: ECDSASigner(key: ES384PrivateKey(key: privateExponent)))
} else {
return try .init(algorithm: ECDSASigner(key: ES384PublicKey(parameters: (x, y))))
}
case .es512:
if let privateExponent = self.jwk.privateExponent {
return try .init(algorithm: ECDSASigner(key: ES512PrivateKey(key: privateExponent)))
} else {
return try .init(algorithm: ECDSASigner(key: ES512PublicKey(parameters: (x, y))))
}
default:
return nil
case .es384:
if let privateExponent = self.privateExponent {
return try ECDSASigner(key: ES384PrivateKey(key: privateExponent))
} else {
return try ECDSASigner(key: ES384PublicKey(parameters: (x, y)))
}
} catch {
case .es512:
if let privateExponent = self.privateExponent {
return try ECDSASigner(key: ES512PrivateKey(key: privateExponent))
} else {
return try ECDSASigner(key: ES512PublicKey(parameters: (x, y)))
}
default:
return nil
}

// EdDSA

case .octetKeyPair:
guard let algorithm = algorithm ?? self.jwk.algorithm else {
return nil
}

guard let curve = self.jwk.curve.flatMap({ EdDSACurve(rawValue: $0.rawValue) }) else {
return nil
guard let curve = self.curve.flatMap({ EdDSACurve(rawValue: $0.rawValue) }) else {
throw JWTError.invalidJWK(reason: "Invalid EdDSA curve")
}

let algorithm = alg ?? self.algorithm

switch (algorithm, self.jwk.x, self.jwk.privateExponent) {
switch (algorithm, self.x, self.privateExponent) {
case let (.eddsa, .some(x), .some(d)):
let key = try? EdDSA.PrivateKey(x: x, d: d, curve: curve)
return key.map { .init(algorithm: EdDSASigner(key: $0)) }
let key = try EdDSA.PrivateKey(x: x, d: d, curve: curve)
return EdDSASigner(key: key)

case let (.eddsa, .some(x), .none):
let key = try? EdDSA.PublicKey(x: x, curve: curve)
return key.map { .init(algorithm: EdDSASigner(key: $0)) }
let key = try EdDSA.PublicKey(x: x, curve: curve)
return EdDSASigner(key: key)

default:
return nil
Expand Down
4 changes: 3 additions & 1 deletion Sources/JWTKit/JWTError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ public struct JWTError: Error, Sendable {
.init(backing: .init(errorType: .unknownKID, kid: kid))
}

public static let invalidJWK = Self(errorType: .invalidJWK)
public static func invalidJWK(reason: String) -> Self {
.init(backing: .init(errorType: .invalidJWK, reason: reason))
}

public static func invalidBool(_ name: String) -> Self {
.init(backing: .init(errorType: .invalidBool, name: name))
Expand Down
18 changes: 11 additions & 7 deletions Sources/JWTKit/JWTKeyCollection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ public actor JWTKeyCollection: Sendable {

if let kid {
if self.storage[kid] != nil {
logger.debug("Overwriting existing JWT signer", metadata: ["kid": "\(kid)"])
self.logger.debug("Overwriting existing JWT signer", metadata: ["kid": "\(kid)"])
}
self.storage[kid] = .jwt(signer)
} else {
if self.default != nil {
logger.debug("Overwriting existing default JWT signer")
self.logger.debug("Overwriting existing default JWT signer")
}
self.default = .jwt(signer)
}
Expand Down Expand Up @@ -99,9 +99,10 @@ public actor JWTKeyCollection: Sendable {
isDefault: Bool? = nil
) throws -> Self {
guard let kid = jwk.keyIdentifier else {
throw JWTError.invalidJWK
throw JWTError.invalidJWK(reason: "Missing KID")
}
let signer = JWKSigner(jwk: jwk, parser: defaultJWTParser, serializer: defaultJWTSerializer)
let signer = try JWKSigner(jwk: jwk, parser: defaultJWTParser, serializer: defaultJWTSerializer)

self.storage[kid] = .jwk(signer)
switch (self.default, isDefault) {
case (.none, .none), (_, .some(true)):
Expand Down Expand Up @@ -131,10 +132,13 @@ public actor JWTKeyCollection: Sendable {
case let .jwt(jwt):
return jwt
case let .jwk(jwk):
if let signer = jwk.signer(for: alg.flatMap { JWK.Algorithm(rawValue: $0) }) {
if let signer = jwk.signer {
return signer
} else {
throw JWTError.generic(identifier: "Algorithm", reason: "Invalid algorithm or unable to create signer with provided algorithm.")
guard let alg, let jwkAlg = JWK.Algorithm(rawValue: alg) else {
throw JWTError.generic(identifier: "Algorithm", reason: "Invalid algorithm or unable to create signer with provided algorithm.")
}
return try jwk.makeSigner(for: jwkAlg)
}
}
}
Expand Down Expand Up @@ -182,7 +186,7 @@ public actor JWTKeyCollection: Sendable {
) throws -> Payload
where Payload: JWTPayload
{
try (parser ?? defaultJWTParser).parse(token, as: Payload.self).payload
try (parser ?? self.defaultJWTParser).parse(token, as: Payload.self).payload
}

/// Verifies and decodes a JWT token to extract the payload.
Expand Down
2 changes: 1 addition & 1 deletion Sources/JWTKit/JWTSigner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ final class JWTSigner: Sendable {
let parser: any JWTParser
let serializer: any JWTSerializer

init(algorithm: JWTAlgorithm, parser: some JWTParser = DefaultJWTParser(), serializer: some JWTSerializer = DefaultJWTSerializer()) {
init(algorithm: JWTAlgorithm, parser: any JWTParser = DefaultJWTParser(), serializer: any JWTSerializer = DefaultJWTSerializer()) {
self.algorithm = algorithm
self.parser = parser
self.serializer = serializer
Expand Down
3 changes: 2 additions & 1 deletion Tests/JWTKitTests/JWTKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class JWTKitTests: XCTestCase {
exp: .init(value: .init(timeIntervalSince1970: 2_000_000_000))
)
let data = try await keyCollection.sign(payload, kid: "1234")

// test private signer decoding
try await XCTAssertEqualAsync(await keyCollection.verify(data, as: TestPayload.self), payload)
// test public signer decoding
Expand All @@ -212,7 +213,7 @@ class JWTKitTests: XCTestCase {
let keyCollection = try await JWTKeyCollection().use(jwksJSON: json)

await XCTAssertNoThrowAsync(try await keyCollection.getKey())
var a: JWTAlgorithm, b: JWTAlgorithm
let a: JWTAlgorithm, b: JWTAlgorithm
do {
a = try await keyCollection.getKey(for: "a")
} catch {
Expand Down

0 comments on commit 26555af

Please sign in to comment.