diff --git a/go/vt/discovery/tablet_stats_cache_wait_test.go b/go/vt/discovery/tablet_stats_cache_wait_test.go index 1fec17ced81..ceed880eded 100644 --- a/go/vt/discovery/tablet_stats_cache_wait_test.go +++ b/go/vt/discovery/tablet_stats_cache_wait_test.go @@ -54,7 +54,8 @@ func (a TargetArray) Less(i, j int) bool { func TestFindAllKeyspaceShards(t *testing.T) { ctx := context.Background() ts := memorytopo.NewServer("cell1", "cell2") - flag.Set("srv_topo_cache_ttl", "0s") // No caching values + flag.Set("srv_topo_cache_ttl", "0s") // No caching values + flag.Set("srv_topo_cache_refresh", "0s") // No caching values rs := srvtopo.NewResilientServer(ts, "TestFindAllKeyspaceShards") // No keyspace / shards. diff --git a/go/vt/srvtopo/resilient_server.go b/go/vt/srvtopo/resilient_server.go index ed4548f5452..9f60f328615 100644 --- a/go/vt/srvtopo/resilient_server.go +++ b/go/vt/srvtopo/resilient_server.go @@ -35,7 +35,20 @@ import ( ) var ( - srvTopoCacheTTL = flag.Duration("srv_topo_cache_ttl", 1*time.Second, "how long to use cached entries for topology") + // srvTopoCacheTTL and srvTopoCacheRefresh control the behavior of + // the caching for both watched and unwatched values. + // + // For entries we don't watch (like the list of Keyspaces), we refresh + // the cached list from the topo after srv_topo_cache_refresh elapses. + // If the fetch fails, we hold onto the cached value until + // srv_topo_cache_ttl elapses. + // + // For entries we watch (like the SrvKeyspace for a given cell), if + // setting the watch fails, we will use the last known value until + // srv_topo_cache_ttl elapses and we only try to re-establish the watch + // once every srv_topo_cache_refresh interval. + srvTopoCacheTTL = flag.Duration("srv_topo_cache_ttl", 1*time.Second, "how long to use cached entries for topology") + srvTopoCacheRefresh = flag.Duration("srv_topo_cache_refresh", 1*time.Second, "how frequently to refresh the topology for cached entries") ) const ( @@ -96,9 +109,10 @@ const ( // - limit the QPS to the underlying topo.Server // - return the last known value of the data if there is an error type ResilientServer struct { - topoServer *topo.Server - cacheTTL time.Duration - counts *stats.Counters + topoServer *topo.Server + cacheTTL time.Duration + cacheRefresh time.Duration + counts *stats.Counters // mutex protects the cache map itself, not the individual // values in the cache. @@ -115,6 +129,7 @@ type srvKeyspaceNamesEntry struct { mutex sync.Mutex insertionTime time.Time + lastQueryTime time.Time value []string lastError error lastErrorCtx context.Context @@ -134,12 +149,17 @@ type srvKeyspaceEntry struct { // // if watchrunning is not set, the next time we try to access the // keyspace, we will start a watch. - // if watchrunning is set, we are guaranteed to have exactly one of - // value or lastError be nil, and the other non-nil. + // if watchrunning is set, we are guaranteed to have lastError be + // non-nil and an up-to-date value (which may be nil) watchRunning bool value *topodatapb.SrvKeyspace lastError error + // valueTime is the time when the watch last obtained a non-nil value. + // It is compared to the TTL to determine if we can return the value + // when the watch is failing + lastValueTime time.Time + // lastErrorCtx tries to remember the context of the query // that failed to get the SrvKeyspace, so we can display it in // the status UI. The background routine that refreshes the @@ -147,15 +167,24 @@ type srvKeyspaceEntry struct { // The intent is to have the source of a query that for instance // has a bad keyspace or cell name. lastErrorCtx context.Context + + // lastErrorTime records the time that the watch failed, so that + // any requests that come in + lastErrorTime time.Time } // NewResilientServer creates a new ResilientServer // based on the provided topo.Server. func NewResilientServer(base *topo.Server, counterPrefix string) *ResilientServer { + if *srvTopoCacheRefresh > *srvTopoCacheTTL { + log.Fatalf("srv_topo_cache_refresh must be less than or equal to srv_topo_cache_ttl") + } + return &ResilientServer{ - topoServer: base, - cacheTTL: *srvTopoCacheTTL, - counts: stats.NewCounters(counterPrefix + "Counts"), + topoServer: base, + cacheTTL: *srvTopoCacheTTL, + cacheRefresh: *srvTopoCacheRefresh, + counts: stats.NewCounters(counterPrefix + "Counts"), srvKeyspaceNamesCache: make(map[string]*srvKeyspaceNamesEntry), srvKeyspaceCache: make(map[string]*srvKeyspaceEntry), @@ -184,29 +213,45 @@ func (server *ResilientServer) GetSrvKeyspaceNames(ctx context.Context, cell str entry.mutex.Lock() defer entry.mutex.Unlock() - // If the entry is fresh enough, return it - if time.Now().Sub(entry.insertionTime) < server.cacheTTL { - return entry.value, entry.lastError + // If it is not time to check again, then return either the cached + // value or the cached error + cacheValid := entry.value != nil && time.Since(entry.insertionTime) < server.cacheTTL + shouldRefresh := time.Since(entry.lastQueryTime) > server.cacheRefresh + + if !shouldRefresh { + if cacheValid { + return entry.value, nil + } + return nil, entry.lastError } - // Not in cache or too old, get the real value. We use the context that issued - // the query here. + // Not in cache or needs refresh so try to get the real value. + // We use the context that issued the query here. result, err := server.topoServer.GetSrvKeyspaceNames(ctx, cell) - if err != nil { + if err == nil { + // save the value we got and the current time in the cache + entry.insertionTime = time.Now() + entry.value = result + } else { if entry.insertionTime.IsZero() { server.counts.Add(errorCategory, 1) log.Errorf("GetSrvKeyspaceNames(%v, %v) failed: %v (no cached value, caching and returning error)", ctx, cell, err) - } else { + + } else if cacheValid { server.counts.Add(cachedCategory, 1) log.Warningf("GetSrvKeyspaceNames(%v, %v) failed: %v (returning cached value: %v %v)", ctx, cell, err, entry.value, entry.lastError) - return entry.value, entry.lastError + result = entry.value + err = nil + } else { + server.counts.Add(errorCategory, 1) + log.Errorf("GetSrvKeyspaceNames(%v, %v) failed: %v (cached value expired)", ctx, cell, err) + entry.insertionTime = time.Time{} + entry.value = nil } } - // save the value we got and the current time in the cache - entry.insertionTime = time.Now() - entry.value = result entry.lastError = err + entry.lastQueryTime = time.Now() entry.lastErrorCtx = ctx return result, err } @@ -259,37 +304,72 @@ func (server *ResilientServer) GetSrvKeyspace(ctx context.Context, cell, keyspac return entry.value, entry.lastError } - // Watch is not running, let's try to start it. + // Watch is not running, but check if the last time we got an error was + // more recent than the refresh interval. + // + // If so return either the last cached value or the last error we got. + cacheValid := entry.value != nil && time.Since(entry.lastValueTime) < server.cacheTTL + shouldRefresh := time.Since(entry.lastErrorTime) > server.cacheRefresh + + if !shouldRefresh { + if cacheValid { + server.counts.Add(cachedCategory, 1) + return entry.value, nil + } + return nil, entry.lastError + } + + // Time to try to start the watch again. // We use a background context, as starting the watch should keep going // even if the current query context is short-lived. newCtx := context.Background() - current, changes, _ := server.topoServer.WatchSrvKeyspace(newCtx, cell, keyspace) + current, changes, cancel := server.topoServer.WatchSrvKeyspace(newCtx, cell, keyspace) if current.Err != nil { // lastError and lastErrorCtx will be visible from the UI // until the next try - entry.value = nil entry.lastError = current.Err entry.lastErrorCtx = ctx - log.Errorf("WatchSrvKeyspace failed for %v/%v: %v", cell, keyspace, current.Err) + entry.lastErrorTime = time.Now() + + // if the node disappears, delete the cached value + if current.Err == topo.ErrNoNode { + entry.value = nil + } + + server.counts.Add(errorCategory, 1) + log.Errorf("Initial WatchSrvKeyspace failed for %v/%v: %v", cell, keyspace, current.Err) + + if cacheValid { + return entry.value, nil + } + return nil, current.Err } // we are now watching, cache the first notification entry.watchRunning = true entry.value = current.Value + entry.lastValueTime = time.Now() entry.lastError = nil entry.lastErrorCtx = nil - go func() { + defer cancel() + for c := range changes { if c.Err != nil { - // Watch errored out. We log it, clear - // our record, and return. - err := fmt.Errorf("watch for SrvKeyspace %v in cell %v failed: %v", keyspace, cell, c.Err) + // Watch errored out. + // + // Log it and store the error, but do not clear the value + // so it can be used until the ttl elapses unless the node + // was deleted. + err := fmt.Errorf("WatchSrvKeyspace failed for %v/%v: %v", cell, keyspace, c.Err) log.Errorf("%v", err) + server.counts.Add(errorCategory, 1) entry.mutex.Lock() + if c.Err == topo.ErrNoNode { + entry.value = nil + } entry.watchRunning = false - entry.value = nil entry.lastError = err entry.lastErrorCtx = nil entry.mutex.Unlock() @@ -299,6 +379,7 @@ func (server *ResilientServer) GetSrvKeyspace(ctx context.Context, cell, keyspac // We got a new value, save it. entry.mutex.Lock() entry.value = c.Value + entry.lastValueTime = time.Now() entry.lastError = nil entry.lastErrorCtx = nil entry.mutex.Unlock() diff --git a/go/vt/srvtopo/resilient_server_test.go b/go/vt/srvtopo/resilient_server_flaky_test.go similarity index 55% rename from go/vt/srvtopo/resilient_server_test.go rename to go/vt/srvtopo/resilient_server_flaky_test.go index 6e93cdfbbbd..ad71e31d548 100644 --- a/go/vt/srvtopo/resilient_server_test.go +++ b/go/vt/srvtopo/resilient_server_flaky_test.go @@ -18,7 +18,9 @@ package srvtopo import ( "bytes" + "fmt" "html/template" + "reflect" "sync" "testing" "time" @@ -35,7 +37,14 @@ import ( // TestGetSrvKeyspace will test we properly return updated SrvKeyspace. func TestGetSrvKeyspace(t *testing.T) { - ts := memorytopo.NewServer("test_cell") + ts, factory := memorytopo.NewServerAndFactory("test_cell") + *srvTopoCacheTTL = time.Duration(100 * time.Millisecond) + *srvTopoCacheRefresh = time.Duration(40 * time.Millisecond) + defer func() { + *srvTopoCacheTTL = 1 * time.Second + *srvTopoCacheRefresh = 1 * time.Second + }() + rs := NewResilientServer(ts, "TestGetSrvKeyspace") // Ask for a not-yet-created keyspace @@ -44,6 +53,9 @@ func TestGetSrvKeyspace(t *testing.T) { t.Fatalf("GetSrvKeyspace(not created) got unexpected error: %v", err) } + // Wait until the cached error expires. + time.Sleep(*srvTopoCacheRefresh + 10*time.Millisecond) + // Set SrvKeyspace with value want := &topodatapb.SrvKeyspace{ ShardingColumnName: "id", @@ -56,6 +68,7 @@ func TestGetSrvKeyspace(t *testing.T) { expiry := time.Now().Add(5 * time.Second) for { got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err != nil { t.Fatalf("GetSrvKeyspace got unexpected error: %v", err) } @@ -65,7 +78,18 @@ func TestGetSrvKeyspace(t *testing.T) { if time.Now().After(expiry) { t.Fatalf("GetSrvKeyspace() timeout = %+v, want %+v", got, want) } - time.Sleep(10 * time.Millisecond) + time.Sleep(2 * time.Millisecond) + } + + // make sure the HTML template works + templ := template.New("").Funcs(status.StatusFuncs) + templ, err = templ.Parse(TopoTemplate) + if err != nil { + t.Fatalf("error parsing template: %v", err) + } + wr := &bytes.Buffer{} + if err := templ.Execute(wr, rs.CacheStatus()); err != nil { + t.Fatalf("error executing template: %v", err) } // Now delete the SrvKeyspace, wait until we get the error. @@ -89,8 +113,10 @@ func TestGetSrvKeyspace(t *testing.T) { ShardingColumnName: "id2", ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, } + ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want) expiry = time.Now().Add(5 * time.Second) + updateTime := time.Now() for { got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") if err == nil && proto.Equal(want, got) { @@ -102,15 +128,88 @@ func TestGetSrvKeyspace(t *testing.T) { time.Sleep(time.Millisecond) } - // make sure the HTML template works - templ := template.New("").Funcs(status.StatusFuncs) - templ, err = templ.Parse(TopoTemplate) - if err != nil { - t.Fatalf("error parsing template: %v", err) + // Now simulate a topo service error and see that the last value is + // cached for at least half of the expected ttl. + errorReqsBefore, _ := rs.counts.Counts()[errorCategory] + forceErr := fmt.Errorf("test topo error") + factory.SetError(forceErr) + + expiry = updateTime.Add(*srvTopoCacheTTL / 2) + for { + got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err != nil || !proto.Equal(want, got) { + // On a slow test machine it is possible that we never end up + // verifying the value is cached because it could take too long to + // even get into this loop... so log this as an informative message + // but don't fail the test + if time.Now().After(expiry) { + t.Logf("test execution was too slow -- caching was not verified") + break + } + + t.Errorf("expected keyspace to be cached for at least %s seconds, got error %v", time.Since(updateTime), err) + } + + if time.Now().After(expiry) { + break + } + + time.Sleep(time.Millisecond) } - wr := &bytes.Buffer{} - if err := templ.Execute(wr, rs.CacheStatus()); err != nil { - t.Fatalf("error executing template: %v", err) + + // Now wait for the TTL to expire and we should get the expected error + expiry = time.Now().Add(1 * time.Second) + for { + _, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err != nil || err == forceErr { + break + } + + if time.Now().After(expiry) { + t.Fatalf("timed out waiting for error to be returned") + } + time.Sleep(time.Millisecond) + } + + // Clear the error away and check that the cached error is still returned + // until the refresh interval elapses + factory.SetError(nil) + _, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err == nil || err != forceErr { + t.Fatalf("expected error to be cached") + } + + // Now sleep for the rest of the interval and we should get the value again + time.Sleep(*srvTopoCacheRefresh) + got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err != nil || !proto.Equal(want, got) { + t.Errorf("expected value to be restored, got %v", err) + } + + // Check that there were three errors counted during the interval, + // one for the original watch failing, then three more attempts to + // re-establish the watch + errorReqs, _ := rs.counts.Counts()[errorCategory] + if errorReqs-errorReqsBefore != 4 { + t.Errorf("expected 4 error requests got %d", errorReqs-errorReqsBefore) + } + + // Check that the watch now works to update the value + want = &topodatapb.SrvKeyspace{ + ShardingColumnName: "id3", + ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, + } + ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want) + expiry = time.Now().Add(5 * time.Second) + for { + got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks") + if err == nil && proto.Equal(want, got) { + break + } + if time.Now().After(expiry) { + t.Fatalf("timeout waiting for new keyspace value") + } + time.Sleep(time.Millisecond) } } @@ -118,6 +217,12 @@ func TestGetSrvKeyspace(t *testing.T) { // the topo server upon failure. func TestSrvKeyspaceCachedError(t *testing.T) { ts := memorytopo.NewServer("test_cell") + *srvTopoCacheTTL = 100 * time.Millisecond + *srvTopoCacheRefresh = 40 * time.Millisecond + defer func() { + *srvTopoCacheTTL = 1 * time.Second + *srvTopoCacheRefresh = 1 * time.Second + }() rs := NewResilientServer(ts, "TestSrvKeyspaceCachedErrors") // Ask for an unknown keyspace, should get an error. @@ -134,6 +239,7 @@ func TestSrvKeyspaceCachedError(t *testing.T) { t.Errorf("Context wasn't saved properly") } + time.Sleep(*srvTopoCacheTTL + 10*time.Millisecond) // Ask again with a different context, should get an error and // save that context. ctx, cancel := context.WithCancel(ctx) @@ -186,10 +292,10 @@ func TestGetSrvKeyspaceCreated(t *testing.T) { } func TestWatchSrvVSchema(t *testing.T) { + watchSrvVSchemaSleepTime = 10 * time.Millisecond ctx := context.Background() ts := memorytopo.NewServer("test_cell") rs := NewResilientServer(ts, "TestWatchSrvVSchema") - watchSrvVSchemaSleepTime = 10 * time.Millisecond // mu protects watchValue and watchErr. mu := sync.Mutex{} @@ -268,3 +374,107 @@ func TestWatchSrvVSchema(t *testing.T) { time.Sleep(10 * time.Millisecond) } } + +func TestGetSrvKeyspaceNames(t *testing.T) { + ts, factory := memorytopo.NewServerAndFactory("test_cell") + *srvTopoCacheTTL = 100 * time.Millisecond + *srvTopoCacheRefresh = 40 * time.Millisecond + defer func() { + *srvTopoCacheTTL = 1 * time.Second + *srvTopoCacheRefresh = 1 * time.Second + }() + rs := NewResilientServer(ts, "TestGetSrvKeyspaceNames") + + // Set SrvKeyspace with value + want := &topodatapb.SrvKeyspace{ + ShardingColumnName: "id", + ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, + } + ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want) + ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks2", want) + + ctx := context.Background() + names, err := rs.GetSrvKeyspaceNames(ctx, "test_cell") + if err != nil { + t.Errorf("GetSrvKeyspaceNames unexpected error %v", err) + } + wantNames := []string{"test_ks", "test_ks2"} + + if !reflect.DeepEqual(names, wantNames) { + t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames) + } + + forceErr := fmt.Errorf("force test error") + factory.SetError(forceErr) + + // Check that we get the cached value until at least the refresh interval + // elapses but before the TTL expires + start := time.Now() + for { + names, err = rs.GetSrvKeyspaceNames(ctx, "test_cell") + if err != nil { + t.Errorf("GetSrvKeyspaceNames unexpected error %v", err) + } + + if !reflect.DeepEqual(names, wantNames) { + t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames) + } + + if time.Since(start) >= *srvTopoCacheRefresh+10*time.Millisecond { + break + } + + time.Sleep(time.Millisecond) + } + + // Now wait for it to expire from cache + for { + _, err = rs.GetSrvKeyspaceNames(ctx, "test_cell") + if err != nil { + break + } + + time.Sleep(2 * time.Millisecond) + + if time.Since(start) > 2*time.Second { + t.Fatalf("expected error after TTL expires") + } + } + + if err != forceErr { + t.Errorf("got error %v want %v", err, forceErr) + } + + // Check that we only checked the topo service 1 or 2 times during the + // period where we got the cached error. + cachedReqs, ok := rs.counts.Counts()[cachedCategory] + if !ok || cachedReqs > 2 { + t.Errorf("expected <= 2 cached requests got %v", cachedReqs) + } + + // Clear the error and wait until the cached error state expires + factory.SetError(nil) + + start = time.Now() + for { + names, err = rs.GetSrvKeyspaceNames(ctx, "test_cell") + if err == nil { + break + } + + time.Sleep(2 * time.Millisecond) + + if time.Since(start) > 2*time.Second { + t.Fatalf("expected error after TTL expires") + } + } + + if !reflect.DeepEqual(names, wantNames) { + t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames) + } + + errorReqs, ok := rs.counts.Counts()[errorCategory] + if !ok || errorReqs != 1 { + t.Errorf("expected 1 error request got %v", errorReqs) + } +} diff --git a/go/vt/topo/memorytopo/directory.go b/go/vt/topo/memorytopo/directory.go index 2657f8f314c..b6884ce940c 100644 --- a/go/vt/topo/memorytopo/directory.go +++ b/go/vt/topo/memorytopo/directory.go @@ -29,6 +29,10 @@ func (c *Conn) ListDir(ctx context.Context, dirPath string, full bool) ([]topo.D c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return nil, c.factory.err + } + isRoot := false if dirPath == "" || dirPath == "/" { isRoot = true diff --git a/go/vt/topo/memorytopo/file.go b/go/vt/topo/memorytopo/file.go index 3cb04530393..e8514cebf07 100644 --- a/go/vt/topo/memorytopo/file.go +++ b/go/vt/topo/memorytopo/file.go @@ -34,6 +34,10 @@ func (c *Conn) Create(ctx context.Context, filePath string, contents []byte) (to c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return nil, c.factory.err + } + // Get the parent dir. dir, file := path.Split(filePath) p := c.factory.getOrCreatePath(c.cell, dir) @@ -61,6 +65,10 @@ func (c *Conn) Update(ctx context.Context, filePath string, contents []byte, ver c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return nil, c.factory.err + } + // Get the parent dir, we'll need it in case of creation. dir, file := path.Split(filePath) p := c.factory.nodeByPath(c.cell, dir) @@ -117,6 +125,10 @@ func (c *Conn) Get(ctx context.Context, filePath string) ([]byte, topo.Version, c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return nil, nil, c.factory.err + } + // Get the node. n := c.factory.nodeByPath(c.cell, filePath) if n == nil { @@ -134,6 +146,10 @@ func (c *Conn) Delete(ctx context.Context, filePath string, version topo.Version c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return c.factory.err + } + // Get the parent dir. dir, file := path.Split(filePath) p := c.factory.nodeByPath(c.cell, dir) diff --git a/go/vt/topo/memorytopo/lock.go b/go/vt/topo/memorytopo/lock.go index d2a8306be5b..507ca704da5 100644 --- a/go/vt/topo/memorytopo/lock.go +++ b/go/vt/topo/memorytopo/lock.go @@ -46,6 +46,11 @@ func (c *Conn) Lock(ctx context.Context, dirPath, contents string) (topo.LockDes for { c.factory.mu.Lock() + if c.factory.err != nil { + c.factory.mu.Unlock() + return nil, c.factory.err + } + n := c.factory.nodeByPath(c.cell, dirPath) if n == nil { c.factory.mu.Unlock() diff --git a/go/vt/topo/memorytopo/memorytopo.go b/go/vt/topo/memorytopo/memorytopo.go index be12c5ca197..d9c93a47728 100644 --- a/go/vt/topo/memorytopo/memorytopo.go +++ b/go/vt/topo/memorytopo/memorytopo.go @@ -61,6 +61,9 @@ type Factory struct { // version at 1. It is initialized with a random number, // so if we have two implementations, the numbers won't match. generation uint64 + // err is used for testing purposes to force queries / watches + // to return the given error + err error } // HasGlobalReadOnlyCell is part of the topo.Factory interface. @@ -81,6 +84,20 @@ func (f *Factory) Create(cell, serverAddr, root string) (topo.Conn, error) { }, nil } +// SetError forces the given error to be returned from all calls and propagates +// the error to all active watches. +func (f *Factory) SetError(err error) { + f.mu.Lock() + defer f.mu.Unlock() + + f.err = err + if err != nil { + for _, node := range f.cells { + node.PropagateWatchError(err) + } + } +} + // Conn implements the topo.Conn interface. It remembers the cell, and // points at the Factory that has all the data. type Conn struct { @@ -123,10 +140,24 @@ func (n *node) isDirectory() bool { return n.children != nil } -// NewServer returns a new MemoryTopo for all the cells. It will create one -// cell for each parameter passed in. It will log.Exit out in case -// of a problem. -func NewServer(cells ...string) *topo.Server { +// PropagateWatchError propagates the given error to all watches on this node +// and recursively applies to all children +func (n *node) PropagateWatchError(err error) { + for _, ch := range n.watches { + ch <- &topo.WatchData{ + Err: err, + } + } + + for _, c := range n.children { + c.PropagateWatchError(err) + } +} + +// NewServerAndFactory returns a new MemoryTopo and the backing factory for all +// the cells. It will create one cell for each parameter passed in. It will log.Exit out +// in case of a problem. +func NewServerAndFactory(cells ...string) (*topo.Server, *Factory) { f := &Factory{ cells: make(map[string]*node), generation: uint64(rand.Int63n(2 ^ 60)), @@ -144,7 +175,13 @@ func NewServer(cells ...string) *topo.Server { log.Exitf("ts.CreateCellInfo(%v) failed: %v", cell, err) } } - return ts + return ts, f +} + +// NewServer returns the new server +func NewServer(cells ...string) *topo.Server { + server, _ := NewServerAndFactory(cells...) + return server } func (f *Factory) getNextVersion() uint64 { diff --git a/go/vt/topo/memorytopo/watch.go b/go/vt/topo/memorytopo/watch.go index e150939b233..a8e27ef7986 100644 --- a/go/vt/topo/memorytopo/watch.go +++ b/go/vt/topo/memorytopo/watch.go @@ -29,6 +29,10 @@ func (c *Conn) Watch(ctx context.Context, filePath string) (*topo.WatchData, <-c c.factory.mu.Lock() defer c.factory.mu.Unlock() + if c.factory.err != nil { + return &topo.WatchData{Err: c.factory.err}, nil, nil + } + n := c.factory.nodeByPath(c.cell, filePath) if n == nil { return &topo.WatchData{Err: topo.ErrNoNode}, nil, nil diff --git a/test/utils.py b/test/utils.py index fa7cecce373..cfe566b6d63 100644 --- a/test/utils.py +++ b/test/utils.py @@ -550,6 +550,7 @@ def start(self, cell='test_nj', retry_count=2, '-retry-count', str(retry_count), '-log_dir', environment.vtlogroot, '-srv_topo_cache_ttl', cache_ttl, + '-srv_topo_cache_refresh', cache_ttl, '-tablet_protocol', protocols_flavor().tabletconn_protocol(), '-stderrthreshold', get_log_level(), '-normalize_queries', @@ -799,6 +800,7 @@ def start(self, cell='test_nj', retry_count=2, '-retry-count', str(retry_count), '-log_dir', environment.vtlogroot, '-srv_topo_cache_ttl', cache_ttl, + '-srv_topo_cache_refresh', cache_ttl, '-tablet_protocol', protocols_flavor().tabletconn_protocol(), '-gateway_implementation', vtgate_gateway_flavor().flavor(), ]