diff --git a/cmd/root.go b/cmd/root.go index 43419f125..ed7299357 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,6 +32,7 @@ import ( "contrib.go.opencensus.io/exporter/prometheus" "contrib.go.opencensus.io/exporter/stackdriver" "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/healthcheck" "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/log" "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy" "github.com/spf13/cobra" @@ -76,6 +77,7 @@ type Command struct { telemetryProject string telemetryPrefix string prometheusNamespace string + healthCheck bool httpPort string } @@ -157,7 +159,6 @@ When this flag is not set, there is no limit.`) to close after receiving a TERM signal. The proxy will shut down when the number of open connections reaches 0 or when the maximum time has passed. Defaults to 0s.`) - cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "", "Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.") cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false, @@ -172,12 +173,16 @@ the maximum time has passed. Defaults to 0s.`) "Enable Prometheus for metric collection using the provided namespace") cmd.PersistentFlags().StringVar(&c.httpPort, "http-port", "9090", "Port for the Prometheus server to use") + cmd.PersistentFlags().BoolVar(&c.healthCheck, "health-check", false, + `Enables HTTP endpoints /startup, /liveness, and /readiness +that report on the proxy's health. Endpoints are available on localhost +only. Uses the port specified by the http-port flag.`) cmd.PersistentFlags().StringVar(&c.conf.APIEndpointURL, "sqladmin-api-endpoint", "", "When set, the proxy uses this url as the API endpoint for all Cloud SQL Admin API requests.\nExample: https://sqladmin.googleapis.com") cmd.PersistentFlags().StringVar(&c.conf.QuotaProject, "quota-project", "", `Specifies the project to use for Cloud SQL Admin API quota tracking. The IAM principal must have the "serviceusage.services.use" permission -for the given project. See https://cloud.google.com/service-usage/docs/overview and +for the given project. See https://cloud.google.com/service-usage/docs/overview and https://cloud.google.com/storage/docs/requester-pays`) // Global and per instance flags @@ -225,18 +230,18 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { return newBadCommandError("cannot specify --credentials-file and --gcloud-auth flags at the same time") } - if userHasSet("http-port") && !userHasSet("prometheus-namespace") { - return newBadCommandError("cannot specify --http-port without --prometheus-namespace") + if userHasSet("http-port") && !userHasSet("prometheus-namespace") && !userHasSet("health-check") { + cmd.logger.Infof("Ignoring --http-port because --prometheus-namespace or --health-check was not set") } if !userHasSet("telemetry-project") && userHasSet("telemetry-prefix") { - cmd.logger.Infof("Ignoring telementry-prefix as telemetry-project was not set") + cmd.logger.Infof("Ignoring --telementry-prefix because --telemetry-project was not set") } if !userHasSet("telemetry-project") && userHasSet("disable-metrics") { - cmd.logger.Infof("Ignoring disable-metrics as telemetry-project was not set") + cmd.logger.Infof("Ignoring --disable-metrics because --telemetry-project was not set") } if !userHasSet("telemetry-project") && userHasSet("disable-traces") { - cmd.logger.Infof("Ignoring disable-traces as telemetry-project was not set") + cmd.logger.Infof("Ignoring --disable-traces because --telemetry-project was not set") } if userHasSet("sqladmin-api-endpoint") && conf.APIEndpointURL != "" { @@ -364,9 +369,8 @@ func runSignalWrapper(cmd *Command) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - // Configure Cloud Trace and/or Cloud Monitoring based on command - // invocation. If a project has not been enabled, no traces or metrics are - // enabled. + // Configure collectors before the proxy has started to ensure we are + // collecting metrics before *ANY* Cloud SQL Admin API calls are made. enableMetrics := !cmd.disableMetrics enableTraces := !cmd.disableTraces if cmd.telemetryProject != "" && (enableMetrics || enableTraces) { @@ -394,40 +398,22 @@ func runSignalWrapper(cmd *Command) error { }() } - shutdownCh := make(chan error) - + var ( + needsHTTPServer bool + mux = http.NewServeMux() + ) if cmd.prometheusNamespace != "" { + needsHTTPServer = true e, err := prometheus.NewExporter(prometheus.Options{ Namespace: cmd.prometheusNamespace, }) if err != nil { return err } - mux := http.NewServeMux() mux.Handle("/metrics", e) - addr := fmt.Sprintf("localhost:%s", cmd.httpPort) - server := &http.Server{Addr: addr, Handler: mux} - go func() { - select { - case <-ctx.Done(): - // Give the HTTP server a second to shutdown cleanly. - ctx2, _ := context.WithTimeout(context.Background(), time.Second) - if err := server.Shutdown(ctx2); err != nil { - cmd.logger.Errorf("failed to shutdown Prometheus HTTP server: %v\n", err) - } - } - }() - go func() { - err := server.ListenAndServe() - if err == http.ErrServerClosed { - return - } - if err != nil { - shutdownCh <- fmt.Errorf("failed to start prometheus HTTP server: %v", err) - } - }() } + shutdownCh := make(chan error) // watch for sigterm / sigint signals signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) @@ -465,18 +451,55 @@ func runSignalWrapper(cmd *Command) error { cmd.logger.Errorf("The proxy has encountered a terminal error: %v", err) return err case p = <-startCh: + cmd.logger.Infof("The proxy has started successfully and is ready for new connections!") } - cmd.logger.Infof("The proxy has started successfully and is ready for new connections!") - defer p.Close() defer func() { if cErr := p.Close(); cErr != nil { cmd.logger.Errorf("error during shutdown: %v", cErr) } }() - go func() { - shutdownCh <- p.Serve(ctx) - }() + notify := func() {} + if cmd.healthCheck { + needsHTTPServer = true + hc := healthcheck.NewCheck(p, cmd.logger) + mux.HandleFunc("/startup", hc.HandleStartup) + mux.HandleFunc("/readiness", hc.HandleReadiness) + mux.HandleFunc("/liveness", hc.HandleLiveness) + notify = hc.NotifyStarted + } + + // Start the HTTP server if anything requiring HTTP is specified. + if needsHTTPServer { + server := &http.Server{ + Addr: fmt.Sprintf("localhost:%s", cmd.httpPort), + Handler: mux, + } + // Start the HTTP server. + go func() { + err := server.ListenAndServe() + if err == http.ErrServerClosed { + return + } + if err != nil { + shutdownCh <- fmt.Errorf("failed to start HTTP server: %v", err) + } + }() + // Handle shutdown of the HTTP server gracefully. + go func() { + select { + case <-ctx.Done(): + // Give the HTTP server a second to shutdown cleanly. + ctx2, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := server.Shutdown(ctx2); err != nil { + cmd.logger.Errorf("failed to shutdown Prometheus HTTP server: %v\n", err) + } + } + }() + } + + go func() { shutdownCh <- p.Serve(ctx, notify) }() err := <-shutdownCh switch { diff --git a/cmd/root_test.go b/cmd/root_test.go index bfb975906..880acc879 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -516,10 +516,6 @@ func TestNewCommandWithErrors(t *testing.T) { desc: "when the iam authn login query param is bogus", args: []string{"proj:region:inst?auto-iam-authn=nope"}, }, - { - desc: "enabling a Prometheus port without a namespace", - args: []string{"--http-port", "1111", "proj:region:inst"}, - }, { desc: "using an invalid url for sqladmin-api-endpoint", args: []string{"--sqladmin-api-endpoint", "https://user:abc{DEf1=ghi@example.com:5432/db?sslmode=require", "proj:region:inst"}, diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go new file mode 100644 index 000000000..b55a1275f --- /dev/null +++ b/internal/healthcheck/healthcheck.go @@ -0,0 +1,109 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package healthcheck tests and communicates the health of the Cloud SQL Auth proxy. +package healthcheck + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy" +) + +// Check provides HTTP handlers for use as healthchecks typically in a +// Kubernetes context. +type Check struct { + once *sync.Once + started chan struct{} + proxy *proxy.Client + logger cloudsql.Logger +} + +// NewCheck is the initializer for Check. +func NewCheck(p *proxy.Client, l cloudsql.Logger) *Check { + return &Check{ + once: &sync.Once{}, + started: make(chan struct{}), + proxy: p, + logger: l, + } +} + +// NotifyStarted notifies the check that the proxy has started up successfully. +func (c *Check) NotifyStarted() { + c.once.Do(func() { close(c.started) }) +} + +// HandleStartup reports whether the Check has been notified of startup. +func (c *Check) HandleStartup(w http.ResponseWriter, _ *http.Request) { + select { + case <-c.started: + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + default: + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("error")) + } +} + +var errNotStarted = errors.New("proxy is not started") + +// HandleReadiness ensures the Check has been notified of successful startup, +// that the proxy has not reached maximum connections, and that all connections +// are healthy. +func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + select { + case <-c.started: + default: + c.logger.Errorf("[Health Check] Readiness failed: %v", errNotStarted) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(errNotStarted.Error())) + return + } + + if open, max := c.proxy.ConnCount(); max > 0 && open == max { + err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max) + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + err := c.proxy.CheckConnections(ctx) + if err != nil { + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) +} + +// HandleLiveness indicates the process is up and responding to HTTP requests. +// If this check fails (because it's not reachable), the process is in a bad +// state and should be restarted. +func (c *Check) HandleLiveness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) +} diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go new file mode 100644 index 000000000..156d5c60f --- /dev/null +++ b/internal/healthcheck/healthcheck_test.go @@ -0,0 +1,241 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcheck_test + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/cloudsql" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/healthcheck" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/log" + "github.com/GoogleCloudPlatform/cloudsql-proxy/v2/internal/proxy" +) + +var ( + logger = log.NewStdLogger(os.Stdout, os.Stdout) + proxyHost = "127.0.0.1" + proxyPort = 9000 +) + +func proxyAddr() string { + return fmt.Sprintf("%s:%d", proxyHost, proxyPort) +} + +func dialTCP(t *testing.T, addr string) net.Conn { + for i := 0; i < 10; i++ { + conn, err := net.Dial("tcp", addr) + if err == nil { + return conn + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("failed to dial %v", addr) + return nil +} + +type fakeDialer struct{} + +func (*fakeDialer) Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error) { + conn, _ := net.Pipe() + return conn, nil +} + +func (*fakeDialer) EngineVersion(ctx context.Context, inst string) (string, error) { + return "POSTGRES_14", nil +} + +func (*fakeDialer) Close() error { + return nil +} + +type errorDialer struct { + fakeDialer +} + +func (*errorDialer) Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error) { + return nil, errors.New("errorDialer always errors") +} + +func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer) *proxy.Client { + c := &proxy.Config{ + Addr: proxyHost, + Port: proxyPort, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + MaxConnections: maxConns, + } + p, err := proxy.NewClient(context.Background(), dialer, logger, c) + if err != nil { + t.Fatalf("proxy.NewClient: %v", err) + } + return p +} + +func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client { + return newProxyWithParams(t, maxConns, &fakeDialer{}) +} + +func newTestProxyWithDialer(t *testing.T, d cloudsql.Dialer) *proxy.Client { + return newProxyWithParams(t, 0, d) +} + +func newTestProxy(t *testing.T) *proxy.Client { + return newProxyWithParams(t, 0, &fakeDialer{}) +} + +func TestHandleStartupWhenNotNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + rec := httptest.NewRecorder() + check.HandleStartup(rec, &http.Request{}) + + // Startup is not complete because the Check has not been notified of the + // proxy's startup. + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleStartupWhenNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + check.NotifyStarted() + + rec := httptest.NewRecorder() + check.HandleStartup(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleReadinessWhenNotNotified(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +func TestHandleReadinessForMaxConns(t *testing.T) { + p := newTestProxyWithMaxConns(t, 1) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + started := make(chan struct{}) + check := healthcheck.NewCheck(p, logger) + go p.Serve(context.Background(), func() { + check.NotifyStarted() + close(started) + }) + select { + case <-started: + // proxy has started + case <-time.After(10 * time.Second): + t.Fatal("proxy has not started after 10 seconds") + } + + conn := dialTCP(t, proxyAddr()) + defer conn.Close() + + // The proxy calls the dialer in a separate goroutine. So wait for that + // goroutine to run before asserting on the readiness response. + waitForConnect := func(t *testing.T, wantCode int) *http.Response { + for i := 0; i < 10; i++ { + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + resp := rec.Result() + if resp.StatusCode == wantCode { + return resp + } + time.Sleep(time.Second) + } + return nil + } + resp := waitForConnect(t, http.StatusServiceUnavailable) + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if !strings.Contains(string(body), "max connections") { + t.Fatalf("want max connections error, got = %v", string(body)) + } +} + +func TestHandleReadinessWithConnectionProblems(t *testing.T) { + p := newTestProxyWithDialer(t, &errorDialer{}) // error dialer will error on dial + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + check.NotifyStarted() + + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if want := "errorDialer"; !strings.Contains(string(body), want) { + t.Fatalf("want substring with = %q, got = %v", want, string(body)) + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index fd5750cbd..b1da81cad 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -250,7 +250,6 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * // Check if the caller has configured a dialer. // Otherwise, initialize a new one. if d == nil { - var err error dialerOpts, err := conf.DialerOptions(l) if err != nil { return nil, fmt.Errorf("error initializing dialer: %v", err) @@ -298,9 +297,54 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * return c, nil } +// CheckConnections dials each registered instance and reports any errors that +// may have occurred. +func (c *Client) CheckConnections(ctx context.Context) error { + var ( + wg sync.WaitGroup + errCh = make(chan error, len(c.mnts)) + ) + for _, m := range c.mnts { + wg.Add(1) + go func(inst string) { + defer wg.Done() + conn, err := c.dialer.Dial(ctx, inst) + if err != nil { + errCh <- err + return + } + cErr := conn.Close() + if err != nil { + errCh <- fmt.Errorf("%v: %v", inst, cErr) + } + }(m.inst) + } + wg.Wait() + + var mErr MultiErr + for i := 0; i < len(c.mnts); i++ { + select { + case err := <-errCh: + mErr = append(mErr, err) + default: + continue + } + } + if len(mErr) > 0 { + return mErr + } + return nil +} + +// ConnCount returns the number of open connections and the maximum allowed +// connections. Returns 0 when the maximum allowed connections have not been set. +func (c *Client) ConnCount() (uint64, uint64) { + return atomic.LoadUint64(&c.connCount), c.maxConns +} + // Serve starts proxying connections for all configured instances using the // associated socket. -func (c *Client) Serve(ctx context.Context) error { +func (c *Client) Serve(ctx context.Context, notify func()) error { ctx, cancel := context.WithCancel(ctx) defer cancel() exitCh := make(chan error) @@ -321,6 +365,7 @@ func (c *Client) Serve(ctx context.Context) error { } }(m) } + notify() return <-exitCh } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 40686e27d..16b1f5459 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -72,6 +72,10 @@ type errorDialer struct { fakeDialer } +func (*errorDialer) Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error) { + return nil, errors.New("errorDialer returns error on Dial") +} + func (*errorDialer) Close() error { return errors.New("errorDialer returns error on Close") } @@ -243,7 +247,11 @@ func TestClientInitialization(t *testing.T) { if err != nil { t.Fatalf("want error = nil, got = %v", err) } - defer c.Close() + defer func() { + if err := c.Close(); err != nil { + t.Logf("failed to close client: %v", err) + } + }() for _, addr := range tc.wantTCPAddrs { conn, err := net.Dial("tcp", addr) if err != nil { @@ -285,7 +293,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { t.Fatalf("proxy.NewClient error: %v", err) } defer c.Close() - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn1, err1 := net.Dial("tcp", "127.0.0.1:5000") if err1 != nil { @@ -299,11 +307,6 @@ func TestClientLimitsMaxConnections(t *testing.T) { } defer conn2.Close() - // try to read to check if the connection is closed - // wait only a second for the result (since nothing is writing to the - // socket) - conn2.SetReadDeadline(time.Now().Add(time.Second)) - wantEOF := func(t *testing.T, c net.Conn) { var got error for i := 0; i < 10; i++ { @@ -311,7 +314,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { if got == io.EOF { return } - time.Sleep(500 * time.Millisecond) + time.Sleep(time.Second) } t.Fatalf("conn.Read should return io.EOF, got = %v", got) } @@ -357,7 +360,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") _ = conn.Close() @@ -372,7 +375,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) var open []net.Conn for i := 0; i < 5; i++ { @@ -403,7 +406,7 @@ func TestClientClosesCleanly(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") _ = conn.Close() @@ -426,7 +429,7 @@ func TestClosesWithError(t *testing.T) { if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } - go c.Serve(context.Background()) + go c.Serve(context.Background(), func() {}) conn := tryTCPDial(t, "127.0.0.1:5000") defer conn.Close() @@ -491,3 +494,124 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) { } c.Close() } + +func TestClientNotifiesCallerOnServe(t *testing.T) { + ctx := context.Background() + in := &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + } + logger := log.NewStdLogger(os.Stdout, os.Stdout) + c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + done := make(chan struct{}) + notify := func() { close(done) } + + go c.Serve(ctx, notify) + + verifyNotification := func(t *testing.T, ch <-chan struct{}) { + for i := 0; i < 10; i++ { + select { + case <-ch: + return + default: + time.Sleep(100 * time.Millisecond) + } + } + t.Fatal("channel should have been closed but was not") + } + verifyNotification(t, done) +} + +func TestClientConnCount(t *testing.T) { + logger := log.NewStdLogger(os.Stdout, os.Stdout) + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + MaxConnections: 10, + } + + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + gotOpen, gotMax := c.ConnCount() + if gotOpen != 0 { + t.Fatalf("want 0 open connections, got = %v", gotOpen) + } + if gotMax != 10 { + t.Fatalf("want 10 max connections, got = %v", gotMax) + } + + conn := tryTCPDial(t, "127.0.0.1:5000") + defer conn.Close() + + verifyOpen := func(t *testing.T, want uint64) { + var got uint64 + for i := 0; i < 10; i++ { + got, _ = c.ConnCount() + if got == want { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("open connections, want = %v, got = %v", want, got) + } + verifyOpen(t, 1) +} + +func TestCheckConnections(t *testing.T) { + logger := log.NewStdLogger(os.Stdout, os.Stdout) + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + } + d := &fakeDialer{} + c, err := proxy.NewClient(context.Background(), d, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + if err = c.CheckConnections(context.Background()); err != nil { + t.Fatalf("CheckConnections failed: %v", err) + } + + if want, got := 1, d.dialAttempts(); want != got { + t.Fatalf("dial attempts: want = %v, got = %v", want, got) + } + + in = &proxy.Config{ + Addr: "127.0.0.1", + Port: 6000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg1"}, + {Name: "proj:region:pg2"}, + }, + } + ed := &errorDialer{} + c, err = proxy.NewClient(context.Background(), ed, logger, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + defer c.Close() + go c.Serve(context.Background(), func() {}) + + err = c.CheckConnections(context.Background()) + if err == nil { + t.Fatal("CheckConnections should have failed, but did not") + } +} diff --git a/testsV2/common_test.go b/testsV2/common_test.go index ac83ede47..57e324af2 100644 --- a/testsV2/common_test.go +++ b/testsV2/common_test.go @@ -23,6 +23,7 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "os" @@ -127,7 +128,15 @@ func (p *proxyExec) WaitForServe(ctx context.Context) (output string, err error) errCh <- err return } - buf.WriteString(s) + if _, err = buf.WriteString(s); err != nil { + errCh <- err + return + } + // Check for an unrecognized flag + if strings.Contains(s, "Error") { + errCh <- errors.New(s) + return + } if strings.Contains(s, "ready for new connections") { errCh <- nil return @@ -137,10 +146,10 @@ func (p *proxyExec) WaitForServe(ctx context.Context) (output string, err error) // Wait for either the background thread of the context to complete select { case <-ctx.Done(): - return buf.String(), fmt.Errorf("context done: %w", ctx.Err()) + return buf.String(), ctx.Err() case err := <-errCh: if err != nil { - return buf.String(), fmt.Errorf("proxy start failed: %w", err) + return buf.String(), err } } return buf.String(), nil diff --git a/testsV2/connection_test.go b/testsV2/connection_test.go index 63e8a8bdb..7089cdb92 100644 --- a/testsV2/connection_test.go +++ b/testsV2/connection_test.go @@ -17,6 +17,8 @@ package tests import ( "context" "database/sql" + "io/ioutil" + "net/http" "os" "testing" "time" @@ -78,3 +80,55 @@ func proxyConnTest(t *testing.T, args []string, driver, dsn string) { t.Fatalf("unable to exec on db: %s", err) } } + +// testHealthCheck verifies that when a proxy client serves the given instance, +// the readiness endpoint serves http.StatusOK. +func testHealthCheck(t *testing.T, connName string) { + ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout) + defer cancel() + + args := []string{connName, "--health-check"} + // Start the proxy. + p, err := StartProxy(ctx, args...) + if err != nil { + t.Fatalf("unable to start proxy: %v", err) + } + defer p.Close() + _, err = p.WaitForServe(ctx) + if err != nil { + t.Fatal(err) + } + + tryDial := func(t *testing.T) *http.Response { + var ( + err error + resp *http.Response + ) + for i := 0; i < 10; i++ { + resp, err = http.Get("http://localhost:9090/readiness") + if err != nil { + time.Sleep(100 * time.Millisecond) + } + if resp != nil { + return resp + } + } + t.Fatalf("HTTP GET failed: %v", err) + return nil + } + + resp := tryDial(t) + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read HTTP response body: %v", err) + } + defer resp.Body.Close() + if string(body) != "ok" { + t.Fatalf("response body was not ok, got = %v", string(body)) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("want %v, got %v", http.StatusOK, resp.StatusCode) + } +} diff --git a/testsV2/mysql_test.go b/testsV2/mysql_test.go index cb23e32ca..8373750c5 100644 --- a/testsV2/mysql_test.go +++ b/testsV2/mysql_test.go @@ -123,3 +123,7 @@ func TestMySQLAuthWithCredentialsFile(t *testing.T) { []string{"--credentials-file", path, *mysqlConnName}, "mysql", cfg.FormatDSN()) } + +func TestMySQLHealthCheck(t *testing.T) { + testHealthCheck(t, *mysqlConnName) +} diff --git a/testsV2/postgres_test.go b/testsV2/postgres_test.go index 994980d0a..fd8914d13 100644 --- a/testsV2/postgres_test.go +++ b/testsV2/postgres_test.go @@ -152,3 +152,7 @@ func TestPostgresIAMDBAuthn(t *testing.T) { []string{fmt.Sprintf("%s?auto-iam-authn=true", *postgresConnName)}, "pgx", dsn) } + +func TestPostgresHealthCheck(t *testing.T) { + testHealthCheck(t, *postgresConnName) +} diff --git a/testsV2/sqlserver_test.go b/testsV2/sqlserver_test.go index 3ba683391..e243ff4e8 100644 --- a/testsV2/sqlserver_test.go +++ b/testsV2/sqlserver_test.go @@ -84,3 +84,7 @@ func TestSQLServerAuthWithCredentialsFile(t *testing.T) { []string{"--credentials-file", path, *sqlserverConnName}, "sqlserver", dsn) } + +func TestSQLServerHealthCheck(t *testing.T) { + testHealthCheck(t, *sqlserverConnName) +}