diff --git a/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client.go b/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client.go index 309d901ef55..87e1944e70d 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client.go @@ -62,7 +62,8 @@ type TabletVStreamerClient struct { // mu protects isOpen, streamers, streamIdx and kschema. mu sync.Mutex - isOpen bool + isOpen bool + openConnection int32 tablet *topodatapb.Tablet target *querypb.Target @@ -74,7 +75,8 @@ type MySQLVStreamerClient struct { // mu protects isOpen, streamers, streamIdx and kschema. mu sync.Mutex - isOpen bool + isOpen bool + openConnection int32 sourceConnParams dbconfigs.Connector sourceSe *schema.Engine @@ -96,11 +98,11 @@ func NewTabletVStreamerClient(tablet *topodatapb.Tablet) *TabletVStreamerClient func (vsClient *TabletVStreamerClient) Open(ctx context.Context) (err error) { vsClient.mu.Lock() defer vsClient.mu.Unlock() - if vsClient.isOpen { + vsClient.openConnection++ + if vsClient.openConnection > 1 { return nil } vsClient.isOpen = true - vsClient.tsQueryService, err = tabletconn.GetDialer()(vsClient.tablet, grpcclient.FailFast(true)) return err } @@ -109,7 +111,10 @@ func (vsClient *TabletVStreamerClient) Open(ctx context.Context) (err error) { func (vsClient *TabletVStreamerClient) Close(ctx context.Context) (err error) { vsClient.mu.Lock() defer vsClient.mu.Unlock() - if !vsClient.isOpen { + if vsClient.openConnection > 0 { + vsClient.openConnection-- + } + if vsClient.openConnection > 0 { return nil } vsClient.isOpen = false @@ -150,10 +155,10 @@ func NewMySQLVStreamerClient() *MySQLVStreamerClient { func (vsClient *MySQLVStreamerClient) Open(ctx context.Context) (err error) { vsClient.mu.Lock() defer vsClient.mu.Unlock() - if vsClient.isOpen { + vsClient.openConnection++ + if vsClient.openConnection > 1 { return nil } - vsClient.isOpen = true // Let's create all the required components by vstreamer @@ -171,10 +176,12 @@ func (vsClient *MySQLVStreamerClient) Open(ctx context.Context) (err error) { func (vsClient *MySQLVStreamerClient) Close(ctx context.Context) (err error) { vsClient.mu.Lock() defer vsClient.mu.Unlock() - if !vsClient.isOpen { + if vsClient.openConnection > 0 { + vsClient.openConnection-- + } + if vsClient.openConnection > 0 { return nil } - vsClient.isOpen = false vsClient.sourceSe.Close() return nil diff --git a/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client_test.go b/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client_test.go index af3e762383b..3947697f566 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vstreamer_client_test.go @@ -142,6 +142,75 @@ func TestTabletVStreamerClientClose(t *testing.T) { } } +func TestTabletVStreamerClientCloseTwice(t *testing.T) { + tablet := addTablet(100) + defer deleteTablet(tablet) + + type fields struct { + tablet *topodatapb.Tablet + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + err string + }{ + { + name: "closes engine correctly", + fields: fields{ + tablet: tablet, + }, + args: args{ + ctx: context.Background(), + }, + }, + } + + for _, tcase := range tests { + t.Run(tcase.name, func(t *testing.T) { + vsClient := &TabletVStreamerClient{ + tablet: tcase.fields.tablet, + } + + err := vsClient.Open(tcase.args.ctx) + if err != nil { + t.Errorf("Failed to Open vsClient") + return + } + + // open again + err = vsClient.Open(tcase.args.ctx) + if err != nil { + t.Errorf("Failed to Open vsClient") + return + } + + err = vsClient.Close(tcase.args.ctx) + + if tcase.err != "" { + t.Errorf("MySQLVStreamerClient.Close() error:\n%v, want\n%v", err, tcase.err) + } + + if !vsClient.isOpen { + t.Errorf("MySQLVStreamerClient.Close() should not close the connection opened by other") + } + + err = vsClient.Close(tcase.args.ctx) + + if tcase.err != "" { + t.Errorf("MySQLVStreamerClient.Close() error:\n%v, want\n%v", err, tcase.err) + } + + if vsClient.isOpen { + t.Errorf("MySQLVStreamerClient.Close() isOpen set to true, expected false") + } + }) + } +} + func TestTabletVStreamerClientVStream(t *testing.T) { tablet := addTablet(100) defer deleteTablet(tablet) @@ -410,6 +479,83 @@ func TestMySQLVStreamerClientClose(t *testing.T) { } } +func TestMySQLVStreamerClientCloseTwice(t *testing.T) { + type fields struct { + isOpen bool + sourceConnParams dbconfigs.Connector + } + type args struct { + ctx context.Context + } + + tests := []struct { + name string + fields fields + args args + err string + }{ + { + name: "closes engine correctly", + fields: fields{ + sourceConnParams: dbcfgs.ExternalReplWithDB(), + }, + args: args{ + ctx: context.Background(), + }, + }, + } + + for _, tcase := range tests { + t.Run(tcase.name, func(t *testing.T) { + vsClient := &MySQLVStreamerClient{ + isOpen: tcase.fields.isOpen, + sourceConnParams: tcase.fields.sourceConnParams, + } + + err := vsClient.Open(tcase.args.ctx) + if err != nil { + t.Errorf("Failed to Open vsClient") + return + } + + // open again + err = vsClient.Open(tcase.args.ctx) + if err != nil { + t.Errorf("Failed to Open vsClient") + return + } + + err = vsClient.Close(tcase.args.ctx) + + if tcase.err != "" { + t.Errorf("MySQLVStreamerClient.Close() error:\n%v, want\n%v", err, tcase.err) + } + + if vsClient.isOpen { + t.Errorf("MySQLVStreamerClient.Close() should not close the connection opened by other") + } + + if !vsClient.sourceSe.IsOpen() { + t.Errorf("MySQLVStreamerClient.Close() expected sourceSe not to be closed") + } + + err = vsClient.Close(tcase.args.ctx) + + if tcase.err != "" { + t.Errorf("MySQLVStreamerClient.Close() error:\n%v, want\n%v", err, tcase.err) + } + + if vsClient.isOpen { + t.Errorf("MySQLVStreamerClient.Close() isOpen set to true, expected false") + } + + if vsClient.sourceSe.IsOpen() { + t.Errorf("MySQLVStreamerClient.Close() expected sourceSe to be closed") + } + }) + } +} + func TestMySQLVStreamerClientVStream(t *testing.T) { vsClient := &MySQLVStreamerClient{ sourceConnParams: dbcfgs.ExternalReplWithDB(),