Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions go/vt/topo/etcd2topo/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import (

// Create is part of the topo.Conn interface.
func (s *Server) Create(ctx context.Context, filePath string, contents []byte) (topo.Version, error) {
if err := s.checkClosed(); err != nil {
return nil, convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

// We have to do a transaction, comparing existing version with 0.
Expand All @@ -47,6 +50,9 @@ func (s *Server) Create(ctx context.Context, filePath string, contents []byte) (

// Update is part of the topo.Conn interface.
func (s *Server) Update(ctx context.Context, filePath string, contents []byte, version topo.Version) (topo.Version, error) {
if err := s.checkClosed(); err != nil {
return nil, convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

if version != nil {
Expand Down Expand Up @@ -75,6 +81,9 @@ func (s *Server) Update(ctx context.Context, filePath string, contents []byte, v

// Get is part of the topo.Conn interface.
func (s *Server) Get(ctx context.Context, filePath string) ([]byte, topo.Version, error) {
if err := s.checkClosed(); err != nil {
return nil, nil, convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

resp, err := s.cli.Get(ctx, nodePath)
Expand All @@ -90,6 +99,9 @@ func (s *Server) Get(ctx context.Context, filePath string) ([]byte, topo.Version

// GetVersion is part of the topo.Conn interface.
func (s *Server) GetVersion(ctx context.Context, filePath string, version int64) ([]byte, error) {
if err := s.checkClosed(); err != nil {
return nil, convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

resp, err := s.cli.Get(ctx, nodePath, clientv3.WithRev(version))
Expand All @@ -105,6 +117,9 @@ func (s *Server) GetVersion(ctx context.Context, filePath string, version int64)

// List is part of the topo.Conn interface.
func (s *Server) List(ctx context.Context, filePathPrefix string) ([]topo.KVInfo, error) {
if err := s.checkClosed(); err != nil {
return []topo.KVInfo{}, convertError(err, filePathPrefix)
}
nodePathPrefix := path.Join(s.root, filePathPrefix)

resp, err := s.cli.Get(ctx, nodePathPrefix, clientv3.WithPrefix())
Expand All @@ -127,6 +142,9 @@ func (s *Server) List(ctx context.Context, filePathPrefix string) ([]topo.KVInfo

// Delete is part of the topo.Conn interface.
func (s *Server) Delete(ctx context.Context, filePath string, version topo.Version) error {
if err := s.checkClosed(); err != nil {
return convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

if version != nil {
Expand Down
11 changes: 11 additions & 0 deletions go/vt/topo/etcd2topo/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ We follow these conventions within this package:
package etcd2topo

import (
"context"
"crypto/tls"
"crypto/x509"
"strings"
Expand Down Expand Up @@ -93,6 +94,16 @@ func registerEtcd2TopoFlags(fs *pflag.FlagSet) {
utils.SetFlagStringVar(fs, &serverCaPath, "topo-etcd-tls-ca", serverCaPath, "path to the ca to use to validate the server cert when connecting to the etcd topo server")
}

// checkClosed returns context.Canceled if the server has been closed.
// This mimics the pattern used for context cancellation which gets converted
// to topo.Interrupted by convertError().
func (s *Server) checkClosed() error {
if s.cli == nil {
return context.Canceled
}
return nil
}

// Close implements topo.Server.Close.
// It will nil out the global and cells fields, so any attempt to
// re-use this server will panic.
Expand Down
69 changes: 69 additions & 0 deletions go/vt/topo/etcd2topo/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,75 @@ func TestEtcd2TopoGetTabletsPartialResults(t *testing.T) {
}
}

// TestEtcd2TopoServerClosed tests that operations on a closed server return
// appropriate errors instead of panicking due to nil pointer dereference.
func TestEtcd2TopoServerClosed(t *testing.T) {
// Start a single etcd in the background.
clientAddr, _ := startEtcd(t, 0)

testRoot := "/test-closed"

// Create the server on the new root.
ts, err := topo.OpenServer("etcd2", clientAddr, path.Join(testRoot, topo.GlobalCell))
require.NoError(t, err, "OpenServer() failed: %v", err)

// Create the CellInfo first.
ctx := context.Background()
err = ts.CreateCellInfo(ctx, "test_cell", &topodatapb.CellInfo{
ServerAddress: clientAddr,
Root: path.Join(testRoot, "test_cell"),
})
require.NoError(t, err, "CreateCellInfo() failed: %v", err)

// Get the connection for the cell
conn, err := ts.ConnForCell(ctx, "test_cell")
require.NoError(t, err, "ConnForCell() failed: %v", err)

// Test that operations work before closing
testPath := "test_key"
testContents := []byte("test_value")

_, err = conn.Create(ctx, testPath, testContents)
require.NoError(t, err, "Create() before close should succeed")

// Close the connection
ts.Close()

// Test that operations return appropriate errors after closing
_, err = conn.Create(ctx, "another_key", testContents)
require.Error(t, err, "Create() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

_, _, err = conn.Get(ctx, testPath)
require.Error(t, err, "Get() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

_, err = conn.GetVersion(ctx, testPath, 1)
require.Error(t, err, "GetVersion() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

err = conn.Delete(ctx, testPath, nil)
require.Error(t, err, "Delete() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

_, err = conn.List(ctx, "/")
require.Error(t, err, "List() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

_, err = conn.Update(ctx, testPath, testContents, nil)
require.Error(t, err, "Update() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

// Test watch operations after close
_, _, err = conn.Watch(ctx, testPath)
require.Error(t, err, "Watch() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)

_, _, err = conn.WatchRecursive(ctx, "/")
require.Error(t, err, "WatchRecursive() after close should fail")
require.True(t, topo.IsErrType(err, topo.Interrupted), "Error should be topo.Interrupted, got: %v", err)
}

// testKeyspaceLock tests etcd-specific heartbeat (TTL).
// Note TTL granularity is in seconds, even though the API uses time.Duration.
// So we have to wait a long time in these tests.
Expand Down
6 changes: 6 additions & 0 deletions go/vt/topo/etcd2topo/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import (

// Watch is part of the topo.Conn interface.
func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <-chan *topo.WatchData, error) {
if err := s.checkClosed(); err != nil {
return nil, nil, convertError(err, filePath)
}
nodePath := path.Join(s.root, filePath)

// Get the initial version of the file
Expand Down Expand Up @@ -160,6 +163,9 @@ func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <

// WatchRecursive is part of the topo.Conn interface.
func (s *Server) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.WatchDataRecursive, <-chan *topo.WatchDataRecursive, error) {
if err := s.checkClosed(); err != nil {
return nil, nil, convertError(err, dirpath)
}
nodePath := path.Join(s.root, dirpath)
if !strings.HasSuffix(nodePath, "/") {
nodePath = nodePath + "/"
Expand Down
Loading