Skip to content

Commit

Permalink
fix: resolve data race on closing client (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom authored Jul 8, 2022
1 parent bca543c commit e2b40c4
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 82 deletions.
6 changes: 5 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,11 @@ func runSignalWrapper(cmd *Command) error {
case p = <-startCh:
}
cmd.Println("The proxy has started successfully and is ready for new connections!")
defer p.Close()
defer func() {
if cErr := p.Close(); cErr != nil {
cmd.PrintErrf("error during shutdown: %v\n", cErr)
}
}()

go func() {
shutdownCh <- p.Serve(ctx)
Expand Down
2 changes: 1 addition & 1 deletion cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func (*spyDialer) Close() error {
}

func TestCommandWithCustomDialer(t *testing.T) {
want := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
want := "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
s := &spyDialer{}
c := NewCommand(WithDialer(s))
// Keep the test output quiet
Expand Down
186 changes: 107 additions & 79 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,76 +176,21 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
}
}

pc := newPortConfig(conf.Port)
var mnts []*socketMount
pc := newPortConfig(conf.Port)
for _, inst := range conf.Instances {
var (
// network is one of "tcp" or "unix"
network string
// address is either a TCP host port, or a Unix socket
address string
)
// IF
// a global Unix socket directory is NOT set AND
// an instance-level Unix socket is NOT set
// (e.g., I didn't set a Unix socket globally or for this instance)
// OR
// an instance-level TCP address or port IS set
// (e.g., I'm overriding any global settings to use TCP for this
// instance)
// use a TCP listener.
// Otherwise, use a Unix socket.
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
(inst.Addr != "" || inst.Port != 0) {
network = "tcp"

a := conf.Addr
if inst.Addr != "" {
a = inst.Addr
}

var np int
switch {
case inst.Port != 0:
np = inst.Port
case conf.Port != 0:
np = pc.nextPort()
default:
np = pc.nextPort()
}

address = net.JoinHostPort(a, fmt.Sprint(np))
} else {
network = "unix"

dir := conf.UnixSocket
if dir == "" {
dir = inst.UnixSocket
}
ud, err := UnixSocketDir(dir, inst.Name)
if err != nil {
return nil, err
}
// Create the parent directory that will hold the socket.
if _, err := os.Stat(ud); err != nil {
if err = os.Mkdir(ud, 0777); err != nil {
return nil, err
}
}
// use the Postgres-specific socket name
address = filepath.Join(ud, ".s.PGSQL.5432")
}

m := &socketMount{inst: inst.Name}
addr, err := m.listen(ctx, network, address)
m, err := newSocketMount(ctx, conf, pc, inst)
if err != nil {
for _, m := range mnts {
m.close()
mErr := m.Close()
if mErr != nil {
cmd.PrintErrf("failed to close mount: %v", mErr)
}
}
return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err)
}

cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String())
cmd.Printf("[%s] Listening on %s\n", inst.Name, m.Addr())
mnts = append(mnts, m)
}

Expand Down Expand Up @@ -277,22 +222,45 @@ func (c *Client) Serve(ctx context.Context) error {
return <-exitCh
}

