Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Stuck on Setting up Wire after canceling E2EI during login [WPB-10046] 🍒 #2886

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val getProteusFingerprint: GetProteusFingerprintUseCase
get() = GetProteusFingerprintUseCaseImpl(preKeyRepository)

@OptIn(DelicateKaliumApi::class)
private val verifyExistingClientUseCase: VerifyExistingClientUseCase
get() = VerifyExistingClientUseCaseImpl(clientRepository)

get() = VerifyExistingClientUseCaseImpl(
selfUserId,
clientRepository,
isAllowedToRegisterMLSClient,
registerMLSClientUseCase
)
val importClient: ImportClientUseCase
get() = ImportClientUseCaseImpl(
clientRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ internal class GetOrRegisterClientUseCaseImpl(
clearOldClientRelatedData()
null
}

is VerifyExistingClientResult.Failure.E2EICertificateRequired -> RegisterClientResult.E2EICertificateRequired(
result.client,
result.userId
)
}
}
) ?: registerClient(registerClientParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map
import com.wire.kalium.util.DelicateKaliumApi

/**
* Checks if the given client is still exists on the backend, otherwise returns failure.
Expand All @@ -36,18 +40,46 @@ interface VerifyExistingClientUseCase {
suspend operator fun invoke(clientId: ClientId): VerifyExistingClientResult
}

internal class VerifyExistingClientUseCaseImpl(
private val clientRepository: ClientRepository
internal class VerifyExistingClientUseCaseImpl @OptIn(DelicateKaliumApi::class) constructor(
private val selfUserId: UserId,
private val clientRepository: ClientRepository,
private val isAllowedToRegisterMLSClient: IsAllowedToRegisterMLSClientUseCase,
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
) : VerifyExistingClientUseCase {

@OptIn(DelicateKaliumApi::class)
override suspend fun invoke(clientId: ClientId): VerifyExistingClientResult {
return clientRepository.selfListOfClients()
.fold({
VerifyExistingClientResult.Failure.Generic(it)
}, { listOfClients ->
val client = listOfClients.firstOrNull { it.id == clientId }
when {
(client == null) -> VerifyExistingClientResult.Failure.ClientNotRegistered

isAllowedToRegisterMLSClient() -> {
registerMLSClientUseCase.invoke(clientId = client.id).map {
if (it is RegisterMLSClientResult.E2EICertificateRequired)
VerifyExistingClientResult.Failure.E2EICertificateRequired(client, selfUserId)
else VerifyExistingClientResult.Success(client)
}.getOrElse { VerifyExistingClientResult.Failure.Generic(it) }
}

else -> VerifyExistingClientResult.Success(client)
}

if (client != null) {
VerifyExistingClientResult.Success(client)
if (isAllowedToRegisterMLSClient()) {
registerMLSClientUseCase.invoke(clientId = client.id).fold({
VerifyExistingClientResult.Failure.Generic(it)
}) {
if (it is RegisterMLSClientResult.E2EICertificateRequired)
VerifyExistingClientResult.Failure.E2EICertificateRequired(client, selfUserId)
else VerifyExistingClientResult.Success(client)
}
} else {
VerifyExistingClientResult.Success(client)
}
} else {
VerifyExistingClientResult.Failure.ClientNotRegistered
}
Expand All @@ -61,5 +93,6 @@ sealed class VerifyExistingClientResult {
sealed class Failure : VerifyExistingClientResult() {
data object ClientNotRegistered : Failure()
data class Generic(val genericFailure: CoreFailure) : Failure()
class E2EICertificateRequired(val client: Client, val userId: UserId) : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
* along with this program. If not, see http://www.gnu.org/licenses/.
*/


package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.framework.TestClient
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import io.mockative.Mock
import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.usecase.IsAllowedToRegisterMLSClientUseCaseArrangement
import com.wire.kalium.logic.util.arrangement.usecase.IsAllowedToRegisterMLSClientUseCaseArrangementImpl
import com.wire.kalium.logic.util.arrangement.usecase.RegisterMLSClientUseCaseArrangement
import com.wire.kalium.logic.util.arrangement.usecase.RegisterMLSClientUseCaseArrangementImpl
import com.wire.kalium.util.DelicateKaliumApi
import io.mockative.any
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.mock
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand All @@ -38,12 +41,53 @@ import kotlin.test.assertIs
class VerifyExistingClientUseCaseTest {

@Test
fun givenRegisteredClientId_whenInvoking_thenReturnSuccess() = runTest {
fun givenRegisteredClientIdAndNoMLS_whenInvoking_thenReturnSuccess() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = Arrangement()
.withSelfClientsResult(Either.Right(listOf(client)))
.arrange()
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(false)
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Success>(result)
assertEquals(client, result.client)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenRegisterMLSFails_thenReturnFailure() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Left(CoreFailure.Unknown(null)))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.Generic>(result)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenE2EIRequired_thenReturnE2EIRequiredFailure() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Right(RegisterMLSClientResult.E2EICertificateRequired))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.E2EICertificateRequired>(result)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenRegisterMLSSucceed_thenReturnSuccess() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Right(RegisterMLSClientResult.Success))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Success>(result)
assertEquals(client, result.client)
Expand All @@ -52,30 +96,31 @@ class VerifyExistingClientUseCaseTest {
@Test
fun givenNotRegisteredClientId_whenInvoking_thenReturnClientNotRegisteredFailure() = runTest {
val clientId = ClientId("clientId")
val (arrangement, useCase) = Arrangement()
.withSelfClientsResult(Either.Right(listOf()))
.arrange()
val (arrangement, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf()))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.ClientNotRegistered>(result)
coVerify {
arrangement.clientRepository.persistClientId(any())
}.wasNotInvoked()
coVerify { arrangement.clientRepository.persistClientId(any()) }.wasNotInvoked()
}

