From 33f70bd85e5f4d0f2fec0617fbd89d89f29e865d Mon Sep 17 00:00:00 2001 From: WaterLemons2k <62788816+WaterLemons2k@users.noreply.github.com> Date: Thu, 7 Mar 2024 18:34:57 +0800 Subject: [PATCH] perf: check internet with DNS lookup DNS lookups are much cheaper than HTTP requests since we only need to check if the Internet is available. See: https://stackoverflow.com/a/50058255 --- dns/index.go | 4 +-- dns/internal/wait_net.go | 48 --------------------------------- dns/internet/wait.go | 56 +++++++++++++++++++++++++++++++++++++++ main.go | 7 +---- util/net_resolver.go | 31 +++++++++++----------- util/net_resolver_test.go | 44 ++++++++++++++++++------------ util/string.go | 15 +++++++++-- util/string_test.go | 43 ++++++++++++++++++++++++++++++ 8 files changed, 158 insertions(+), 90 deletions(-) delete mode 100644 dns/internal/wait_net.go create mode 100644 dns/internet/wait.go create mode 100644 util/string_test.go diff --git a/dns/index.go b/dns/index.go index f7b40fea2..c90b7ef0a 100644 --- a/dns/index.go +++ b/dns/index.go @@ -4,7 +4,7 @@ import ( "time" "github.com/jeessy2/ddns-go/v6/config" - "github.com/jeessy2/ddns-go/v6/dns/internal" + "github.com/jeessy2/ddns-go/v6/dns/internet" "github.com/jeessy2/ddns-go/v6/util" ) @@ -34,7 +34,7 @@ var ( // RunTimer 定时运行 func RunTimer(delay time.Duration) { - internal.WaitForNetworkConnected(addresses) + internet.Wait(addresses) for { RunOnce() diff --git a/dns/internal/wait_net.go b/dns/internal/wait_net.go deleted file mode 100644 index 2c8d780f3..000000000 --- a/dns/internal/wait_net.go +++ /dev/null @@ -1,48 +0,0 @@ -package internal - -import ( - "strings" - "time" - - "github.com/jeessy2/ddns-go/v6/util" -) - -// waitForNetworkConnected 等待网络连接后继续 -// -// addresses:用于测试网络是否连接的域名 -func WaitForNetworkConnected(addresses []string) { - // 延时 5 秒 - timeout := time.Second * 5 - - loopbackServer := "[::1]:53" - find := false - - for { - for _, addr := range addresses { - // https://github.com/jeessy2/ddns-go/issues/736 - client := util.CreateHTTPClient() - resp, err := client.Get(addr) - if err != nil { - - // 如果 err 包含回环地址([::1]:53)则表示没有 DNS 服务器,设置 DNS 服务器 - if strings.Contains(err.Error(), loopbackServer) && !find { - server := "1.1.1.1:53" - util.Log("本机DNS异常! 将默认使用 %s, 可参考文档通过 -dns 自定义 DNS 服务器", loopbackServer, server) - util.NewDialerResolver(server) - find = true - continue - } - - util.Log("等待网络连接: %s", err) - util.Log("%s 后重试...", timeout) - // 等待 5 秒后重试 - time.Sleep(timeout) - continue - } - - // 网络已连接 - resp.Body.Close() - return - } - } -} diff --git a/dns/internet/wait.go b/dns/internet/wait.go new file mode 100644 index 000000000..b74e3e12c --- /dev/null +++ b/dns/internet/wait.go @@ -0,0 +1,56 @@ +// Package internet implements utilities for checking the Internet connection. +package internet + +import ( + "strings" + "time" + + "github.com/jeessy2/ddns-go/v6/util" +) + +const ( + // fallbackDNS used when a fallback occurs. + fallbackDNS = "1.1.1.1" + + // delay is the delay time for each DNS lookup. + delay = time.Second * 5 +) + +// Wait blocks until the Internet is connected. +// +// See also: +// +// - https://stackoverflow.com/a/50058255 +// - https://github.com/ddev/ddev/blob/v1.22.7/pkg/globalconfig/global_config.go#L776 +func Wait(addresses []string) { + // fallbase in case loopback DNS is unavailable and only once. + fallback := false + + for { + for _, addr := range addresses { + err := util.LookupHost(addr) + // Internet is connected. + if err == nil { + return + } + + if !fallback && isLoopback(err) { + util.Log("本机DNS异常! 将默认使用 %s, 可参考文档通过 -dns 自定义 DNS 服务器", fallbackDNS) + util.SetDNS(fallbackDNS) + + fallback = true + continue + } + + util.Log("等待网络连接: %s", err) + + util.Log("%s 后重试...", delay) + time.Sleep(delay) + } + } +} + +// isLoopback checks if the error is a loopback error. +func isLoopback(e error) bool { + return strings.Contains(e.Error(), "[::1]:53") +} diff --git a/main.go b/main.go index 412021493..3b36dc2a1 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,6 @@ import ( "os/exec" "path/filepath" "strconv" - "strings" "time" "github.com/jeessy2/ddns-go/v6/config" @@ -85,11 +84,7 @@ func main() { util.SetInsecureSkipVerify() } if *customDNSServer != "" { - if !strings.Contains(*customDNSServer, ":") { - util.NewDialerResolver(*customDNSServer + ":53") - } else { - util.NewDialerResolver(*customDNSServer) - } + util.SetDNS(*customDNSServer) } os.Setenv(util.IPCacheTimesENV, strconv.Itoa(*ipCacheTimes)) switch *serviceType { diff --git a/util/net_resolver.go b/util/net_resolver.go index dd81aee1a..b8d511b9c 100644 --- a/util/net_resolver.go +++ b/util/net_resolver.go @@ -5,25 +5,26 @@ import ( "net" ) -// NewDialerResolver 使用 s 将 dialer.Resolver 设置为新的 net.Resolver。 -// -// s:用于创建新 net.Resolver 的字符串。 -func NewDialerResolver(s string) { - dialer.Resolver = newNetResolver(s) -} - -// newNetResolver 当 s 不为空时返回使用 s 的 Go 内置 DNS 解析器。 -// -// s:net.Resolver 的 DNS 服务器地址。 -func newNetResolver(s string) *net.Resolver { - if s == "" { - return net.DefaultResolver +// SetDNS sets the dialer.Resolver to use the given DNS server. +func SetDNS(dns string) { + // Error means that the given DNS doesn't have a port. Add it. + if _, _, err := net.SplitHostPort(dns); err != nil { + dns = net.JoinHostPort(dns, "53") } - return &net.Resolver{ + dialer.Resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return net.Dial("udp", s) + return net.Dial(network, dns) }, } } + +// LookupHost looks up the host based on the given URL using the dialer.Resolver. +// A wrapper for [net.Resolver.LookupHost]. +func LookupHost(url string) error { + name := toHostname(url) + + _, err := dialer.Resolver.LookupHost(context.Background(), name) + return err +} diff --git a/util/net_resolver_test.go b/util/net_resolver_test.go index 436aa2ebf..a565c6ed8 100644 --- a/util/net_resolver_test.go +++ b/util/net_resolver_test.go @@ -1,28 +1,38 @@ package util -import ( - "context" - "testing" +import "testing" + +const ( + testDNS = "1.1.1.1" + testURL = "https://cloudflare.com" ) -// TestNewDialerResolver 测试传递 DNS 服务器地址时能否设置 dialer.Resolver。 -func TestNewDialerResolver(t *testing.T) { - // 测试前重置以确保正常设置 - dialer.Resolver = nil +func TestSetDNS(t *testing.T) { + SetDNS(testDNS) - NewDialerResolver("1.1.1.1:53") if dialer.Resolver == nil { t.Error("Failed to set dialer.Resolver") } - - // 测试后重置以确保与测试前的值一致 - dialer.Resolver = nil } -// TestNewNetResolver 测试能否通过 newNetResolver 返回的 net.Resolver 解析域名的 IP。 -func TestNewNetResolver(t *testing.T) { - _, err := newNetResolver("1.1.1.1:53").LookupIP(context.Background(), "ip", "cloudflare.com") - if err != nil { - t.Errorf("Failed to lookup IP, err: %v", err) - } +func TestLookupHost(t *testing.T) { + t.Run("Valid URL", func(t *testing.T) { + if err := LookupHost(testURL); err != nil { + t.Errorf("Expected nil error, got %v", err) + } + }) + + t.Run("Invalid URL", func(t *testing.T) { + if err := LookupHost("invalidurl"); err == nil { + t.Error("Expected error, got nil") + } + }) + + t.Run("After SetDNS", func(t *testing.T) { + SetDNS(testDNS) + + if err := LookupHost(testURL); err != nil { + t.Errorf("Expected nil error, got %v", err) + } + }) } diff --git a/util/string.go b/util/string.go index 0d8496e5d..75c0e8941 100644 --- a/util/string.go +++ b/util/string.go @@ -2,8 +2,7 @@ package util import "strings" -// WriteString 使用 strings.Builder 生成字符串并返回 string -// https://pkg.go.dev/strings#Builder +// WriteString creates a new string using [strings.Builder]. func WriteString(strs ...string) string { var b strings.Builder for _, str := range strs { @@ -12,3 +11,15 @@ func WriteString(strs ...string) string { return b.String() } + +// toHostname normalizes a URL with a https scheme to just its hostname. +// +// See also: +// +// - https://github.com/moby/moby/blob/v25.0.3/registry/auth.go#L132 +func toHostname(url string) string { + stripped := url + stripped = strings.TrimPrefix(stripped, "https://") + + return strings.Split(stripped, "/")[0] +} diff --git a/util/string_test.go b/util/string_test.go new file mode 100644 index 000000000..736dc4234 --- /dev/null +++ b/util/string_test.go @@ -0,0 +1,43 @@ +package util + +import "testing" + +func TestWriteString(t *testing.T) { + tests := []struct { + input []string + expected string + }{ + {[]string{"hello", "world"}, "helloworld"}, + {[]string{"", "test"}, "test"}, + {[]string{"hello", " ", "world"}, "hello world"}, + {[]string{""}, ""}, + } + + for _, tt := range tests { + result := WriteString(tt.input...) + if result != tt.expected { + t.Errorf("Expected %s, but got %s", tt.expected, result) + } + } +} + +func TestToHostname(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"With https scheme", "https://www.example.com", "www.example.com"}, + {"With path", "www.example.com/path", "www.example.com"}, + {"With https scheme and path", "https://www.example.com/path", "www.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := toHostname(tt.input) + if result != tt.expected { + t.Errorf("Expected %s, but got %s", tt.expected, result) + } + }) + } +}