Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Packages
.build
.index-build
.DS_Store
*.xcodeproj
Package.pins
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ let package = Package(
.library(name: "JWTKit", targets: ["JWTKit"])
],
dependencies: [
.package(url: "https://github.com/apple/swift-crypto.git", "3.8.0"..<"5.0.0"),
.package(url: "https://github.com/apple/swift-crypto.git", branch: "main"),
.package(url: "https://github.com/apple/swift-certificates.git", from: "1.2.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
],
Expand Down
6 changes: 2 additions & 4 deletions Sources/JWTKit/JWTKeyCollection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ public actor JWTKeyCollection: Sendable {
/// - Returns: Self for chaining.
@discardableResult
func add(_ signer: JWTSigner, for kid: JWKIdentifier? = nil) -> Self {
let signer = JWTSigner(
algorithm: signer.algorithm, parser: signer.parser, serializer: signer.serializer)
let signer = JWTSigner(algorithm: signer.algorithm, parser: signer.parser, serializer: signer.serializer)

if let kid {
if self.storage[kid] != nil {
Expand Down Expand Up @@ -106,8 +105,7 @@ public actor JWTKeyCollection: Sendable {
guard let kid = jwk.keyIdentifier else {
throw JWTError.invalidJWK(reason: "Missing KID")
}
let signer = try JWKSigner(
jwk: jwk, parser: defaultJWTParser, serializer: defaultJWTSerializer)
let signer = try JWKSigner(jwk: jwk, parser: defaultJWTParser, serializer: defaultJWTSerializer)

self.storage[kid] = .jwk(signer)
switch (self.default, isDefault) {
Expand Down
15 changes: 15 additions & 0 deletions Sources/JWTKit/MLDSA/JWTKeyCollection+MLDSA.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
extension JWTKeyCollection {
@_spi(PostQuantum)
@discardableResult
public func add(
mldsa key: some MLDSAKey,
kid: JWKIdentifier? = nil,
parser: some JWTParser = DefaultJWTParser(),
serializer: some JWTSerializer = DefaultJWTSerializer()
) -> Self {
self.add(
.init(algorithm: MLDSASigner(key: key), parser: parser, serializer: serializer),
for: kid
)
}
}
49 changes: 49 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSA.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import _CryptoExtras

#if !canImport(Darwin)
import FoundationEssentials
#else
import Foundation
#endif

@_spi(PostQuantum) public enum MLDSA: Sendable {}

extension MLDSA {
public struct PublicKey<KeyType>: MLDSAKey where KeyType: MLDSAType {
public typealias MLDSAType = KeyType

typealias PublicKey = KeyType.PrivateKey.PublicKey

let backing: any MLDSAPublicKey

public init(backing: some MLDSAPublicKey) {
self.backing = backing
}

public init(rawRepresentation: some DataProtocol) throws {
self.backing = try PublicKey(rawRepresentation: rawRepresentation)
}
}
}

extension MLDSA {
public struct PrivateKey<KeyType>: MLDSAKey where KeyType: MLDSAType {
public typealias MLDSAType = KeyType

typealias PrivateKey = KeyType.PrivateKey

let backing: any MLDSAPrivateKey

public var publicKey: MLDSA.PublicKey<KeyType> {
.init(backing: self.backing.publicKey)
}

public init(backing: some MLDSAPrivateKey) {
self.backing = backing
}

public init(seedRepresentation: some DataProtocol) throws {
self.backing = try PrivateKey(seedRepresentation: seedRepresentation)
}
}
}
14 changes: 14 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSA65+MLDSAKey.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import _CryptoExtras

@_spi(PostQuantum)
extension MLDSA65.PublicKey: MLDSAPublicKey {
public typealias MLDSAType = MLDSA65
}

@_spi(PostQuantum)
extension MLDSA65.PrivateKey: MLDSAPrivateKey {
public typealias MLDSAType = MLDSA65
}

@_spi(PostQuantum) public typealias MLDSA65PublicKey = MLDSA.PublicKey<MLDSA65>
@_spi(PostQuantum) public typealias MLDSA65PrivateKey = MLDSA.PrivateKey<MLDSA65>
14 changes: 14 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSA87+MLDSAKey.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import _CryptoExtras

@_spi(PostQuantum)
extension MLDSA87.PublicKey: MLDSAPublicKey {
public typealias MLDSAType = MLDSA87
}

@_spi(PostQuantum)
extension MLDSA87.PrivateKey: MLDSAPrivateKey {
public typealias MLDSAType = MLDSA87
}

@_spi(PostQuantum) public typealias MLDSA87PublicKey = MLDSA.PublicKey<MLDSA87>
@_spi(PostQuantum) public typealias MLDSA87PrivateKey = MLDSA.PrivateKey<MLDSA87>
5 changes: 5 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSAError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
enum MLDSAError: Error {
case noPrivateKey
case noPublicKey
case failedToSign(Error)
}
34 changes: 34 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSAKey.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

@_spi(PostQuantum)
public protocol MLDSAKey: Sendable {
associatedtype MLDSAType: JWTKit.MLDSAType
}

@_spi(PostQuantum)
public protocol MLDSAPublicKey: Sendable {
associatedtype MLDSAType

init(rawRepresentation: some DataProtocol) throws
var rawRepresentation: Data { get }
func isValidSignature<S: DataProtocol, D: DataProtocol>(_ signature: S, for data: D) -> Bool
func isValidSignature<S: DataProtocol, D: DataProtocol, C: DataProtocol>(
_ signature: S, for data: D, context: C
) -> Bool
}

@_spi(PostQuantum)
public protocol MLDSAPrivateKey: Sendable {
associatedtype MLDSAType
associatedtype PublicKey: MLDSAPublicKey

var seedRepresentation: Data { get }
var publicKey: PublicKey { get }
init(seedRepresentation: some DataProtocol) throws
func signature<D: DataProtocol>(for data: D) throws -> Data
func signature<D: DataProtocol, C: DataProtocol>(for data: D, context: C) throws -> Data
}
46 changes: 46 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSASigner.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import _CryptoExtras

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

struct MLDSASigner<Key: MLDSAKey>: JWTAlgorithm, Sendable {
let privateKey: MLDSA.PrivateKey<Key.MLDSAType>?
let publicKey: MLDSA.PublicKey<Key.MLDSAType>

var name: String = Key.MLDSAType.name

init(key: Key) {
switch key {
case let key as MLDSA.PrivateKey<Key.MLDSAType>:
self.privateKey = key
self.publicKey = key.publicKey
case let key as MLDSA.PublicKey<Key.MLDSAType>:
self.privateKey = nil
self.publicKey = key
default:
fatalError()
}
}

func sign(_ plaintext: some DataProtocol) throws -> [UInt8] {
guard let privateKey else {
throw JWTError.signingAlgorithmFailure(MLDSAError.noPrivateKey)
}

let signature: Data
do {
signature = try privateKey.backing.signature(for: plaintext)
} catch {
throw JWTError.signingAlgorithmFailure(MLDSAError.failedToSign(error))
}

return signature.copyBytes()
}

func verify(_ signature: some DataProtocol, signs plaintext: some DataProtocol) throws -> Bool {
publicKey.backing.isValidSignature(signature, for: plaintext)
}
}
18 changes: 18 additions & 0 deletions Sources/JWTKit/MLDSA/MLDSAType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import _CryptoExtras

@_spi(PostQuantum)
public protocol MLDSAType {
associatedtype PrivateKey: MLDSAPrivateKey

static var name: String { get }
}

@_spi(PostQuantum)
extension MLDSA65: MLDSAType {
public static var name: String { "ML-DSA-65" }
}

@_spi(PostQuantum)
extension MLDSA87: MLDSAType {
public static var name: String { "ML-DSA-87" }
}
78 changes: 78 additions & 0 deletions Tests/JWTKitTests/MLDSATests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import Crypto
import Foundation
@_spi(PostQuantum) import JWTKit
import Testing

@Suite("MLDSA Tests")
struct MLDSATests {
@Test("MLDSA65 Signing")
func sign65() async throws {
struct Foo: JWTPayload {
var bar: Int
func verify(using _: some JWTAlgorithm) throws {}
}

let key = try MLDSA65PrivateKey(
seedRepresentation: Data(fromHexEncodedString: mldsa65PrivateKeySeedRepresentation)!)

let keyCollection = JWTKeyCollection()
await keyCollection.add(mldsa: key)

let jwt = try await keyCollection.sign(Foo(bar: 42))
let verified = try await keyCollection.verify(jwt, as: Foo.self)

#expect(verified.bar == 42)
}

@Test("MLDSA87 Signing")
func sign87() async throws {
struct Foo: JWTPayload {
var bar: Int
func verify(using _: some JWTAlgorithm) throws {}
}

let key = try MLDSA87PrivateKey(
seedRepresentation: Data(fromHexEncodedString: mldsa65PrivateKeySeedRepresentation)!)

let keyCollection = JWTKeyCollection()
await keyCollection.add(mldsa: key)

let jwt = try await keyCollection.sign(Foo(bar: 42))
let verified = try await keyCollection.verify(jwt, as: Foo.self)

#expect(verified.bar == 42)

print(jwt)
}
}

let mldsa65PrivateKeySeedRepresentation =
"70cefb9aed5b68e018b079da8284b9d5cad5499ed9c265ff73588005d85c225c"

let mldsa87PrivateKeySeedRepresentation =
"19e9e5efe0c1549ddb1d72213636d16fe2faeb2428257004ae464094ca536a66"

extension Data {
init?(fromHexEncodedString string: String) {
func decodeNibble(u: UInt8) -> UInt8? {
switch u {
case 0x30...0x39: u - 0x30
case 0x41...0x46: u - 0x41 + 10
case 0x61...0x66: u - 0x61 + 10
default: nil
}
}

self.init(capacity: string.utf8.count / 2)

var iter = string.utf8.makeIterator()
while let c1 = iter.next() {
guard
let val1 = decodeNibble(u: c1),
let c2 = iter.next(),
let val2 = decodeNibble(u: c2)
else { return nil }
self.append(val1 << 4 + val2)
}
}
}
Loading