-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor into separate packages & add tests. (#4)
- Loading branch information
Showing
14 changed files
with
902 additions
and
324 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: Test | ||
|
||
on: | ||
push: | ||
branches: ["main"] | ||
pull_request: | ||
branches: ["main"] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Set up Go | ||
uses: actions/setup-go@v4 | ||
with: | ||
go-version: "1.21" | ||
|
||
- name: Build | ||
run: go build -v | ||
|
||
- name: Prepare ip routes | ||
run: | | ||
sudo ip rule add from 127.0.0.1/8 iif lo table 123 | ||
sudo ip route add local 0.0.0.0/0 dev lo table 123 | ||
sudo ip -6 rule add from ::1/128 iif lo table 123 | ||
sudo ip -6 route add local ::/0 dev lo table 123 | ||
- name: Test | ||
run: sudo go test -v -timeout 30s ./tests |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright 2019 Path Network, Inc. All rights reserved. | ||
// Copyright 2024 Konrad Zemek <[email protected]> | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package buffers | ||
|
||
import ( | ||
"math" | ||
"sync" | ||
) | ||
|
||
var buffers sync.Pool | ||
|
||
func init() { | ||
buffers.New = func() any { | ||
slice := make([]byte, math.MaxUint16) | ||
return &slice | ||
} | ||
} | ||
|
||
func Get() []byte { | ||
return *buffers.Get().(*[]byte) | ||
} | ||
|
||
func Put(buf []byte) { | ||
buffers.Put(&buf) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,58 @@ | ||
// Copyright 2019 Path Network, Inc. All rights reserved. | ||
// Copyright 2024 Konrad Zemek <[email protected]> | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package main | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"flag" | ||
"log/slog" | ||
"net" | ||
"net/netip" | ||
"os" | ||
"syscall" | ||
"time" | ||
|
||
"github.com/kzemek/go-mmproxy/tcp" | ||
"github.com/kzemek/go-mmproxy/udp" | ||
"github.com/kzemek/go-mmproxy/utils" | ||
) | ||
|
||
type options struct { | ||
Protocol string | ||
ListenAddrStr string | ||
TargetAddr4Str string | ||
TargetAddr6Str string | ||
ListenAddr netip.AddrPort | ||
TargetAddr4 netip.AddrPort | ||
TargetAddr6 netip.AddrPort | ||
Mark int | ||
Verbose int | ||
allowedSubnetsPath string | ||
AllowedSubnets []*net.IPNet | ||
Listeners int | ||
Logger *slog.Logger | ||
udpCloseAfter int | ||
UDPCloseAfter time.Duration | ||
} | ||
var protocolStr string | ||
var listenAddrStr string | ||
var targetAddr4Str string | ||
var targetAddr6Str string | ||
var allowedSubnetsPath string | ||
var udpCloseAfterInt int | ||
var listeners int | ||
|
||
var Opts options | ||
var opts utils.Options | ||
|
||
func init() { | ||
flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp") | ||
flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") | ||
flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") | ||
flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") | ||
flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets") | ||
flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections | ||
flag.StringVar(&protocolStr, "p", "tcp", "Protocol that will be proxied: tcp, udp") | ||
flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") | ||
flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") | ||
flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") | ||
flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets") | ||
flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections | ||
1 - log errors occurring in individual connections | ||
2 - log all state changes of individual connections`) | ||
flag.StringVar(&Opts.allowedSubnetsPath, "allowed-subnets", "", | ||
flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "", | ||
"Path to a file that contains allowed subnets of the proxy servers") | ||
flag.IntVar(&Opts.Listeners, "listeners", 1, | ||
flag.IntVar(&listeners, "listeners", 1, | ||
"Number of listener sockets that will be opened for the listen address (Linux 3.9+)") | ||
flag.IntVar(&Opts.udpCloseAfter, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up") | ||
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up") | ||
} | ||
|
||
func listen(listenerNum int, errors chan<- error) { | ||
logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum), | ||
slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String())) | ||
func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) { | ||
logger := parentLogger.With(slog.Int("listenerNum", listenerNum), | ||
slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String())) | ||
|
||
listenConfig := net.ListenConfig{} | ||
if Opts.Listeners > 1 { | ||
if listeners > 1 { | ||
listenConfig.Control = func(network, address string, c syscall.RawConn) error { | ||
return c.Control(func(fd uintptr) { | ||
soReusePort := 15 | ||
|
@@ -67,15 +63,15 @@ func listen(listenerNum int, errors chan<- error) { | |
} | ||
} | ||
|
||
if Opts.Protocol == "tcp" { | ||
tcpListen(&listenConfig, logger, errors) | ||
if opts.Protocol == utils.TCP { | ||
tcp.Listen(ctx, &listenConfig, &opts, logger, listenErrors) | ||
} else { | ||
udpListen(&listenConfig, logger, errors) | ||
udp.Listen(ctx, &listenConfig, &opts, logger, listenErrors) | ||
} | ||
} | ||
|
||
func loadAllowedSubnets() error { | ||
file, err := os.Open(Opts.allowedSubnetsPath) | ||
func loadAllowedSubnets(logger *slog.Logger) error { | ||
file, err := os.Open(allowedSubnetsPath) | ||
if err != nil { | ||
return err | ||
} | ||
|
@@ -84,12 +80,12 @@ func loadAllowedSubnets() error { | |
|
||
scanner := bufio.NewScanner(file) | ||
for scanner.Scan() { | ||
_, ipNet, err := net.ParseCIDR(scanner.Text()) | ||
ipNet, err := netip.ParsePrefix(scanner.Text()) | ||
if err != nil { | ||
return err | ||
} | ||
Opts.AllowedSubnets = append(Opts.AllowedSubnets, ipNet) | ||
Opts.Logger.Info("allowed subnet", slog.String("subnet", ipNet.String())) | ||
opts.AllowedSubnets = append(opts.AllowedSubnets, ipNet) | ||
logger.Info("allowed subnet", slog.String("subnet", ipNet.String())) | ||
} | ||
|
||
return nil | ||
|
@@ -98,72 +94,79 @@ func loadAllowedSubnets() error { | |
func main() { | ||
flag.Parse() | ||
lvl := slog.LevelInfo | ||
if Opts.Verbose > 0 { | ||
if opts.Verbose > 0 { | ||
lvl = slog.LevelDebug | ||
} | ||
Opts.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) | ||
|
||
if Opts.allowedSubnetsPath != "" { | ||
if err := loadAllowedSubnets(); err != nil { | ||
Opts.Logger.Error("failed to load allowed subnets file", "path", Opts.allowedSubnetsPath, "error", err) | ||
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) | ||
|
||
if allowedSubnetsPath != "" { | ||
if err := loadAllowedSubnets(logger); err != nil { | ||
logger.Error("failed to load allowed subnets file", "path", allowedSubnetsPath, "error", err) | ||
} | ||
} | ||
|
||
if Opts.Protocol != "tcp" && Opts.Protocol != "udp" { | ||
Opts.Logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", Opts.Protocol)) | ||
if protocolStr == "tcp" { | ||
opts.Protocol = utils.TCP | ||
} else if protocolStr == "udp" { | ||
opts.Protocol = utils.UDP | ||
} else { | ||
logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", protocolStr)) | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.Mark < 0 { | ||
Opts.Logger.Error("--mark has to be >= 0", slog.Int("mark", Opts.Mark)) | ||
if opts.Mark < 0 { | ||
logger.Error("--mark has to be >= 0", slog.Int("mark", opts.Mark)) | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.Verbose < 0 { | ||
Opts.Logger.Error("-v has to be >= 0", slog.Int("verbose", Opts.Verbose)) | ||
if opts.Verbose < 0 { | ||
logger.Error("-v has to be >= 0", slog.Int("verbose", opts.Verbose)) | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.Listeners < 1 { | ||
Opts.Logger.Error("--listeners has to be >= 1") | ||
if listeners < 1 { | ||
logger.Error("--listeners has to be >= 1") | ||
os.Exit(1) | ||
} | ||
|
||
var err error | ||
if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil { | ||
Opts.Logger.Error("listen address is malformed", "error", err) | ||
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil { | ||
logger.Error("listen address is malformed", "error", err) | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil { | ||
Opts.Logger.Error("ipv4 target address is malformed", "error", err) | ||
if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil { | ||
logger.Error("ipv4 target address is malformed", "error", err) | ||
os.Exit(1) | ||
} | ||
if !Opts.TargetAddr4.Addr().Is4() { | ||
Opts.Logger.Error("ipv4 target address is not IPv4") | ||
if !opts.TargetAddr4.Addr().Is4() { | ||
logger.Error("ipv4 target address is not IPv4") | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil { | ||
Opts.Logger.Error("ipv6 target address is malformed", "error", err) | ||
if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil { | ||
logger.Error("ipv6 target address is malformed", "error", err) | ||
os.Exit(1) | ||
} | ||
if !Opts.TargetAddr6.Addr().Is6() { | ||
Opts.Logger.Error("ipv6 target address is not IPv6") | ||
if !opts.TargetAddr6.Addr().Is6() { | ||
logger.Error("ipv6 target address is not IPv6") | ||
os.Exit(1) | ||
} | ||
|
||
if Opts.udpCloseAfter < 0 { | ||
Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter)) | ||
if udpCloseAfterInt < 0 { | ||
logger.Error("--close-after has to be >= 0", slog.Int("close-after", udpCloseAfterInt)) | ||
os.Exit(1) | ||
} | ||
Opts.UDPCloseAfter = time.Duration(Opts.udpCloseAfter) * time.Second | ||
opts.UDPCloseAfter = time.Duration(udpCloseAfterInt) * time.Second | ||
|
||
listenErrors := make(chan error, Opts.Listeners) | ||
for i := 0; i < Opts.Listeners; i++ { | ||
go listen(i, listenErrors) | ||
listenErrors := make(chan error, listeners) | ||
ctxs := make([]context.Context, listeners) | ||
for i := range ctxs { | ||
ctxs[i] = context.Background() | ||
go listen(ctxs[i], i, logger, listenErrors) | ||
} | ||
for i := 0; i < Opts.Listeners; i++ { | ||
for range ctxs { | ||
<-listenErrors | ||
} | ||
} |
Oops, something went wrong.