diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 7b1bdcfacf..d74084933e 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -1406,7 +1406,7 @@ func (wn *WebsocketNetwork) peersToPing() []*wsPeer { } func (wn *WebsocketNetwork) getDNSAddrs(dnsBootstrap string) []string { - srvPhonebook, err := tools_network.ReadFromSRV("algobootstrap", dnsBootstrap, wn.config.FallbackDNSResolverAddress) + srvPhonebook, err := tools_network.ReadFromSRV("algobootstrap", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress) if err != nil { // only log this warning on testnet or devnet if wn.NetworkID == config.Devnet || wn.NetworkID == config.Testnet { diff --git a/tools/network/bootstrap.go b/tools/network/bootstrap.go index 413962ef83..39fc46ccb9 100644 --- a/tools/network/bootstrap.go +++ b/tools/network/bootstrap.go @@ -25,12 +25,16 @@ import ( ) // ReadFromSRV is a helper to collect SRV addresses for a given name. -func ReadFromSRV(service string, name string, fallbackDNSResolverAddress string) (addrs []string, err error) { +func ReadFromSRV(service string, protocol string, name string, fallbackDNSResolverAddress string) (addrs []string, err error) { log := logging.Base() if name == "" { log.Debug("no dns lookup due to empty name") return } + if protocol != "tcp" && protocol != "udp" && protocol != "tls" { + err = fmt.Errorf("unsupported protocol '%s' specified", protocol) + return + } _, records, sysLookupErr := net.LookupSRV(service, "tcp", name) if sysLookupErr != nil { diff --git a/tools/network/telemetryURIUpdateService.go b/tools/network/telemetryURIUpdateService.go index 6c572d2c1c..7dfa08692a 100644 --- a/tools/network/telemetryURIUpdateService.go +++ b/tools/network/telemetryURIUpdateService.go @@ -17,6 +17,8 @@ package network import ( + "net/url" + "strings" "time" "github.com/algorand/go-algorand/config" @@ -24,17 +26,42 @@ import ( "github.com/algorand/go-algorand/protocol" ) +type telemetrySrvReader interface { + readFromSRV(protocol string, bootstrapID string) (addrs []string, err error) +} + +type telemetryURIUpdater struct { + interval time.Duration + cfg config.Local + genesisNetwork protocol.NetworkID + log logging.Logger + abort chan struct{} + srvReader telemetrySrvReader +} + // StartTelemetryURIUpdateService starts a go routine which queries SRV records for a telemetry URI every func StartTelemetryURIUpdateService(interval time.Duration, cfg config.Local, genesisNetwork protocol.NetworkID, log logging.Logger, abort chan struct{}) { + updater := &telemetryURIUpdater{ + interval: interval, + cfg: cfg, + genesisNetwork: genesisNetwork, + log: log, + abort: abort, + } + updater.srvReader = updater + updater.Start() + +} +func (t *telemetryURIUpdater) Start() { go func() { - ticker := time.NewTicker(interval) + ticker := time.NewTicker(t.interval) defer ticker.Stop() updateTelemetryURI := func() { - endpoint := lookupTelemetryEndpoint(cfg, genesisNetwork, log) + endpointURL := t.lookupTelemetryURL() - if endpoint != "" && endpoint != log.GetTelemetryURI() { - log.UpdateTelemetryURI(endpoint) + if endpointURL != nil && endpointURL.String() != t.log.GetTelemetryURI() { + t.log.UpdateTelemetryURI(endpointURL.String()) } } @@ -44,27 +71,62 @@ func StartTelemetryURIUpdateService(interval time.Duration, cfg config.Local, ge select { case <-ticker.C: updateTelemetryURI() - case <-abort: + case <-t.abort: return } } }() } -func lookupTelemetryEndpoint(cfg config.Local, genesisNetwork protocol.NetworkID, log logging.Logger) string { - bootstrapArray := cfg.DNSBootstrapArray(genesisNetwork) +func (t *telemetryURIUpdater) lookupTelemetryURL() (url *url.URL) { + bootstrapArray := t.cfg.DNSBootstrapArray(t.genesisNetwork) bootstrapArray = append(bootstrapArray, "default.algodev.network") for _, bootstrapID := range bootstrapArray { - addrs, err := ReadFromSRV("telemetry", bootstrapID, cfg.FallbackDNSResolverAddress) + addrs, err := t.srvReader.readFromSRV("tls", bootstrapID) if err != nil { - log.Infof("An issue occurred reading telemetry entry for '%s': %v", bootstrapID, err) + t.log.Infof("An issue occurred reading telemetry entry for '_telemetry._tls.%s': %v", bootstrapID, err) } else if len(addrs) == 0 { - log.Infof("No telemetry entry for: '%s'", bootstrapID) + t.log.Infof("No telemetry entry for: '_telemetry._tls.%s'", bootstrapID) } else { - return addrs[0] + for _, addr := range addrs { + // the addr that we received from ReadFromSRV contains host:port, we need to prefix that with the schema. since it's the tls, we want to use https. + url, err = url.Parse("https://" + addr) + if err != nil { + t.log.Infof("a telemetry endpoint '%s' was retrieved for '_telemerty._tls.%s'. This does not seems to be a valid endpoint and will be ignored(%v).", addr, bootstrapID, err) + continue + } + return url + } + } + + addrs, err = t.srvReader.readFromSRV("tcp", bootstrapID) + if err != nil { + t.log.Infof("An issue occurred reading telemetry entry for '_telemetry._tcp.%s': %v", bootstrapID, err) + } else if len(addrs) == 0 { + t.log.Infof("No telemetry entry for: '_telemetry._tcp.%s'", bootstrapID) + } else { + for _, addr := range addrs { + if strings.HasPrefix(addr, "https://") { + // the addr that we received from ReadFromSRV should contain host:port. however, in some cases, it might contain a https prefix, where we want to take it as is. + url, err = url.Parse(addr) + } else { + // the addr that we received from ReadFromSRV contains host:port, we need to prefix that with the schema. since it's the tcp, we want to use http. + url, err = url.Parse("http://" + addr) + } + + if err != nil { + t.log.Infof("a telemetry endpoint '%s' was retrieved for '_telemerty._tcp.%s'. This does not seems to be a valid endpoint and will be ignored(%v).", addr, bootstrapID, err) + continue + } + return url + } } } - log.Warn("No telemetry endpoint was found.") - return "" + t.log.Warn("No telemetry endpoint was found.") + return nil +} + +func (t *telemetryURIUpdater) readFromSRV(protocol string, bootstrapID string) (addrs []string, err error) { + return ReadFromSRV("telemetry", protocol, bootstrapID, t.cfg.FallbackDNSResolverAddress) } diff --git a/tools/network/telemetryURIUpdateService_test.go b/tools/network/telemetryURIUpdateService_test.go new file mode 100644 index 0000000000..7971bcd285 --- /dev/null +++ b/tools/network/telemetryURIUpdateService_test.go @@ -0,0 +1,90 @@ +// Copyright (C) 2019 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package network + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/algorand/go-algorand/config" + "github.com/algorand/go-algorand/logging" + "github.com/algorand/go-algorand/protocol" +) + +type telemetryURIUpdaterTest struct { + telemetryURIUpdater + readFromSRVResults map[string][]string +} + +func (t *telemetryURIUpdaterTest) readFromSRV(protocol string, bootstrapID string) (addrs []string, err error) { + if addr, ok := t.readFromSRVResults[protocol+bootstrapID]; ok { + return addr, nil + } + fmt.Printf("no result for %s %s\n", protocol, bootstrapID) + return nil, fmt.Errorf("no cached results") +} + +func makeTelemetryURIUpdaterTest(genesisNetwork protocol.NetworkID) *telemetryURIUpdaterTest { + t := &telemetryURIUpdaterTest{ + telemetryURIUpdater: telemetryURIUpdater{ + cfg: config.GetDefaultLocal(), + log: logging.Base(), + genesisNetwork: genesisNetwork, + }, + readFromSRVResults: make(map[string][]string), + } + t.srvReader = t + return t +} + +func (t *telemetryURIUpdaterTest) add(protocol, bootstrap string, addrs []string) { + t.readFromSRVResults[protocol+bootstrap] = addrs +} + +func TestTelemetryURILookup(t *testing.T) { + + // trivial success case. + uriUpdater := makeTelemetryURIUpdaterTest(config.Devnet) + uriUpdater.add("tcp", "devnet.algodev.network", []string{"myhost:4160"}) + uri := uriUpdater.lookupTelemetryURL() + require.NotNil(t, uri) + require.Equal(t, "http://myhost:4160", uri.String()) + + // check https prefixing + uriUpdater = makeTelemetryURIUpdaterTest(config.Devnet) + uriUpdater.add("tcp", "devnet.algodev.network", []string{"https://myhost:4160"}) + uri = uriUpdater.lookupTelemetryURL() + require.NotNil(t, uri) + require.Equal(t, "https://myhost:4160", uri.String()) + + // check https priority + uriUpdater = makeTelemetryURIUpdaterTest(config.Devnet) + uriUpdater.add("tcp", "devnet.algodev.network", []string{"myhost2:4160"}) + uriUpdater.add("tls", "devnet.algodev.network", []string{"myhost1:4160"}) + uri = uriUpdater.lookupTelemetryURL() + require.NotNil(t, uri) + require.Equal(t, "https://myhost1:4160", uri.String()) + + // check fallback + uriUpdater = makeTelemetryURIUpdaterTest(config.Devnet) + uriUpdater.add("tcp", "default.algodev.network", []string{"fallbackhost:8123"}) + uri = uriUpdater.lookupTelemetryURL() + require.NotNil(t, uri) + require.Equal(t, "http://fallbackhost:8123", uri.String()) +}