Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions combined/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
Expand Down Expand Up @@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))

var relayAcceptFn func(conn net.Conn)
var relayAcceptFn func(conn listener.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
Expand Down Expand Up @@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
}

// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
Expand All @@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
return
}

lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}

log.Debugf("Relay WS client connected from: %s", rAddr)

conn := ws.NewConn(wsConn, lAddr, rAddr)
conn := ws.NewConn(wsConn, rAddr)
acceptFn(conn)
}

Expand Down
20 changes: 14 additions & 6 deletions relay/server/handshake.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
package server

import (
"context"
"fmt"
"net"
"time"

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
)

const (
// handshakeTimeout bounds how long a connection may remain in the
// pre-authentication handshake phase before being closed.
handshakeTimeout = 10 * time.Second
)

type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
Expand Down Expand Up @@ -58,17 +66,17 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
}

type handshake struct {
conn net.Conn
conn listener.Conn
validator Validator
preparedMsg *preparedMsg

handshakeMethodAuth bool
peerID *messages.PeerID
}

func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
n, err := h.conn.Read(ctx, buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
Expand Down Expand Up @@ -103,15 +111,15 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return peerID, nil
}

func (h *handshake) handshakeResponse() error {
func (h *handshake) handshakeResponse(ctx context.Context) error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
} else {
responseMsg = h.preparedMsg.responseHelloMsg
}

if _, err := h.conn.Write(responseMsg); err != nil {
if _, err := h.conn.Write(ctx, responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}

Expand Down
14 changes: 14 additions & 0 deletions relay/server/listener/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package listener

import (
"context"
"net"
)

// Conn is the relay connection contract implemented by WS and QUIC transports.
type Conn interface {
Read(ctx context.Context, b []byte) (n int, err error)
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
}
14 changes: 0 additions & 14 deletions relay/server/listener/listener.go

This file was deleted.

39 changes: 7 additions & 32 deletions relay/server/listener/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,26 @@ package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/quic-go/quic-go"
)

type Conn struct {
session *quic.Conn
closed bool
closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
session *quic.Conn
closed bool
closedMu sync.Mutex
}

func NewConn(session *quic.Conn) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{
session: session,
ctx: ctx,
ctxCancel: cancel,
session: session,
}
}

func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
Expand All @@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return n, nil
}

func (c *Conn) Write(b []byte) (int, error) {
func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
Comment thread
pappz marked this conversation as resolved.
}

func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}

func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}

func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}

func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}

func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
Expand All @@ -74,8 +51,6 @@ func (c *Conn) Close() error {
c.closed = true
c.closedMu.Unlock()

c.ctxCancel() // Cancel the context

sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}
Expand Down
4 changes: 2 additions & 2 deletions relay/server/listener/quic/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"

"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
nbRelay "github.com/netbirdio/netbird/shared/relay"
)

Expand All @@ -25,7 +25,7 @@ type Listener struct {
listener *quic.Listener
}

func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Expand Down
30 changes: 5 additions & 25 deletions relay/server/listener/ws/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@ const (

type Conn struct {
*websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr

closed bool
closedMu sync.Mutex
ctx context.Context
}

func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr,
ctx: context.Background(),
}
}

func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx)
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
t, r, err := c.Reader(ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
Expand All @@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// Write writes a binary message with the given payload.
// It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
defer ctxCancel()

err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err
Comment thread
pappz marked this conversation as resolved.
}

func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}

func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}

func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}

func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}

func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true
Expand Down
24 changes: 9 additions & 15 deletions relay/server/listener/ws/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"fmt"
"net"
"net/http"
"time"

"github.com/coder/websocket"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay"
)

Expand All @@ -27,18 +29,19 @@ type Listener struct {
TLSConfig *tls.Config

server *http.Server
acceptFn func(conn net.Conn)
acceptFn func(conn relaylistener.Conn)
}

func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept)

l.server = &http.Server{
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
ReadHeaderTimeout: 5 * time.Second,
}

log.Infof("WS server listening address: %s", l.Address)
Expand Down Expand Up @@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}

lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}

log.Infof("WS client connected from: %s", rAddr)

conn := NewConn(wsConn, lAddr, rAddr)
conn := NewConn(wsConn, rAddr)
l.acceptFn(conn)
}

Expand Down
Loading
Loading