Skip to content

Commit

Permalink
Replace old waitForTunnelUp function
Browse files Browse the repository at this point in the history
After invoking VpnService.establish() we will get a tunnel file
descriptor that corresponds to the interface that was created. However,
this has no guarantee of the routing table beeing up to date, and we
might thus send traffic outside the tunnel. Previously this was done
through looking at the tunFd to see that traffic is sent to verify that
the routing table has changed. If no traffic is seen some traffic is
induced to a random IP address to ensure traffic can be seen. This new
implementation is slower but won't risk sending UDP traffic to a random
public address at the internet.
  • Loading branch information
Rawa committed Jan 2, 2025
1 parent 71fe0b4 commit 93a2161
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 183 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package net.mullvad.talpid

import android.net.ConnectivityManager
import android.net.ConnectivityManager
import android.net.LinkProperties
import android.os.ParcelFileDescriptor
import android.system.Os.socket
import androidx.annotation.CallSuper
import androidx.core.content.getSystemService
import androidx.lifecycle.lifecycleScope
Expand All @@ -12,9 +15,23 @@ import java.net.InetAddress
import kotlin.properties.Delegates.observable
import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe
import net.mullvad.mullvadvpn.lib.model.PrepareError
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.measureTimedValue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import net.mullvad.talpid.model.CreateTunResult
import net.mullvad.talpid.model.TunConfig
import net.mullvad.talpid.util.NetworkEvent
import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported
import net.mullvad.talpid.util.defaultCallbackFlow

open class TalpidVpnService : LifecycleVpnService() {
private var activeTunStatus by
Expand All @@ -36,11 +53,22 @@ open class TalpidVpnService : LifecycleVpnService() {
// Used by JNI
lateinit var connectivityListener: ConnectivityListener

private lateinit var defaultNetworkLinkProperties:
StateFlow<NetworkEvent.OnLinkPropertiesChanged?>

@CallSuper
override fun onCreate() {
super.onCreate()
connectivityListener = ConnectivityListener(getSystemService<ConnectivityManager>()!!)
connectivityListener.register(lifecycleScope)

val connectivityManager = getSystemService<ConnectivityManager>()!!

defaultNetworkLinkProperties =
connectivityManager
.defaultCallbackFlow()
.filterIsInstance<NetworkEvent.OnLinkPropertiesChanged>()
.stateIn(lifecycleScope, SharingStarted.Eagerly, null)
}

fun openTun(config: TunConfig): CreateTunResult {
Expand Down Expand Up @@ -95,7 +123,7 @@ open class TalpidVpnService : LifecycleVpnService() {
for (dnsServer in config.dnsServers) {
try {
addDnsServer(dnsServer)
} catch (exception: IllegalArgumentException) {
} catch (_: IllegalArgumentException) {
invalidDnsServerAddresses.add(dnsServer)
}
}
Expand Down Expand Up @@ -132,11 +160,18 @@ open class TalpidVpnService : LifecycleVpnService() {
return CreateTunResult.TunnelDeviceError
}

Logger.d("Vpn Interface Established")

if (vpnInterfaceFd == null) {
Logger.e("VpnInterface returned null")
return CreateTunResult.TunnelDeviceError
}

// Wait for android OS to respond back to us that the routes are setup so we don't send
// traffic before the routes are set up. Otherwise we might send traffic through the wrong
// interface
runBlocking { waitForRoutesWithTimeout(config) }

val tunFd = vpnInterfaceFd.detachFd()

waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 })
Expand All @@ -159,6 +194,30 @@ open class TalpidVpnService : LifecycleVpnService() {
is PrepareError.OtherAlwaysOnApp -> CreateTunResult.OtherAlwaysOnApp(appName)
}

@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun waitForRoutesWithTimeout(
config: TunConfig,
timeout: Duration = ROUTES_SETUP_TIMEOUT,
) {
val linkProperties =
withTimeoutOrNull(timeout = timeout) {
measureTimedValue {
defaultNetworkLinkProperties.filterNotNull().first {
it.linkProperties.matches(config)
}
}
.also { Logger.d("LinkProperties matching tunnel, took ${it.duration}") }
.value
}
if (linkProperties == null) {
Logger.w("Waiting for LinkProperties timed out")
}
}

// return true if LinkProperties matches the TunConfig
private fun LinkProperties.matches(tunConfig: TunConfig): Boolean =
linkAddresses.all { it.address in tunConfig.addresses }

private fun InetAddress.prefixLength(): Int =
when (this) {
is Inet4Address -> IPV4_PREFIX_LENGTH
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package net.mullvad.talpid.util

import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow

sealed interface NetworkEvent {
data class OnAvailable(val network: Network) : NetworkEvent

data object OnUnavailable : NetworkEvent

data class OnLinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) :
NetworkEvent

data class OnCapabilitiesChanged(
val network: Network,
val networkCapabilities: NetworkCapabilities,
) : NetworkEvent

data class OnBlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent

data class OnLosing(val network: Network, val maxMsToLive: Int) : NetworkEvent

data class OnLost(val network: Network) : NetworkEvent
}

fun ConnectivityManager.defaultCallbackFlow(): Flow<NetworkEvent> =
callbackFlow<NetworkEvent> {
val callback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(
network: Network,
linkProperties: LinkProperties,
) {
super.onLinkPropertiesChanged(network, linkProperties)
trySendBlocking(NetworkEvent.OnLinkPropertiesChanged(network, linkProperties))
}

override fun onAvailable(network: Network) {
super.onAvailable(network)
trySendBlocking(NetworkEvent.OnAvailable(network))
}

override fun onCapabilitiesChanged(
network: Network,
networkCapabilities: NetworkCapabilities,
) {
super.onCapabilitiesChanged(network, networkCapabilities)
trySendBlocking(
NetworkEvent.OnCapabilitiesChanged(network, networkCapabilities)
)
}

override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
super.onBlockedStatusChanged(network, blocked)
trySendBlocking(NetworkEvent.OnBlockedStatusChanged(network, blocked))
}

override fun onLosing(network: Network, maxMsToLive: Int) {
super.onLosing(network, maxMsToLive)
trySendBlocking(NetworkEvent.OnLosing(network, maxMsToLive))
}

override fun onLost(network: Network) {
super.onLost(network)
trySendBlocking(NetworkEvent.OnLost(network))
}

override fun onUnavailable() {
super.onUnavailable()
trySendBlocking(NetworkEvent.OnUnavailable)
}
}
registerDefaultNetworkCallback(callback)

awaitClose { unregisterNetworkCallback(callback) }
}
1 change: 0 additions & 1 deletion mullvad-jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod api;
mod classes;
mod problem_report;
mod talpid_vpn_service;

use jnix::{
jni::{
Expand Down
181 changes: 0 additions & 181 deletions mullvad-jni/src/talpid_vpn_service.rs

This file was deleted.

0 comments on commit 93a2161

Please sign in to comment.