diff --git a/main.go b/main.go index 9062dfa..aeb57a1 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "flag" "log/slog" "net" + "net/netip" "os" "syscall" "time" @@ -125,6 +126,21 @@ func main() { os.Exit(1) } + if _, err := netip.ParseAddr(Opts.ListenAddr); err != nil { + Opts.Logger.Error("listen address is malformed", "error", err) + os.Exit(1) + } + + if _, err := netip.ParseAddr(Opts.TargetAddr4); err != nil { + Opts.Logger.Error("ipv4 target address is malformed", "error", err) + os.Exit(1) + } + + if _, err := netip.ParseAddr(Opts.TargetAddr6); err != nil { + Opts.Logger.Error("ipv6 target address is malformed", "error", err) + os.Exit(1) + } + if Opts.udpCloseAfter < 0 { Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter)) os.Exit(1) diff --git a/tcp.go b/tcp.go index f8eb1ee..a0db342 100644 --- a/tcp.go +++ b/tcp.go @@ -9,6 +9,7 @@ import ( "io" "log/slog" "net" + "net/netip" ) func tcpCopyData(dst net.Conn, src net.Conn, ch chan<- error) { @@ -51,10 +52,10 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { targetAddr := Opts.TargetAddr6 if saddr == nil { - if AddrVersion(conn.RemoteAddr()) == 4 { + if netip.MustParseAddr(conn.RemoteAddr().String()).Is4() { targetAddr = Opts.TargetAddr4 } - } else if AddrVersion(saddr) == 4 { + } else if netip.MustParseAddr(saddr.String()).Is4() { targetAddr = Opts.TargetAddr4 } diff --git a/udp.go b/udp.go index 5eb2054..4373ab1 100644 --- a/udp.go +++ b/udp.go @@ -9,6 +9,7 @@ import ( "errors" "log/slog" "net" + "net/netip" "sync/atomic" "syscall" "time" @@ -93,7 +94,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad } targetAddr := Opts.TargetAddr6 - if AddrVersion(downstreamAddr) == 4 { + if netip.MustParseAddr(downstreamAddr.String()).Is4() { targetAddr = Opts.TargetAddr4 } diff --git a/utils.go b/utils.go index 02d7490..432c152 100644 --- a/utils.go +++ b/utils.go @@ -7,7 +7,6 @@ package main import ( "fmt" "net" - "strings" "syscall" ) @@ -87,11 +86,3 @@ func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error return syscallErr } } - -func AddrVersion(addr net.Addr) int { - // poor man's ipv6 check - golang makes it unnecessarily hard - if strings.ContainsRune(addr.String(), '.') { - return 4 - } - return 6 -}