Skip to content

Commit f87b287

Browse files
committed
ipn/wg: merge amnezia and bepass
1 parent 757269f commit f87b287

File tree

3 files changed

+52
-80
lines changed

3 files changed

+52
-80
lines changed

intra/ipn/wg/amnezia.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ func (a *Amnezia) String() string {
4949
if a == nil {
5050
return "<nil>"
5151
}
52+
if !a.Set() {
53+
return "<unset>"
54+
}
5255
return fmt.Sprintf("%s: amnezia: jc(%d), jmin(%d), jmax(%d), s1(%d), s2(%d), h1(%d), h2(%d), h3(%d), h4(%d)",
5356
a.id, a.Jc, a.Jmin, a.Jmax, a.S1, a.S2, a.H1, a.H2, a.H3, a.H4)
5457
}
@@ -113,16 +116,16 @@ func (a *Amnezia) recv(pktptr *[]byte) (ok bool) {
113116
pkt, typ = a.strip(pkt)
114117

115118
switch typ {
116-
case a.H1:
119+
case device.MessageInitiationType, a.H1:
117120
typ = device.MessageInitiationType
118121
binary.LittleEndian.PutUint32(pkt[:h], device.MessageInitiationType)
119-
case a.H2:
122+
case device.MessageResponseType, a.H2:
120123
typ = device.MessageResponseType
121124
binary.LittleEndian.PutUint32(pkt[:h], device.MessageResponseType)
122-
case a.H3:
125+
case device.MessageCookieReplyType, a.H3:
123126
typ = device.MessageCookieReplyType
124127
binary.LittleEndian.PutUint32(pkt[:h], device.MessageCookieReplyType)
125-
case a.H4:
128+
case device.MessageTransportType, a.H4: // must be default?
126129
typ = device.MessageTransportType
127130
binary.LittleEndian.PutUint32(pkt[:h], device.MessageTransportType)
128131
}
@@ -192,7 +195,9 @@ func (a *Amnezia) instate(pkt []byte) ([]byte, uint32) {
192195
func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
193196
size := uint16(len(pkt))
194197
h := uint16(device.MessageTransportOffsetReceiver)
195-
defaultType := binary.LittleEndian.Uint32(pkt[:h])
198+
// assume the correct msg type is in just the first byte:
199+
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
200+
defaultType := uint8(pkt[0])
196201

197202
var discard uint16 = 0
198203
var possibleType uint32 = 0
@@ -211,13 +216,13 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
211216

212217
if maybeStrip {
213218
hdr := pkt[discard : discard+h]
214-
strippedType := binary.LittleEndian.Uint32(hdr)
215-
if strippedType == possibleType {
216-
return pkt[discard:], strippedType
217-
} // else: sizes match but msg types do not
219+
obsType := binary.LittleEndian.Uint32(hdr)
220+
if obsType == possibleType {
221+
return pkt[discard:], obsType
222+
} // else: msg type mismatch, but size matched
218223
} // else: nothing to discard
219224

220-
return pkt, defaultType
225+
return pkt, uint32(defaultType)
221226
}
222227

223228
func (a *Amnezia) logIfNeeded(dir string, typ uint32, n int) {

intra/ipn/wg/wgconn.go

+7-46
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ type StdNetBind struct {
102102
connect connector
103103
mh *multihost.MH
104104

105-
reserved []byte // overwrite the 3 wg reserved bytes
106-
overwriteReserve bool
107-
amnezia *Amnezia
108-
floodBa *core.Barrier[int, netip.AddrPort]
105+
amnezia *Amnezia
106+
floodBa *core.Barrier[int, netip.AddrPort]
109107

110108
mu sync.Mutex // protects following fields
111109
ipv4 *net.UDPConn
@@ -118,18 +116,16 @@ type StdNetBind struct {
118116
}
119117

120118
// TODO: get d, ep, f, rb through an Opts bag?
121-
func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, a *Amnezia, rb [3]byte) *StdNetBind {
119+
func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, a *Amnezia) *StdNetBind {
122120
s := &StdNetBind{
123121
id: id,
124122
connect: d,
125123
mh: ep,
126124
observer: f,
127125
amnezia: a,
128-
reserved: rb[:3], // github.com/bepass-org/warp-plus/blob/19ac233cc6/wiresocks/config.go#L184
129126
floodBa: core.NewKeyedBarrier[int, netip.AddrPort](minFloodInterval),
130127
sendAddr: core.NewZeroVolatile[netip.AddrPort](),
131128
}
132-
s.overwriteReserve = a.Set() || isReservedOverwitten(s.reserved)
133129
return s
134130
}
135131

@@ -326,15 +322,7 @@ func (s *StdNetBind) makeReceiveFn(uc *net.UDPConn) conn.ReceiveFunc {
326322
extend(uc, wgtimeout)
327323
n, addr, err := uc.ReadFromUDPAddrPort(b)
328324
if err == nil {
329-
if isReservedOverwitten(b) {
330-
if s.amnezia.Set() {
331-
recvOverwritten = s.amnezia.recv(&b)
332-
} else if n > 3 && isWgMsgType(b[0]) && recvOverwritten {
333-
// github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138
334-
copy(b[1:4], reservedZeros)
335-
recvOverwritten = true
336-
}
337-
}
325+
recvOverwritten = s.amnezia.recv(&b)
338326
numMsgs++
339327
}
340328

@@ -344,7 +332,7 @@ func (s *StdNetBind) makeReceiveFn(uc *net.UDPConn) conn.ReceiveFunc {
344332
}
345333

346334
s := fmt.Sprintf("wg: bind: %s recvFrom(%v): %d / ov? %t<=%t / err? %v",
347-
s.id, addr, n, s.overwriteReserve, recvOverwritten, err)
335+
s.id, addr, n, s.amnezia.Set(), recvOverwritten, err)
348336
if err == nil || timedout(err) {
349337
log.V(s)
350338
} else {
@@ -405,18 +393,9 @@ func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) {
405393

406394
datalen := len(data) // grab the length before we overwrite it
407395

408-
if s.overwriteReserve {
409-
if s.amnezia.Set() {
410-
overwritten = s.amnezia.send(&data)
411-
} else if datalen > 3 && isWgMsgType(data[0]) {
412-
// overwrite the 3 reserved bytes on non-random packets
413-
// from: github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/peer.go#L138
414-
copy(data[1:4], s.reserved)
415-
overwritten = true
416-
}
417-
}
396+
overwritten = s.amnezia.send(&data)
418397

419-
if !flooded && (experimentalWg || s.overwriteReserve) {
398+
if !flooded && (experimentalWg || s.amnezia.Set()) {
420399
if datalen == device.MessageInitiationSize {
421400
s.flood(uc, dst, fkHandshake) // was probably a handshake
422401
flooded = true
@@ -440,24 +419,6 @@ func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) {
440419
return err
441420
}
442421

443-
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/send.go#L456
444-
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
445-
func isWgMsgType(x byte) bool {
446-
// 1: MsgInitiation, 2: MsgResponse, 3: MsgCookieReply, 4: MsgTransport
447-
// blog.cloudflare.com/warp-technical-challenges/
448-
// Handshakes have to be performed every two minutes to rotate keys making
449-
// them insufficiently persistent. We could have forked the protocol to add
450-
// any number of additional fields, but it is important to us to remain wire
451-
// compatible with other WireGuard clients. Fortunately, WireGuard has a three
452-
// byte block in its header which is not currently used by other clients.
453-
// We decided to put our identifier in this region and still support messages
454-
// from other WireGuard clients (albeit with less reliable routing than we can
455-
// offer).
456-
// Though the open source Cloudflare WARP boring-tun impl does not do so:
457-
// github.com/cloudflare/boringtun/blob/64a2fc7c63/boringtun/src/noise/handshake.rs#L734
458-
return x >= device.MessageInitiationType && x <= device.MessageTransportType
459-
}
460-
461422
// flood c with random-sized, non-sense (unencrypted) packets.
462423
// this is okay to do because wireguard silently drops packets that won't decrypt.
463424
// github.com/WireGuard/wireguard-go/blob/19ac233cc6/wireguard/device/send.go#L96

intra/ipn/wgproxy.go

+30-24
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ package ipn
1717

1818
import (
1919
"bufio"
20-
"bytes"
2120
"context"
2221
"encoding/base64"
22+
"encoding/binary"
2323
"fmt"
2424
"net"
2525
"net/netip"
@@ -77,7 +77,6 @@ type wgifopts struct {
7777
peers map[string]device.NoisePublicKey
7878
dns, ep *multihost.MH
7979
mtu int
80-
clientid [3]byte
8180
amnezia *wg.Amnezia
8281
}
8382

@@ -91,7 +90,6 @@ type wgtun struct {
9190
ingress chan *buffer.View // pipes ep writes to wg
9291
events chan tun.Event // wg specific tun (interface) events
9392
amnezia *wg.Amnezia // amnezia config, if any
94-
clientid [3]byte // client id; applicable only for warp
9593
finalize chan struct{} // close signal for incomingPacket
9694
once sync.Once // closer fn; exec exactly once
9795
preferOffload bool // UDP GRO/GSO offloads
@@ -334,11 +332,6 @@ func (w *wgproxy) update(id, txt string) bool {
334332
return anew
335333
}
336334

337-
if !bytes.Equal(opts.clientid[:], w.clientid[:]) {
338-
log.D("proxy: wg: !update(%s): clientid %v != %v", w.id, opts.clientid, w.clientid)
339-
return anew
340-
}
341-
342335
if err := w.setRoutes(opts.ifaddrs); err != nil {
343336
log.W("proxy: wg: !update(%s): setRoutes: %v", w.id, err)
344337
return anew
@@ -430,15 +423,6 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
430423
if opts.mtu, err = strconv.Atoi(v); err != nil {
431424
return
432425
}
433-
case "client_id":
434-
// only for warp: blog.cloudflare.com/warp-technical-challenges
435-
// When we begin a WireGuard session we include our clientid field
436-
// which is provided by our authentication server which has to be
437-
// communicated with to begin a WARP session.
438-
if b, err := base64.StdEncoding.DecodeString(v); err == nil {
439-
n := copy(opts.clientid[:], b)
440-
log.D("proxy: wg: %s ifconfig: clientid(%d) %v", id, n, opts.clientid)
441-
}
442426
case "allowed_ip": // may exist more than once
443427
if err = loadIPNets(&opts.allowed, v); err != nil {
444428
return
@@ -473,6 +457,31 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
473457
// peer config: carry over public keys
474458
log.D("proxy: wg: %s ifconfig: processing key %q, err? %v", id, k, exx)
475459
pcfg.WriteString(line + "\n")
460+
case "client_id":
461+
// only for warp: blog.cloudflare.com/warp-technical-challenges
462+
// When we begin a WireGuard session we include our clientid field
463+
// which is provided by our authentication server which has to be
464+
// communicated with to begin a WARP session.
465+
// Though the open source Cloudflare WARP boring-tun impl does not do so:
466+
// github.com/cloudflare/boringtun/blob/64a2fc7c63/boringtun/src/noise/handshake.rs#L734
467+
if b, err := base64.StdEncoding.DecodeString(v); err == nil && len(b) == 3 {
468+
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/send.go#L456
469+
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
470+
h1 := append([]byte{device.MessageInitiationType}, b...)
471+
h2 := append([]byte{device.MessageResponseType}, b...)
472+
h3 := append([]byte{device.MessageCookieReplyType}, b...)
473+
h4 := append([]byte{device.MessageTransportType}, b...)
474+
// overwrite the 3 reserved bytes on all packets
475+
// github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138
476+
opts.amnezia.H1 = binary.LittleEndian.Uint32(h1)
477+
opts.amnezia.H2 = binary.LittleEndian.Uint32(h2)
478+
opts.amnezia.H3 = binary.LittleEndian.Uint32(h3)
479+
opts.amnezia.H4 = binary.LittleEndian.Uint32(h4)
480+
log.D("proxy: wg: %s ifconfig: clientid(%d) %v", id, len(b), b)
481+
} else {
482+
log.W("proxy: wg: %s ifconfig: clientid(%v) %d == 3?; err: %v",
483+
id, v, len(b), err)
484+
}
476485
case "jc":
477486
// github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/uapi.go#L286
478487
jc, _ := strconv.Atoi(v)
@@ -506,9 +515,7 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) {
506515
pcfg.WriteString(line + "\n")
507516
}
508517
}
509-
if opts.amnezia.Set() {
510-
log.I("proxy: wg: %s amnezia: %s", id, opts.amnezia)
511-
}
518+
log.D("proxy: wg: %s amnezia: %s", id, opts.amnezia)
512519
*txtptr = pcfg.String()
513520
if err == nil && len(opts.ifaddrs) <= 0 || opts.dns.Len() <= 0 || opts.mtu <= 0 {
514521
err = errProxyConfig
@@ -574,7 +581,7 @@ func NewWgProxy(id string, ctl protect.Controller, rev netstack.GConnHandler, cf
574581
// todo: use wgtun.serve fn instead of ctl
575582
wgep = wg.NewEndpoint2(id, ctl, opts.ep, wgtun.listener)
576583
} else {
577-
wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.amnezia, wgtun.clientid)
584+
wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.amnezia)
578585
}
579586

580587
wgdev := device.NewDevice(wgtun, wgep, wglogger(id))
@@ -648,7 +655,6 @@ func makeWgTun(id, cfg string, ctl protect.Controller, rev netstack.GConnHandler
648655
rt: x.NewIpTree(), // must be set to allowedaddrs
649656
ba: core.NewBarrier[[]netip.Addr](wgbarrierttl),
650657
amnezia: ifopts.amnezia,
651-
clientid: ifopts.clientid,
652658
status: core.NewVolatile(TUP),
653659
preferOffload: preferOffload(id),
654660
refreshBa: core.NewBarrier[bool](2 * time.Minute),
@@ -678,8 +684,8 @@ func makeWgTun(id, cfg string, ctl protect.Controller, rev netstack.GConnHandler
678684
t.events <- tun.EventUp
679685

680686
if4, if6 := netstack.StackAddrs(s, wgnic)
681-
log.I("proxy: wg: %s tun: created; dns[%s]; dst[%s]; mtu[%d]; ifaddrs[%v / %v]; clientid[%v]; amnezia[%t]",
682-
t.id, ifopts.dns, ifopts.ep, tunmtu, if4, if6, ifopts.clientid, ifopts.amnezia.Set())
687+
log.I("proxy: wg: %s tun: created; dns[%s]; dst[%s]; mtu[%d]; ifaddrs[%v / %v]; amnezia[%t]",
688+
t.id, ifopts.dns, ifopts.ep, tunmtu, if4, if6, ifopts.amnezia.Set())
683689

684690
return t, nil
685691
}

0 commit comments

Comments
 (0)