Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
}
}

relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())

offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Expand All @@ -2399,7 +2401,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
RelaySrvIP: relayIP,
SessionID: sessionID,
}
return &offerAnswer, nil
}

// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
// netip.Addr. Returns the zero value for empty input and logs a warning
// for malformed payloads.
func decodeRelayIP(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
ip, ok := netip.AddrFromSlice(b)
if !ok {
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
return netip.Addr{}
}
return ip.Unmap()
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
8 changes: 7 additions & 1 deletion client/internal/peer/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package peer
import (
"context"
"errors"
"net/netip"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -40,6 +41,10 @@ type OfferAnswer struct {

// relay server address
RelaySrvAddress string
// RelaySrvIP is the IP the remote peer is connected to on its
// relay server. Used as a dial target if DNS for RelaySrvAddress
// fails. Zero value if the peer did not advertise an IP.
RelaySrvIP netip.Addr
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
Expand Down Expand Up @@ -217,8 +222,9 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
answer.SessionID = &sid
}

if addr, err := h.relay.RelayInstanceAddress(); err == nil {
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr
answer.RelaySrvIP = ip
}

return answer
Expand Down
20 changes: 10 additions & 10 deletions client/internal/peer/signaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
log.Warnf("failed to get session ID bytes: %v", err)
}
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
remoteKey,
&signal.Credential{
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
Type: bodyType,
WgListenPort: offerAnswer.WgListenPort,
Credential: &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress,
sessionIDBytes)
RosenpassPubKey: offerAnswer.RosenpassPubKey,
RosenpassAddr: offerAnswer.RosenpassAddr,
RelaySrvAddress: offerAnswer.RelaySrvAddress,
RelaySrvIP: offerAnswer.RelaySrvIP,
SessionID: sessionIDBytes,
})
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/peer/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {

// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
for _, r := range d.relayMgr.ServerURLs() {
Expand Down
11 changes: 8 additions & 3 deletions client/internal/peer/worker_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"net/netip"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relaySupportedOnRemotePeer.Store(true)

// the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
if err != nil {
w.log.Errorf("failed to handle new offer: %s", err)
return
}

srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
var serverIP netip.Addr
if srv == remoteOfferAnswer.RelaySrvAddress {
serverIP = remoteOfferAnswer.RelaySrvIP
}

relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection")
Expand Down Expand Up @@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}

func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
return w.relayManager.RelayInstanceAddress()
}

Expand Down
113 changes: 110 additions & 3 deletions shared/relay/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package client

import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -146,6 +150,7 @@ func (cc *connContainer) close() {
type Client struct {
log *log.Entry
connectionURL string
serverIP netip.Addr
authTokenStore *auth.TokenStore
hashedID messages.PeerID

Expand All @@ -170,13 +175,22 @@ type Client struct {
}

// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
// is called.
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu)
}

// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is
// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the
// FQDN from serverURL via SNI.
func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})

c := &Client{
log: relayLog,
connectionURL: serverURL,
serverIP: serverIP,
authTokenStore: authTokenStore,
hashedID: hashedID,
mtu: mtu,
Expand Down Expand Up @@ -304,6 +318,41 @@ func (c *Client) ServerInstanceURL() (string, error) {
return c.instanceURL.String(), nil
}

// ConnectedIP returns the IP address of the live relay-server connection,
// extracted from the underlying socket's RemoteAddr. Zero value if not
// connected or if the address is not an IP literal.
func (c *Client) ConnectedIP() netip.Addr {
c.mu.Lock()
conn := c.relayConn
c.mu.Unlock()
if conn == nil {
return netip.Addr{}
}
addr := conn.RemoteAddr()
if addr == nil {
return netip.Addr{}
}
return extractIPLiteral(addr.String())
}

// extractIPLiteral returns the IP from address forms produced by the relay
// dialers (URL or host:port). Zero value if the host is not an IP.
func extractIPLiteral(s string) netip.Addr {
if u, err := url.Parse(s); err == nil && u.Host != "" {
s = u.Host
}
host, _, err := net.SplitHostPort(s)
if err != nil {
host = s
}
host = strings.Trim(host, "[]")
ip, err := netip.ParseAddr(host)
if err != nil {
return netip.Addr{}
}
return ip.Unmap()
}

// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock()
Expand Down Expand Up @@ -332,10 +381,17 @@ func (c *Client) Close() error {
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()

rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
conn, err := rd.Dial(ctx)
conn, err := c.dialDirect(ctx, dialers)
if err != nil {
return nil, err
if c.serverIP.IsValid() {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
}
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
fqdnConn, fErr := rd.Dial(ctx)
if fErr != nil {
return nil, fmt.Errorf("dial via server IP: %w; dial via FQDN: %w", err, fErr)
}
conn = fqdnConn
}
c.relayConn = conn

Expand All @@ -351,6 +407,57 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
return instanceURL, nil
}

// dialDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI. Returns an error if no
// usable server IP is configured or if the substituted URL is malformed.
func (c *Client) dialDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
if !c.serverIP.IsValid() || c.serverIP.IsUnspecified() {
return nil, errors.New("no usable server IP configured")
}

directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
}

c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName)

rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
return rd.Dial(ctx)
}

// substituteHost replaces the host portion of a rel/rels URL with ip,
// preserving the scheme and port. Returns the rewritten URL and the
// original host to use as the TLS ServerName, or empty if the original
// host is itself an IP literal (SNI requires a DNS name).
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
u, err := url.Parse(serverURL)
if err != nil {
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
}
if u.Scheme == "" || u.Host == "" {
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
}
if !ip.IsValid() {
return "", "", errors.New("invalid server IP")
}
origHost := u.Hostname()
if _, err := netip.ParseAddr(origHost); err == nil {
origHost = ""
}
ip = ip.Unmap()
newHost := ip.String()
if ip.Is6() {
newHost = "[" + newHost + "]"
}
if port := u.Port(); port != "" {
u.Host = newHost + ":" + port
} else {
u.Host = newHost
}
return u.String(), origHost, nil
}

func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
Expand Down
Loading
Loading