Skip to content

Commit 8967d6e

Browse files
WIP implement IP proxying
1 parent d5c6bb0 commit 8967d6e

File tree

5 files changed

+330
-20
lines changed

5 files changed

+330
-20
lines changed

capsule.go

+19-17
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,22 @@ func parseAddress(r io.Reader) (requestID uint64, prefix netip.Prefix, _ error)
164164

165165
// routeAdvertisementCapsule represents a ROUTE_ADVERTISEMENT capsule
166166
type routeAdvertisementCapsule struct {
167-
IPAddressRanges []IPAddressRange
167+
IPAddressRanges []IPRoute
168168
}
169169

170-
// IPAddressRange represents an IP Address Range within a ROUTE_ADVERTISEMENT capsule
171-
type IPAddressRange struct {
172-
StartIP netip.Addr
173-
EndIP netip.Addr
170+
// IPRoute represents an IP Address Range
171+
type IPRoute struct {
172+
StartIP netip.Addr
173+
EndIP netip.Addr
174+
// IPProtocol is the Internet Protocol Number for traffic that can be sent to this range.
175+
// If the value is 0, all protocols are allowed.
174176
IPProtocol uint8
175177
}
176178

177-
func (r IPAddressRange) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 }
179+
func (r IPRoute) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 }
178180

