diff --git a/main.go b/main.go index aa67bc5..c978aac 100644 --- a/main.go +++ b/main.go @@ -130,7 +130,7 @@ func main() { } var err error - if Opts.ListenAddr, err = netip.ParseAddrPort(Opts.ListenAddrStr); err != nil { + if Opts.ListenAddr, err = ParseHostPort(Opts.ListenAddrStr); err != nil { Opts.Logger.Error("listen address is malformed", "error", err) os.Exit(1) } diff --git a/utils.go b/utils.go index 432c152..00b806d 100644 --- a/utils.go +++ b/utils.go @@ -7,6 +7,8 @@ package main import ( "fmt" "net" + "net/netip" + "strconv" "syscall" ) @@ -30,6 +32,29 @@ func CheckOriginAllowed(remoteIP net.IP) bool { return false } +func ParseHostPort(hostport string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err) + } + + ips, err := net.LookupIP(host) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("failed to lookup IP addresses: %w", err) + } + if len(ips) == 0 { + return netip.AddrPort{}, fmt.Errorf("no IP addresses found") + } + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("failed to parse port: %w", err) + } + + ip, _ := netip.AddrFromSlice(ips[0]) + return netip.AddrPortFrom(ip, uint16(port)), nil +} + func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { var syscallErr error