Skip to content

Commit

Permalink
fix: Stuck on Setting up Wire after canceling E2EI during login WPB-1…
Browse files Browse the repository at this point in the history
…0046] 🍒
  • Loading branch information
borichellow committed Jul 18, 2024
1 parent afefda7 commit 1438f23
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val getProteusFingerprint: GetProteusFingerprintUseCase
get() = GetProteusFingerprintUseCaseImpl(preKeyRepository)

@OptIn(DelicateKaliumApi::class)
private val verifyExistingClientUseCase: VerifyExistingClientUseCase
get() = VerifyExistingClientUseCaseImpl(clientRepository)

get() = VerifyExistingClientUseCaseImpl(
selfUserId,
clientRepository,
isAllowedToRegisterMLSClient,
registerMLSClientUseCase
)
val importClient: ImportClientUseCase
get() = ImportClientUseCaseImpl(
clientRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ internal class GetOrRegisterClientUseCaseImpl(
clearOldClientRelatedData()
null
}

is VerifyExistingClientResult.Failure.E2EICertificateRequired -> RegisterClientResult.E2EICertificateRequired(
result.client,
result.userId
)
}
}
) ?: registerClient(registerClientParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map
import com.wire.kalium.util.DelicateKaliumApi

/**
* Checks if the given client is still exists on the backend, otherwise returns failure.
Expand All @@ -36,18 +40,46 @@ interface VerifyExistingClientUseCase {
suspend operator fun invoke(clientId: ClientId): VerifyExistingClientResult
}

internal class VerifyExistingClientUseCaseImpl(
private val clientRepository: ClientRepository
internal class VerifyExistingClientUseCaseImpl @OptIn(DelicateKaliumApi::class) constructor(
private val selfUserId: UserId,
private val clientRepository: ClientRepository,
private val isAllowedToRegisterMLSClient: IsAllowedToRegisterMLSClientUseCase,
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
) : VerifyExistingClientUseCase {

@OptIn(DelicateKaliumApi::class)
override suspend fun invoke(clientId: ClientId): VerifyExistingClientResult {
return clientRepository.selfListOfClients()
.fold({
VerifyExistingClientResult.Failure.Generic(it)
}, { listOfClients ->
val client = listOfClients.firstOrNull { it.id == clientId }
when {
(client == null) -> VerifyExistingClientResult.Failure.ClientNotRegistered

isAllowedToRegisterMLSClient() -> {
registerMLSClientUseCase.invoke(clientId = client.id).map {
if (it is RegisterMLSClientResult.E2EICertificateRequired)
VerifyExistingClientResult.Failure.E2EICertificateRequired(client, selfUserId)
else VerifyExistingClientResult.Success(client)
}.getOrElse { VerifyExistingClientResult.Failure.Generic(it) }
}

else -> VerifyExistingClientResult.Success(client)
}

if (client != null) {
VerifyExistingClientResult.Success(client)
if (isAllowedToRegisterMLSClient()) {
registerMLSClientUseCase.invoke(clientId = client.id).fold({
VerifyExistingClientResult.Failure.Generic(it)
}) {
if (it is RegisterMLSClientResult.E2EICertificateRequired)
VerifyExistingClientResult.Failure.E2EICertificateRequired(client, selfUserId)
else VerifyExistingClientResult.Success(client)
}
} else {
VerifyExistingClientResult.Success(client)
}
} else {
VerifyExistingClientResult.Failure.ClientNotRegistered
}
Expand All @@ -61,5 +93,6 @@ sealed class VerifyExistingClientResult {
sealed class Failure : VerifyExistingClientResult() {
data object ClientNotRegistered : Failure()
data class Generic(val genericFailure: CoreFailure) : Failure()
class E2EICertificateRequired(val client: Client, val userId: UserId) : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
* along with this program. If not, see http://www.gnu.org/licenses/.
*/


package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.framework.TestClient
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import io.mockative.Mock
import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.usecase.IsAllowedToRegisterMLSClientUseCaseArrangement
import com.wire.kalium.logic.util.arrangement.usecase.IsAllowedToRegisterMLSClientUseCaseArrangementImpl
import com.wire.kalium.logic.util.arrangement.usecase.RegisterMLSClientUseCaseArrangement
import com.wire.kalium.logic.util.arrangement.usecase.RegisterMLSClientUseCaseArrangementImpl
import com.wire.kalium.util.DelicateKaliumApi
import io.mockative.any
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.mock
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand All @@ -38,12 +41,53 @@ import kotlin.test.assertIs
class VerifyExistingClientUseCaseTest {

@Test
fun givenRegisteredClientId_whenInvoking_thenReturnSuccess() = runTest {
fun givenRegisteredClientIdAndNoMLS_whenInvoking_thenReturnSuccess() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = Arrangement()
.withSelfClientsResult(Either.Right(listOf(client)))
.arrange()
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(false)
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Success>(result)
assertEquals(client, result.client)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenRegisterMLSFails_thenReturnFailure() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Left(CoreFailure.Unknown(null)))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.Generic>(result)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenE2EIRequired_thenReturnE2EIRequiredFailure() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Right(RegisterMLSClientResult.E2EICertificateRequired))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.E2EICertificateRequired>(result)
}