// Close triggers the proxyClient to shutdown.
func (c *Client) Close() {
defer c.dialer.Close()
// MultiErr is a group of errors wrapped into one.
type MultiErr []error

// Error returns a single string representing one or more errors.
func (m MultiErr) Error() string {
l := len(m)
if l == 1 {
return m[0].Error()
}
var errs []string
for _, e := range m {
errs = append(errs, e.Error())
}
return strings.Join(errs, ", ")
}

func (c *Client) Close() error {
var mErr MultiErr
for _, m := range c.mnts {
m.close()
err := m.Close()
if err != nil {
mErr = append(mErr, err)
}
}
cErr := c.dialer.Close()
if cErr != nil {
mErr = append(mErr, cErr)
}
if len(mErr) > 0 {
return mErr
}
return nil
}

// serveSocketMount persistently listens to the socketMounts listener and proxies connections to a
// given AlloyDB instance.
func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error {
if s.listener == nil {
return fmt.Errorf("[%s] mount doesn't have a listener set", s.inst)
}
for {
cConn, err := s.listener.Accept()
cConn, err := s.Accept()
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
c.cmd.PrintErrf("[%s] Error accepting connection: %v\n", s.inst, err)
Expand Down Expand Up @@ -327,22 +295,82 @@ type socketMount struct {
listener net.Listener
}

// listen causes a socketMount to create a Listener at the specified network address.
func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) {
func newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) {
var (
// network is one of "tcp" or "unix"
network string
// address is either a TCP host port, or a Unix socket
address string
)
// IF
// a global Unix socket directory is NOT set AND
// an instance-level Unix socket is NOT set
// (e.g., I didn't set a Unix socket globally or for this instance)
// OR
// an instance-level TCP address or port IS set
// (e.g., I'm overriding any global settings to use TCP for this
// instance)
// use a TCP listener.
// Otherwise, use a Unix socket.
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
(inst.Addr != "" || inst.Port != 0) {
network = "tcp"

a := conf.Addr
if inst.Addr != "" {
a = inst.Addr
}

var np int
switch {
case inst.Port != 0:
np = inst.Port
default:
np = pc.nextPort()
}

address = net.JoinHostPort(a, fmt.Sprint(np))
} else {
network = "unix"

dir := conf.UnixSocket
if dir == "" {
dir = inst.UnixSocket
}
ud, err := UnixSocketDir(dir, inst.Name)
if err != nil {
return nil, err
}
// Create the parent directory that will hold the socket.
if _, err := os.Stat(ud); err != nil {
if err = os.Mkdir(ud, 0777); err != nil {
return nil, err
}
}
// use the Postgres-specific socket name
address = filepath.Join(ud, ".s.PGSQL.5432")
}

lc := net.ListenConfig{KeepAlive: 30 * time.Second}
l, err := lc.Listen(ctx, network, address)
ln, err := lc.Listen(ctx, network, address)
if err != nil {
return nil, err
}
s.listener = l
return s.listener.Addr(), nil
m := &socketMount{inst: inst.Name, listener: ln}
return m, nil
}

func (s *socketMount) Addr() net.Addr {
return s.listener.Addr()
}

func (s *socketMount) Accept() (net.Conn, error) {
return s.listener.Accept()
}

// close stops the mount from listening for any more connections
func (s *socketMount) close() error {
err := s.listener.Close()
s.listener = nil
return err
func (s *socketMount) Close() error {
return s.listener.Close()
}

// proxyConn sets up a bidirectional copy between two open connections
Expand Down
88 changes: 87 additions & 1 deletion internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ package proxy_test

import (
"context"
"errors"
"io/ioutil"
"net"
"os"
"path/filepath"
"testing"
"time"

"cloud.google.com/go/alloydbconn"
"github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy"
Expand All @@ -37,13 +39,22 @@ type testCase struct {
}

func (fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) {
return nil, nil
conn, _ := net.Pipe()
return conn, nil
}

func (fakeDialer) Close() error {
return nil
}

type errorDialer struct {
fakeDialer
}

func (errorDialer) Close() error {
return errors.New("errorDialer returns error on Close")
}

func createTempDir(t *testing.T) (string, func()) {
testDir, err := ioutil.TempDir("", "*")
if err != nil {
Expand Down Expand Up @@ -216,6 +227,81 @@ func TestClientInitialization(t *testing.T) {
}
}

func TestClientClosesCleanly(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:reg:inst"},
},
Dialer: fakeDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
if dErr != nil {
t.Fatalf("net.Dial error = %v", dErr)
}
_ = conn.Close()

if err := c.Close(); err != nil {
t.Fatalf("c.Close() error = %v", err)
}
}

func TestClosesWithError(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:reg:inst"},
},
Dialer: errorDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

if err = c.Close(); err == nil {
t.Fatal("c.Close() should error, got nil")
}
}

func TestMultiErrorFormatting(t *testing.T) {
tcs := []struct {
desc string
in proxy.MultiErr
want string
}{
{
desc: "with one error",
in: proxy.MultiErr{errors.New("woops")},
want: "woops",
},
{
desc: "with many errors",
in: proxy.MultiErr{errors.New("woops"), errors.New("another error")},
want: "woops, another error",
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
if got := tc.in.Error(); got != tc.want {
t.Errorf("want = %v, got = %v", tc.want, got)
}
})
}
}

func TestClientInitializationWorksRepeatedly(t *testing.T) {
// The client creates a Unix socket on initial startup and does not remove
// it on shutdown. This test ensures the existing socket does not cause
Expand Down

0 comments on commit e2b40c4

Please sign in to comment.