Skip to content

Commit

Permalink
feat: support new MLS cipher suite [WPB-8592] 🍒 (#2732) (#2752)
Browse files Browse the repository at this point in the history
* Commit with unresolved merge conflicts

* fix merge issues

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Mohamad Jaara <[email protected]>
Co-authored-by: Yamil Medina <[email protected]>
  • Loading branch information
3 people authored May 26, 2024
1 parent bbd2621 commit 81cf9dc
Show file tree
Hide file tree
Showing 48 changed files with 455 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@ import java.nio.file.Files

actual open class BaseMLSClientTest {

actual suspend fun createMLSClient(clientId: CryptoQualifiedClientId): MLSClient {
return createCoreCrypto(clientId).mlsClient(clientId)
actual suspend fun createMLSClient(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): MLSClient {
return createCoreCrypto(clientId, allowedCipherSuites, defaultCipherSuite).mlsClient(clientId)
}

actual suspend fun createCoreCrypto(clientId: CryptoQualifiedClientId): CoreCryptoCentral {
actual suspend fun createCoreCrypto(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): CoreCryptoCentral {
val root = Files.createTempDirectory("mls").toFile()
val keyStore = root.resolve("keystore-$clientId")
return coreCryptoCentral(keyStore.absolutePath, "test")
return coreCryptoCentral(keyStore.absolutePath, "test", allowedCipherSuites, defaultCipherSuite)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ actual open class BaseProteusClientTest {

actual suspend fun createProteusClient(proteusStore: ProteusStoreRef, databaseKey: ProteusDBSecret?): ProteusClient {
return databaseKey?.let {
coreCryptoCentral(proteusStore.value, it.value).proteusClient()
coreCryptoCentral(proteusStore.value, it.value, emptyList(), 0.toUShort()).proteusClient()
} ?: cryptoboxProteusClient(proteusStore.value, testCoroutineScheduler, testCoroutineScheduler)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,21 @@ package com.wire.kalium.cryptography
import java.nio.file.Files

actual open class BaseMLSClientTest {
actual suspend fun createMLSClient(clientId: CryptoQualifiedClientId): MLSClient {
return createCoreCrypto(clientId).mlsClient(clientId)
actual suspend fun createMLSClient(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): MLSClient {
return createCoreCrypto(clientId, allowedCipherSuites, defaultCipherSuite).mlsClient(clientId)
}

actual suspend fun createCoreCrypto(clientId: CryptoQualifiedClientId): CoreCryptoCentral {
actual suspend fun createCoreCrypto(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): CoreCryptoCentral {
val root = Files.createTempDirectory("mls").toFile()
val keyStore = root.resolve("keystore-$clientId")
return coreCryptoCentral(keyStore.absolutePath, "test")
return coreCryptoCentral(keyStore.absolutePath, "test", allowedCipherSuites, defaultCipherSuite)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ actual open class BaseProteusClientTest {

actual suspend fun createProteusClient(proteusStore: ProteusStoreRef, databaseKey: ProteusDBSecret?): ProteusClient {
return databaseKey?.let {
coreCryptoCentral(proteusStore.value, it.value).proteusClient()
coreCryptoCentral(proteusStore.value, it.value, emptyList(), null).proteusClient()
} ?: cryptoboxProteusClient(proteusStore.value, testCoroutineScheduler,testCoroutineScheduler)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import com.wire.crypto.CoreCryptoCallbacks
import platform.Foundation.NSFileManager
import kotlin.time.Duration

actual suspend fun coreCryptoCentral(rootDir: String, databaseKey: String): CoreCryptoCentral {
actual suspend fun coreCryptoCentral(
rootDir: String,
databaseKey: String,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort?
): CoreCryptoCentral {
val path = "$rootDir/${CoreCryptoCentralImpl.KEYSTORE_NAME}"
NSFileManager.defaultManager.createDirectoryAtPath(rootDir, withIntermediateDirectories = true, null, null)
val coreCrypto = CoreCrypto.deferredInit(path, databaseKey, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ import platform.Foundation.NSURL
import platform.Foundation.URLByAppendingPathComponent

actual open class BaseMLSClientTest actual constructor() {
actual suspend fun createMLSClient(clientId: CryptoQualifiedClientId): MLSClient {
return createCoreCrypto(clientId).mlsClient(clientId)
actual suspend fun createMLSClient(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): MLSClient {
return createCoreCrypto(clientId, allowedCipherSuites, defaultCipherSuite).mlsClient(clientId)
}

actual suspend fun createCoreCrypto(clientId: CryptoQualifiedClientId): CoreCryptoCentral {
actual suspend fun createCoreCrypto(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): CoreCryptoCentral {
val rootDir = NSURL.fileURLWithPath(NSTemporaryDirectory() + "/mls", isDirectory = true)
NSFileManager.defaultManager.createDirectoryAtURL(rootDir, true, null, null)
val keyStore = rootDir.URLByAppendingPathComponent("keystore-$clientId")!!
return coreCryptoCentral(keyStore.path!!, "test")
return coreCryptoCentral(keyStore.path!!, "test", allowedCipherSuites, defaultCipherSuite)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ actual open class BaseProteusClientTest actual constructor() {
proteusStore: ProteusStoreRef,
databaseKey: ProteusDBSecret?
): ProteusClient {
return coreCryptoCentral(proteusStore.value, "secret").proteusClient()
return coreCryptoCentral(proteusStore.value, "secret", emptyList(), null).proteusClient()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,31 @@
*/
package com.wire.kalium.cryptography

import com.wire.crypto.Ciphersuites
import com.wire.crypto.ClientId
import com.wire.crypto.CoreCrypto
import com.wire.crypto.CoreCryptoCallbacks
import com.wire.crypto.client.Ciphersuites
import com.wire.crypto.coreCryptoDeferredInit
import com.wire.kalium.cryptography.MLSClientImpl.Companion.toCrlRegistration
import com.wire.kalium.cryptography.exceptions.CryptographyException
import java.io.File

actual suspend fun coreCryptoCentral(rootDir: String, databaseKey: String): CoreCryptoCentral {
actual suspend fun coreCryptoCentral(
rootDir: String,
databaseKey: String,
allowedCipherSuites: Ciphersuites,
defaultCipherSuite: UShort?
): CoreCryptoCentral {
val path = "$rootDir/${CoreCryptoCentralImpl.KEYSTORE_NAME}"
File(rootDir).mkdirs()
val coreCrypto = coreCryptoDeferredInit(path, databaseKey, Ciphersuites.DEFAULT.lower(), null)
val coreCrypto = coreCryptoDeferredInit(path, databaseKey, allowedCipherSuites, null)
coreCrypto.setCallbacks(Callbacks())
return CoreCryptoCentralImpl(coreCrypto, rootDir)
return CoreCryptoCentralImpl(
cc = coreCrypto,
rootDir = rootDir,
cipherSuite = allowedCipherSuites,
defaultCipherSuite = defaultCipherSuite
)
}

private class Callbacks : CoreCryptoCallbacks {
Expand Down Expand Up @@ -61,12 +71,18 @@ private class Callbacks : CoreCryptoCallbacks {
}
}

class CoreCryptoCentralImpl(private val cc: CoreCrypto, private val rootDir: String) : CoreCryptoCentral {
class CoreCryptoCentralImpl(
private val cc: CoreCrypto,
private val rootDir: String,
// TODO: remove one they are removed from the CC api
private val cipherSuite: Ciphersuites,
private val defaultCipherSuite: UShort?
) : CoreCryptoCentral {
fun getCoreCrypto() = cc

override suspend fun mlsClient(clientId: CryptoQualifiedClientId): MLSClient {
cc.mlsInit(clientId.toString().encodeToByteArray(), Ciphersuites.DEFAULT.lower(), null)
return MLSClientImpl(cc)
cc.mlsInit(clientId.toString().encodeToByteArray(), cipherSuite, null)
return MLSClientImpl(cc, defaultCipherSuite!!)
}

override suspend fun mlsClient(
Expand All @@ -79,7 +95,7 @@ class CoreCryptoCentralImpl(private val cc: CoreCrypto, private val rootDir: Str
(enrollment as E2EIClientImpl).wireE2eIdentity,
certificateChain, newMLSKeyPackageCount
)
return MLSClientImpl(cc)
return MLSClientImpl(cc, defaultCipherSuite!!)
}

override suspend fun proteusClient(): ProteusClient {
Expand All @@ -100,7 +116,7 @@ class CoreCryptoCentralImpl(private val cc: CoreCrypto, private val rootDir: Str
handle,
teamId,
expiry.inWholeSeconds.toUInt(),
Ciphersuites.DEFAULT.lower().first()
defaultCipherSuite!!
)

)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.wire.crypto.MlsCredentialType
import com.wire.crypto.MlsGroupInfoEncryptionType
import com.wire.crypto.MlsRatchetTreeType
import com.wire.crypto.MlsWirePolicy
import com.wire.crypto.client.Ciphersuites
import com.wire.crypto.Ciphersuite
import io.ktor.util.decodeBase64Bytes
import io.ktor.util.encodeBase64
import kotlin.time.Duration
Expand All @@ -41,25 +41,26 @@ typealias ConversationId = ByteArray
@Suppress("TooManyFunctions")
@OptIn(ExperimentalUnsignedTypes::class)
class MLSClientImpl(
private val coreCrypto: CoreCrypto
private val coreCrypto: CoreCrypto,
private val defaultCipherSuite: Ciphersuite
) : MLSClient {
private val keyRotationDuration: Duration = 30.toDuration(DurationUnit.DAYS)
private val defaultGroupConfiguration = CustomConfiguration(keyRotationDuration.toJavaDuration(), MlsWirePolicy.PLAINTEXT)
private val defaultCiphersuite = Ciphersuites.DEFAULT.lower().first()

override suspend fun close() {
coreCrypto.close()
}

override suspend fun getPublicKey(): ByteArray {
return coreCrypto.clientPublicKey(defaultCiphersuite, toCredentialType(getMLSCredentials()))
return coreCrypto.clientPublicKey(defaultCipherSuite, toCredentialType(getMLSCredentials()))
}

override suspend fun generateKeyPackages(amount: Int): List<ByteArray> {
return coreCrypto.clientKeypackages(defaultCiphersuite, toCredentialType(getMLSCredentials()), amount.toUInt())
return coreCrypto.clientKeypackages(defaultCipherSuite, toCredentialType(getMLSCredentials()), amount.toUInt())
}

override suspend fun validKeyPackageCount(): ULong {
return coreCrypto.clientValidKeypackagesCount(defaultCiphersuite, toCredentialType(getMLSCredentials()))
return coreCrypto.clientValidKeypackagesCount(defaultCipherSuite, toCredentialType(getMLSCredentials()))
}

override suspend fun updateKeyingMaterial(groupId: MLSGroupId): CommitBundle {
Expand All @@ -78,7 +79,7 @@ class MLSClientImpl(
return coreCrypto.newExternalAddProposal(
conversationId = groupId.decodeBase64Bytes(),
epoch = epoch,
ciphersuite = defaultCiphersuite,
ciphersuite = defaultCipherSuite,
credentialType = toCredentialType(getMLSCredentials())
)
}
Expand Down Expand Up @@ -106,7 +107,7 @@ class MLSClientImpl(
externalSenders: List<Ed22519Key>
) {
val conf = ConversationConfiguration(
defaultCiphersuite,
defaultCipherSuite,
externalSenders.map { it.value },
defaultGroupConfiguration
)
Expand Down Expand Up @@ -210,7 +211,7 @@ class MLSClientImpl(
handle,
teamId,
expiry.inWholeSeconds.toUInt(),
defaultCiphersuite
defaultCipherSuite
)
)
}
Expand All @@ -227,7 +228,7 @@ class MLSClientImpl(
handle,
teamId,
expiry.inWholeSeconds.toUInt(),
defaultCiphersuite
defaultCipherSuite
)
)
}
Expand All @@ -237,7 +238,7 @@ class MLSClientImpl(
}

override suspend fun isE2EIEnabled(): Boolean {
return coreCrypto.e2eiIsEnabled(defaultCiphersuite)
return coreCrypto.e2eiIsEnabled(defaultCipherSuite)
}

override suspend fun getMLSCredentials(): CredentialType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import kotlin.time.Duration
interface CoreCryptoCentral {
suspend fun mlsClient(clientId: CryptoQualifiedClientId): MLSClient

suspend fun mlsClient(enrollment: E2EIClient, certificateChain: CertificateChain, newMLSKeyPackageCount: UInt): MLSClient
suspend fun mlsClient(
enrollment: E2EIClient,
certificateChain: CertificateChain,
newMLSKeyPackageCount: UInt
): MLSClient

suspend fun proteusClient(): ProteusClient

Expand Down Expand Up @@ -59,4 +63,9 @@ interface CoreCryptoCentral {
suspend fun registerIntermediateCa(pem: CertificateChain)
}

expect suspend fun coreCryptoCentral(rootDir: String, databaseKey: String): CoreCryptoCentral
expect suspend fun coreCryptoCentral(
rootDir: String,
databaseKey: String,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort?
): CoreCryptoCentral
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@ package com.wire.kalium.cryptography

expect open class BaseMLSClientTest() {

suspend fun createMLSClient(clientId: CryptoQualifiedClientId): MLSClient
suspend fun createMLSClient(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): MLSClient

suspend fun createCoreCrypto(clientId: CryptoQualifiedClientId): CoreCryptoCentral
suspend fun createCoreCrypto(
clientId: CryptoQualifiedClientId,
allowedCipherSuites: List<UShort>,
defaultCipherSuite: UShort
): CoreCryptoCentral

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class E2EIClientTest : BaseMLSClientTest() {
}

private suspend fun createE2EIClient(user: SampleUser): E2EIClient {
return createMLSClient(user.qualifiedClientId).e2eiNewActivationEnrollment(
return createMLSClient(user.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITE).e2eiNewActivationEnrollment(
user.name, user.handle, user.teamId,90.days
)
}
Expand Down Expand Up @@ -112,7 +112,7 @@ class E2EIClientTest : BaseMLSClientTest() {

@Test
fun givenClient_whenCallingCheckOrderRequest_ReturnNonEmptyResult() = runTest {
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId)
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITE)
val e2eiClient = createE2EIClient(ALICE1)
e2eiClient.directoryResponse(ACME_DIRECTORY_API_RESPONSE)
e2eiClient.setAccountResponse(NEW_ACCOUNT_API_RESPONSE)
Expand All @@ -130,7 +130,7 @@ class E2EIClientTest : BaseMLSClientTest() {

@Test
fun givenClient_whenCallingFinalizeRequest_ReturnNonEmptyResult() = runTest {
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId)
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITE)
val e2eiClient = createE2EIClient(ALICE1)
e2eiClient.directoryResponse(ACME_DIRECTORY_API_RESPONSE)
e2eiClient.setAccountResponse(NEW_ACCOUNT_API_RESPONSE)
Expand All @@ -149,7 +149,7 @@ class E2EIClientTest : BaseMLSClientTest() {

@Test
fun givenClient_whenCallingCertificateRequest_ReturnNonEmptyResult() = runTest {
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId)
val coreCryptoCentral = createCoreCrypto(ALICE1.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITE)
val e2eiClient = createE2EIClient(ALICE1)
e2eiClient.directoryResponse(ACME_DIRECTORY_API_RESPONSE)
e2eiClient.setAccountResponse(NEW_ACCOUNT_API_RESPONSE)
Expand All @@ -169,6 +169,8 @@ class E2EIClientTest : BaseMLSClientTest() {

companion object {

val DEFAULT_CIPHER_SUITE = 1.toUShort()
val ALLOWED_CIPHER_SUITES = listOf(1.toUShort())
val ALICE1 = SampleUser(
CryptoQualifiedID("837655f7-b448-465a-b4b2-93f0919b38f0", "elna.wire.link"),
CryptoClientId("fb4b58152e20"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MLSClientTest : BaseMLSClientTest() {
}

private suspend fun createClient(user: SampleUser): MLSClient {
return createMLSClient(user.qualifiedClientId)
return createMLSClient(user.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES)
}

@Test
Expand Down Expand Up @@ -188,6 +188,8 @@ class MLSClientTest : BaseMLSClientTest() {
}

companion object {
val ALLOWED_CIPHER_SUITES = listOf(1.toUShort())
val DEFAULT_CIPHER_SUITES = 1.toUShort()
const val MLS_CONVERSATION_ID = "JfflcPtUivbg+1U3Iyrzsh5D2ui/OGS5Rvf52ipH5KY="
const val PLAIN_TEXT = "Hello World"
val ALICE1 = SampleUser(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
Expand Down
Loading

0 comments on commit 81cf9dc

Please sign in to comment.