179181
func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, error) {
180-
var ranges []IPAddressRange
182+
var ranges []IPRoute
181183
for {
182184
ipRange, err := parseIPAddressRange(r)
183185
if err != nil {
@@ -213,47 +215,47 @@ func (c *routeAdvertisementCapsule) append(b []byte) []byte {
213215
return b
214216
}
215217

216-
func parseIPAddressRange(r io.Reader) (IPAddressRange, error) {
218+
func parseIPAddressRange(r io.Reader) (IPRoute, error) {
217219
var ipVersion uint8
218220
if err := binary.Read(r, binary.LittleEndian, &ipVersion); err != nil {
219-
return IPAddressRange{}, err
221+
return IPRoute{}, err
220222
}
221223

222224
var startIP, endIP netip.Addr
223225
switch ipVersion {
224226
case 4:
225227
var start, end [4]byte
226228
if _, err := io.ReadFull(r, start[:]); err != nil {
227-
return IPAddressRange{}, err
229+
return IPRoute{}, err
228230
}
229231
if _, err := io.ReadFull(r, end[:]); err != nil {
230-
return IPAddressRange{}, err
232+
return IPRoute{}, err
231233
}
232234
startIP = netip.AddrFrom4(start)
233235
endIP = netip.AddrFrom4(end)
234236
case 6:
235237
var start, end [16]byte
236238
if _, err := io.ReadFull(r, start[:]); err != nil {
237-
return IPAddressRange{}, err
239+
return IPRoute{}, err
238240
}
239241
if _, err := io.ReadFull(r, end[:]); err != nil {
240-
return IPAddressRange{}, err
242+
return IPRoute{}, err
241243
}
242244
startIP = netip.AddrFrom16(start)
243245
endIP = netip.AddrFrom16(end)
244246
default:
245-
return IPAddressRange{}, fmt.Errorf("invalid IP version: %d", ipVersion)
247+
return IPRoute{}, fmt.Errorf("invalid IP version: %d", ipVersion)
246248
}
247249

248250
if startIP.Compare(endIP) > 0 {
249-
return IPAddressRange{}, errors.New("start IP is greater than end IP")
251+
return IPRoute{}, errors.New("start IP is greater than end IP")
250252
}
251253

252254
var ipProtocol uint8
253255
if err := binary.Read(r, binary.LittleEndian, &ipProtocol); err != nil {
254-
return IPAddressRange{}, err
256+
return IPRoute{}, err
255257
}
256-
return IPAddressRange{
258+
return IPRoute{
257259
StartIP: startIP,
258260
EndIP: endIP,
259261
IPProtocol: ipProtocol,

capsule_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) {
221221
capsule, err := parseRouteAdvertisementCapsule(cr)
222222
require.NoError(t, err)
223223
require.Equal(t,
224-
[]IPAddressRange{
224+
[]IPRoute{
225225
{StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13},
226226
{StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37},
227227
},
@@ -232,7 +232,7 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) {
232232

233233
func TestWriteRouteAdvertisementCapsule(t *testing.T) {
234234
c := &routeAdvertisementCapsule{
235-
IPAddressRanges: []IPAddressRange{
235+
IPAddressRanges: []IPRoute{
236236
{StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13},
237237
{StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37},
238238
},

conn.go

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
package connectip
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"log"
8+
"net/netip"
9+
"slices"
10+
"sync/atomic"
11+
12+
"golang.org/x/net/ipv4"
13+
"golang.org/x/net/ipv6"
14+
15+
"github.com/quic-go/quic-go/http3"
16+
"github.com/quic-go/quic-go/quicvarint"
17+
)
18+
19+
type appendable interface{ append([]byte) []byte }
20+
21+
type writeCapsule struct {
22+
capsule appendable
23+
result chan error
24+
}
25+
26+
// Conn is a connection that proxies IP packets over HTTP/3.
27+
type Conn struct {
28+
str http3.Stream
29+
writes chan writeCapsule
30+
31+
peerAddresses []netip.Prefix // IP prefixes that we assigned to the peer
32+
localRoutes []IPRoute // IP routes that we advertised to the peer
33+
34+
assignedAddressNotify chan struct{}
35+
assignedAddresses atomic.Pointer[[]netip.Prefix]
36+
availableRoutesNotify chan struct{}
37+
availableRoutes atomic.Pointer[[]IPRoute]
38+
}
39+
40+
func newProxiedConn(str http3.Stream) *Conn {
41+
c := &Conn{
42+
str: str,
43+
assignedAddressNotify: make(chan struct{}, 1),
44+
availableRoutesNotify: make(chan struct{}, 1),
45+
}
46+
go func() {
47+
if err := c.readFromStream(); err != nil {
48+
log.Printf("handling stream failed: %v", err)
49+
}
50+
}()
51+
go func() {
52+
if err := c.writeToStream(); err != nil {
53+
log.Printf("writing to stream failed: %v", err)
54+
}
55+
}()
56+
return c
57+
}
58+
59+
// AdvertiseRoute informs the peer about available routes.
60+
// This function can be called multiple times, but only the routes from the most recent call will be active.
61+
// Previous route advertisements are overwritten by each new call to this function.
62+
func (c *Conn) AdvertiseRoute(ctx context.Context, routes []IPRoute) error {
63+
c.localRoutes = slices.Clone(routes)
64+
for _, route := range routes {
65+
if route.StartIP.Compare(route.EndIP) == 1 {
66+
return fmt.Errorf("invalid route advertising start_ip: %s larger than %s", route.StartIP, route.EndIP)
67+
}
68+
}
69+
return c.sendCapsule(ctx, &routeAdvertisementCapsule{IPAddressRanges: routes})
70+
}
71+
72+
// AssignAddresses assigned address prefixes to the peer.
73+
// This function can be called multiple times, but only the addresses from the most recent call will be active.
74+
// Previous address assignments are overwritten by each new call to this function.
75+
func (c *Conn) AssignAddresses(ctx context.Context, prefixes []netip.Prefix) error {
76+
c.peerAddresses = slices.Clone(prefixes)
77+
capsule := &addressAssignCapsule{AssignedAddresses: make([]AssignedAddress, 0, len(prefixes))}
78+
for _, p := range prefixes {
79+
capsule.AssignedAddresses = append(capsule.AssignedAddresses, AssignedAddress{IPPrefix: p})
80+
}
81+
return c.sendCapsule(ctx, capsule)
82+
}
83+
84+
func (c *Conn) sendCapsule(ctx context.Context, capsule appendable) error {
85+
res := make(chan error, 1)
86+
select {
87+
case c.writes <- writeCapsule{
88+
capsule: capsule,
89+
result: res,
90+
}:
91+
select {
92+
case <-ctx.Done():
93+
return ctx.Err()
94+
case err := <-res:
95+
return err
96+
}
97+
case <-ctx.Done():
98+
return ctx.Err()
99+
}
100+
}
101+
102+
// LocalPrefixes returns the prefixes that the peer currently assigned.
103+
// Note that at any point during the connection, the peer can change the assignment.
104+
// It is therefore recommended to call this function in a loop.
105+
func (c *Conn) LocalPrefixes(ctx context.Context) ([]netip.Prefix, error) {
106+
select {
107+
case <-ctx.Done():
108+
return nil, ctx.Err()
109+
case <-c.assignedAddressNotify:
110+
return *c.assignedAddresses.Load(), nil
111+
}
112+
}
113+
114+
// Routes returns the routes that the peer currently advertised.
115+
// Note that at any point during the connection, the peer can change the advertised routes.
116+
// It is therefore recommended to call this function in a loop.
117+
func (c *Conn) Routes(ctx context.Context) ([]IPRoute, error) {
118+
select {
119+
case <-ctx.Done():
120+
return nil, ctx.Err()
121+
case <-c.assignedAddressNotify:
122+
return *c.availableRoutes.Load(), nil
123+
}
124+
}
125+
126+
func (c *Conn) readFromStream() error {
127+
defer c.str.Close()
128+
r := quicvarint.NewReader(c.str)
129+
for {
130+
t, cr, err := http3.ParseCapsule(r)
131+
if err != nil {
132+
return err
133+
}
134+
switch t {
135+
case capsuleTypeAddressAssign:
136+
capsule, err := parseAddressAssignCapsule(cr)
137+
if err != nil {
138+
return err
139+
}
140+
prefixes := make([]netip.Prefix, 0, len(capsule.AssignedAddresses))
141+
for _, assigned := range capsule.AssignedAddresses {
142+
prefixes = append(prefixes, assigned.IPPrefix)
143+
}
144+
c.assignedAddresses.Store(&prefixes)
145+
select {
146+
case c.assignedAddressNotify <- struct{}{}:
147+
default:
148+
}
149+
case capsuleTypeAddressRequest:
150+
if _, err := parseAddressRequestCapsule(r); err != nil {
151+
return err
152+
}
153+
return errors.New("masque: address request not yet supported")
154+
case capsuleTypeRouteAdvertisement:
155+
capsule, err := parseRouteAdvertisementCapsule(r)
156+
if err != nil {
157+
return err
158+
}
159+
c.availableRoutes.Store(&capsule.IPAddressRanges)
160+
select {
161+
case c.availableRoutesNotify <- struct{}{}:
162+
default:
163+
}
164+
default:
165+
return fmt.Errorf("unknown capsule type: %d", t)
166+
}
167+
}
168+
}
169+
170+
func (c *Conn) writeToStream() error {
171+
buf := make([]byte, 0, 1024)
172+
for {
173+
req, ok := <-c.writes
174+
if !ok {
175+
return nil
176+
}
177+
buf = req.capsule.append(buf)
178+
_, err := c.str.Write(buf)
179+
req.result <- err
180+
if err != nil {
181+
return err
182+
}
183+
buf = buf[:0]
184+
}
185+
}
186+
187+
func (c *Conn) Read(b []byte) (n int, err error) {
188+
start:
189+
data, err := c.str.ReceiveDatagram(context.Background())
190+
if err != nil {
191+
return 0, err
192+
}
193+
contextID, n, err := quicvarint.Parse(data)
194+
if err != nil {
195+
return 0, fmt.Errorf("masque: malformed datagram: %w", err)
196+
}
197+
if contextID != 0 {
198+
// Drop this datagram. We currently only support proxying of IP payloads.
199+
goto start
200+
}
201+
if err := c.handleIncomingPacket(data[n:]); err != nil {
202+
log.Printf("dropping proxied packet: %s", err)
203+
goto start
204+
}
205+
return copy(b, data[n:]), nil
206+
}
207+
208+
func (c *Conn) handleIncomingPacket(data []byte) error {
209+
if len(data) == 0 {
210+
return errors.New("empty packet")
211+
}
212+
var src, dst netip.Addr
213+
var ipProto uint8
214+
switch ipVersion(data) {
215+
default:
216+
return fmt.Errorf("masque: unknown IP versions: %d", data[0])
217+
case 4:
218+
if len(data) < ipv4.HeaderLen {
219+
return fmt.Errorf("masque: malformed datagram: too short")
220+
}
221+
src = netip.AddrFrom4([4]byte(data[12:16]))
222+
ipProto = data[9]
223+
case 6:
224+
if len(data) < ipv6.HeaderLen {
225+
return fmt.Errorf("masque: malformed datagram: too short")
226+
}
227+
src = netip.AddrFrom16([16]byte(data[8:24]))
228+
dst = netip.AddrFrom16([16]byte(data[24:40]))
229+
ipProto = data[6]
230+
}
231+
232+
if !slices.ContainsFunc(c.peerAddresses, func(p netip.Prefix) bool { return p.Contains(src) }) {
233+
// TODO: send ICMP
234+
return fmt.Errorf("masque: datagram source address not allowed: %s", src)
235+
}
236+
isAllowedDest := slices.ContainsFunc(c.localRoutes, func(r IPRoute) bool {
237+
if r.StartIP.Compare(dst) > 0 || dst.Compare(r.EndIP) > 0 {
238+
return false
239+
}
240+
if r.IPProtocol != 0 && r.IPProtocol != ipProto {
241+
return false
242+
}
243+
return true
244+
})
245+
if !isAllowedDest {
246+
// TODO: send ICMP
247+
return fmt.Errorf("masque: datagram destination address / IP protocol not allowed: %s (protocol %d)", dst, ipProto)
248+
}
249+
return nil
250+
}
251+
252+
func (c *Conn) Write(b []byte) (n int, err error) {
253+
// TODO: implement src, dst and ipproto checks
254+
if len(b) == 0 {
255+
return 0, nil
256+
}
257+
switch ipVersion(b) {
258+
default:
259+
return 0, fmt.Errorf("masque: unknown IP versions: %d", b[0])
260+
case 4:
261+
if len(b) < 20 {
262+
return 0, fmt.Errorf("masque: IPv4 packet too short")
263+
}
264+
ttl := b[8]
265+
if ttl <= 1 {
266+
return 0, fmt.Errorf("masque: datagram TTL too small: %d", ttl)
267+
}
268+
b[8]-- // Decrement TTL
269+
// TODO: maybe recalculate the checksum?
270+
case 6:
271+
// TODO: IPv6 support
272+
return 0, errors.New("IPv6 currently not supported")
273+
}
274+
data := make([]byte, 0, len(contextIDZero)+len(b))
275+
data = append(data, contextIDZero...)
276+
data = append(data, b...)
277+
return len(b), c.str.SendDatagram(data)
278+
}
279+
280+
func ipVersion(b []byte) uint8 { return b[0] >> 4 }

0 commit comments

Comments
 (0)