@Test
fun givenRegisteredClientIdAndMLSAllowed_whenRegisterMLSSucceed_thenReturnSuccess() = runTest {
val clientId = ClientId("clientId")
val client = TestClient.CLIENT.copy(id = clientId)
val (_, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf(client)))
withIsAllowedToRegisterMLSClient(true)
withRegisterMLSClient(Either.Right(RegisterMLSClientResult.Success))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Success>(result)
assertEquals(client, result.client)
Expand All @@ -52,30 +96,31 @@ class VerifyExistingClientUseCaseTest {
@Test
fun givenNotRegisteredClientId_whenInvoking_thenReturnClientNotRegisteredFailure() = runTest {
val clientId = ClientId("clientId")
val (arrangement, useCase) = Arrangement()
.withSelfClientsResult(Either.Right(listOf()))
.arrange()
val (arrangement, useCase) = arrange {
withSelfClientsResult(Either.Right(listOf()))
}
val result = useCase.invoke(clientId)
assertIs<VerifyExistingClientResult.Failure.ClientNotRegistered>(result)
coVerify {
arrangement.clientRepository.persistClientId(any())
}.wasNotInvoked()
coVerify { arrangement.clientRepository.persistClientId(any()) }.wasNotInvoked()
}

private class Arrangement {
private fun arrange(block: suspend Arrangement.() -> Unit) = Arrangement(block).arrange()

@Mock
val clientRepository = mock(ClientRepository::class)
@OptIn(DelicateKaliumApi::class)
private class Arrangement(private val block: suspend Arrangement.() -> Unit) :
RegisterMLSClientUseCaseArrangement by RegisterMLSClientUseCaseArrangementImpl(),
ClientRepositoryArrangement by ClientRepositoryArrangementImpl(),
IsAllowedToRegisterMLSClientUseCaseArrangement by IsAllowedToRegisterMLSClientUseCaseArrangementImpl() {

val verifyExistingClientUseCase: VerifyExistingClientUseCase = VerifyExistingClientUseCaseImpl(clientRepository)
fun arrange() = run {
runBlocking { block() }

suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>): Arrangement {
coEvery {
clientRepository.selfListOfClients()
}.returns(result)
return this
this@Arrangement to VerifyExistingClientUseCaseImpl(
TestUser.USER_ID,
clientRepository,
isAllowedToRegisterMLSClientUseCase,
registerMLSClientUseCase
)
}

fun arrange() = this to verifyExistingClientUseCase
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package com.wire.kalium.logic.util.arrangement.repository

import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.OtherUserClient
import com.wire.kalium.logic.data.conversation.ClientId
Expand Down Expand Up @@ -58,6 +60,8 @@ internal interface ClientRepositoryArrangement {
result: Either<StorageFailure, Unit>,
clients: Matcher<List<InsertClientParam>> = AnyMatcher(valueOf())
)

suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>)
}

internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangement {
Expand Down Expand Up @@ -112,4 +116,8 @@ internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangemen
clientRepository.storeUserClientListAndRemoveRedundantClients(any())
}.returns(result)
}

override suspend fun withSelfClientsResult(result: Either<NetworkFailure, List<Client>>) {
coEvery { clientRepository.selfListOfClients() }.returns(result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util.arrangement.usecase

import com.wire.kalium.logic.feature.client.IsAllowedToRegisterMLSClientUseCase
import com.wire.kalium.util.DelicateKaliumApi
import io.mockative.coEvery
import io.mockative.mock

@OptIn(DelicateKaliumApi::class)
interface IsAllowedToRegisterMLSClientUseCaseArrangement {

val isAllowedToRegisterMLSClientUseCase: IsAllowedToRegisterMLSClientUseCase

suspend fun withIsAllowedToRegisterMLSClient(isAllowed: Boolean)
}

@OptIn(DelicateKaliumApi::class)
class IsAllowedToRegisterMLSClientUseCaseArrangementImpl : IsAllowedToRegisterMLSClientUseCaseArrangement {

override val isAllowedToRegisterMLSClientUseCase: IsAllowedToRegisterMLSClientUseCase = mock(IsAllowedToRegisterMLSClientUseCase::class)

override suspend fun withIsAllowedToRegisterMLSClient(isAllowed: Boolean) {
coEvery { isAllowedToRegisterMLSClientUseCase() }.returns(isAllowed)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util.arrangement.usecase

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.feature.client.RegisterMLSClientResult
import com.wire.kalium.logic.feature.client.RegisterMLSClientUseCase
import com.wire.kalium.logic.functional.Either
import io.mockative.any
import io.mockative.coEvery
import io.mockative.mock

interface RegisterMLSClientUseCaseArrangement {

val registerMLSClientUseCase: RegisterMLSClientUseCase

suspend fun withRegisterMLSClient(result: Either<CoreFailure, RegisterMLSClientResult>)
}

class RegisterMLSClientUseCaseArrangementImpl : RegisterMLSClientUseCaseArrangement {
override val registerMLSClientUseCase: RegisterMLSClientUseCase = mock(RegisterMLSClientUseCase::class)

override suspend fun withRegisterMLSClient(result: Either<CoreFailure, RegisterMLSClientResult>) {
coEvery { registerMLSClientUseCase(any()) }.returns(result)
}

}

0 comments on commit 1438f23

Please sign in to comment.