private class Arrangement {
private fun arrange(block: suspend Arrangement.() -> Unit) = Arrangement(block).arrange()

@Mock
val clientRepository = mock(ClientRepository::class)
@OptIn(DelicateKaliumApi::class)
private class Arrangement(private val block: suspend Arrangement.() -> Unit) :
RegisterMLSClientUseCaseArrangement by RegisterMLSClientUseCaseArrangementImpl(),
ClientRepositoryArrangement by ClientRepositoryArrangementImpl(),
IsAllowedToRegisterMLSClientUseCaseArrangement by IsAllowedToRegisterMLSClientUseCaseArrangementImpl() {

val verifyExistingClientUseCase: VerifyExistingClientUseCase = VerifyExistingClientUseCaseImpl(clientRepository)
fun arrange() = run {
runBlocking { block() }

suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>): Arrangement {
coEvery {
clientRepository.selfListOfClients()
}.returns(result)
return this
this@Arrangement to VerifyExistingClientUseCaseImpl(
TestUser.USER_ID,
clientRepository,
isAllowedToRegisterMLSClientUseCase,
registerMLSClientUseCase
)
}

fun arrange() = this to verifyExistingClientUseCase
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package com.wire.kalium.logic.util.arrangement.repository

import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.OtherUserClient
import com.wire.kalium.logic.data.conversation.ClientId
Expand Down Expand Up @@ -58,6 +60,8 @@ internal interface ClientRepositoryArrangement {
result: Either<StorageFailure, Unit>,
clients: Matcher<List<InsertClientParam>> = AnyMatcher(valueOf())
)

suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>)
}

internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangement {
Expand Down Expand Up @@ -112,4 +116,8 @@ internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangemen
clientRepository.storeUserClientListAndRemoveRedundantClients(any())
}.returns(result)
}

override suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>) {
coEvery { clientRepository.selfListOfClients() }.returns(result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util.arrangement.usecase

import com.wire.kalium.logic.feature.client.IsAllowedToRegisterMLSClientUseCase
import com.wire.kalium.util.DelicateKaliumApi
import io.mockative.coEvery
import io.mockative.mock

@OptIn(DelicateKaliumApi::class)
interface IsAllowedToRegisterMLSClientUseCaseArrangement {

val isAllowedToRegisterMLSClientUseCase: IsAllowedToRegisterMLSClientUseCase

suspend fun withIsAllowedToRegisterMLSClient(isAllowed: Boolean)
}

@OptIn(DelicateKaliumApi::class)
class IsAllowedToRegisterMLSClientUseCaseArrangementImpl : IsAllowedToRegisterMLSClientUseCaseArrangement {

override val isAllowedToRegisterMLSClientUseCase: IsAllowedToRegisterMLSClientUseCase = mock(IsAllowedToRegisterMLSClientUseCase::class)

override suspend fun withIsAllowedToRegisterMLSClient(isAllowed: Boolean) {
coEvery { isAllowedToRegisterMLSClientUseCase() }.returns(isAllowed)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util.arrangement.usecase

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.feature.client.RegisterMLSClientResult
import com.wire.kalium.logic.feature.client.RegisterMLSClientUseCase
import com.wire.kalium.logic.functional.Either
import io.mockative.any
import io.mockative.coEvery
import io.mockative.mock

interface RegisterMLSClientUseCaseArrangement {

val registerMLSClientUseCase: RegisterMLSClientUseCase

suspend fun withRegisterMLSClient(result: Either<CoreFailure, RegisterMLSClientResult>)
}

class RegisterMLSClientUseCaseArrangementImpl : RegisterMLSClientUseCaseArrangement {
override val registerMLSClientUseCase: RegisterMLSClientUseCase = mock(RegisterMLSClientUseCase::class)

override suspend fun withRegisterMLSClient(result: Either<CoreFailure, RegisterMLSClientResult>) {
coEvery { registerMLSClientUseCase(any()) }.returns(result)
}

}
Loading