Skip to content
37 changes: 37 additions & 0 deletions go/vt/vtorc/inst/shard_dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ import (
// ErrShardNotFound is a fixed error message used when a shard is not found in the database.
var ErrShardNotFound = errors.New("shard not found")

// ReadShardNames reads the names of vitess shards for a single keyspace.
func ReadShardNames(keyspaceName string) (shardNames []string, err error) {
shardNames = make([]string, 0)
query := `select shard from vitess_shard where keyspace = ?`
args := sqlutils.Args(keyspaceName)
err = db.QueryVTOrc(query, args, func(row sqlutils.RowMap) error {
shardNames = append(shardNames, row.GetString("shard"))
return nil
})
return shardNames, err
}

// ReadAllShardNames reads the names of all vitess shards by keyspace.
func ReadAllShardNames() (shardNames map[string][]string, err error) {
shardNames = make(map[string][]string)
query := `select keyspace, shard from vitess_shard`
err = db.QueryVTOrc(query, nil, func(row sqlutils.RowMap) error {
ks := row.GetString("keyspace")
shardNames[ks] = append(shardNames[ks], row.GetString("shard"))
return nil
})
return shardNames, err
}

// ReadShardPrimaryInformation reads the vitess shard record and gets the shard primary alias and timestamp.
func ReadShardPrimaryInformation(keyspaceName, shardName string) (primaryAlias string, primaryTimestamp string, err error) {
if err = topo.ValidateKeyspaceName(keyspaceName); err != nil {
Expand Down Expand Up @@ -95,3 +119,16 @@ func getShardPrimaryTermStartTimeString(shard *topo.ShardInfo) string {
}
return protoutil.TimeFromProto(shard.PrimaryTermStartTime).UTC().String()
}

// DeleteShard deletes a shard using a keyspace and shard name.
func DeleteShard(keyspace, shard string) error {
_, err := db.ExecVTOrc(`DELETE FROM
vitess_shard
WHERE
keyspace = ?
AND shard = ?`,
keyspace,
shard,
)
return err
}
20 changes: 19 additions & 1 deletion go/vt/vtorc/inst/shard_dao_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"vitess.io/vitess/go/vt/vtorc/db"
)

func TestSaveAndReadShard(t *testing.T) {
func TestSaveReadAndDeleteShard(t *testing.T) {
// Clear the database after the test. The easiest way to do that is to run all the initialization commands again.
defer func() {
db.ClearVTOrcDatabase()
Expand Down Expand Up @@ -93,6 +93,7 @@ func TestSaveAndReadShard(t *testing.T) {
require.NoError(t, err)
}

// ReadShardPrimaryInformation
shardPrimaryAlias, primaryTimestamp, err := ReadShardPrimaryInformation(tt.keyspaceName, tt.shardName)
if tt.err != "" {
require.EqualError(t, err, tt.err)
Expand All @@ -101,6 +102,23 @@ func TestSaveAndReadShard(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, tt.primaryAliasWanted, shardPrimaryAlias)
require.EqualValues(t, tt.primaryTimestampWanted, primaryTimestamp)

// ReadShardNames
shardNames, err := ReadShardNames(tt.keyspaceName)
require.NoError(t, err)
require.Equal(t, []string{tt.shardName}, shardNames)

// ReadAllShardNames
allShardNames, err := ReadAllShardNames()
require.NoError(t, err)
ksShards, found := allShardNames[tt.keyspaceName]
require.True(t, found)
require.Equal(t, []string{tt.shardName}, ksShards)

// DeleteShard
require.NoError(t, DeleteShard(tt.keyspaceName, tt.shardName))
_, _, err = ReadShardPrimaryInformation(tt.keyspaceName, tt.shardName)
require.EqualError(t, err, ErrShardNotFound.Error())
})
}
}
82 changes: 79 additions & 3 deletions go/vt/vtorc/logic/keyspace_shard_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,44 @@ import (

"golang.org/x/exp/maps"

"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vtorc/inst"
)

var statsShardsWatched = stats.NewGaugesFuncWithMultiLabels("ShardsWatched",
"Keyspace/shards currently watched",
[]string{"Keyspace", "Shard"},
getShardsWatchedStats)

// getShardsWatchedStats returns the keyspace/shards watched in a format for stats.
func getShardsWatchedStats() map[string]int64 {
shardsWatched := make(map[string]int64)
allShardNames, err := inst.ReadAllShardNames()
if err != nil {
log.Errorf("Failed to read all shard names: %+v", err)
return shardsWatched
}
for ks, shards := range allShardNames {
for _, shard := range shards {
shardsWatched[ks+"."+shard] = 1
}
}
return shardsWatched
}

// refreshAllKeyspacesAndShardsMu ensures RefreshAllKeyspacesAndShards
// is not executed concurrently.
var refreshAllKeyspacesAndShardsMu sync.Mutex

// RefreshAllKeyspacesAndShards reloads the keyspace and shard information for the keyspaces that vtorc is concerned with.
func RefreshAllKeyspacesAndShards(ctx context.Context) error {
refreshAllKeyspacesAndShardsMu.Lock()
Copy link
Copy Markdown
Contributor Author

@timvaillancourt timvaillancourt Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refreshAllShards does a save + read + delete, per keyspace. Lets make sure that doesn't happen concurrently

defer refreshAllKeyspacesAndShardsMu.Unlock()

var keyspaces []string
if len(shardsToWatch) == 0 { // all known keyspaces
ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
Expand Down Expand Up @@ -79,6 +109,26 @@ func RefreshKeyspaceAndShard(keyspaceName string, shardName string) error {
return refreshShard(keyspaceName, shardName)
}

// shouldWatchShard returns true if a shard is within the shardsToWatch
// ranges for it's keyspace.
func shouldWatchShard(shard *topo.ShardInfo) bool {
if len(shardsToWatch) == 0 {
return true
}

watchRanges, found := shardsToWatch[shard.Keyspace()]
if !found {
return false
}

for _, keyRange := range watchRanges {
if key.KeyRangeContainsKeyRange(keyRange, shard.GetKeyRange()) {
return true
}
}
return false
}

// refreshKeyspace refreshes the keyspace's information for the given keyspace from the topo
func refreshKeyspace(keyspaceName string) error {
refreshCtx, refreshCancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout)
Expand Down Expand Up @@ -109,6 +159,7 @@ func refreshKeyspaceHelper(ctx context.Context, keyspaceName string) error {

// refreshAllShards refreshes all the shard records in the given keyspace.
func refreshAllShards(ctx context.Context, keyspaceName string) error {
// get all shards for keyspace name.
shardInfos, err := ts.FindAllShardsInKeyspace(ctx, keyspaceName, &topo.FindAllShardsInKeyspaceOptions{
// Fetch shard records concurrently to speed up discovery. A typical
// Vitess cluster will have 1-3 vtorc instances deployed, so there is
Expand All @@ -119,13 +170,38 @@ func refreshAllShards(ctx context.Context, keyspaceName string) error {
log.Error(err)
return err
}

// save shards that should be watched.
savedShards := make(map[string]bool, len(shardInfos))
for _, shardInfo := range shardInfos {
err = inst.SaveShard(shardInfo)
if err != nil {
if !shouldWatchShard(shardInfo) {
continue
}
if err = inst.SaveShard(shardInfo); err != nil {
log.Error(err)
return err
}
savedShards[shardInfo.ShardName()] = true
}

// delete shards that were not saved, indicating they are stale.
shards, err := inst.ReadShardNames(keyspaceName)
if err != nil {
log.Error(err)
return err
}
for _, shard := range shards {
if savedShards[shard] {
continue
}
shardName := topoproto.KeyspaceShardString(keyspaceName, shard)
log.Infof("Forgetting shard: %s", shardName)
if err = inst.DeleteShard(keyspaceName, shard); err != nil {
log.Errorf("Failed to delete shard %s: %+v", shardName, err)
return err
}
}

return nil
}

Expand Down
41 changes: 41 additions & 0 deletions go/vt/vtorc/logic/keyspace_shard_discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,44 @@ func verifyPrimaryAlias(t *testing.T, keyspaceName, shardName string, primaryAli
require.NoError(t, err)
require.Equal(t, primaryAliasWanted, primaryAlias)
}

func TestRefreshAllShards(t *testing.T) {
// Store the old flags and restore on test completion
oldClustersToWatch := clustersToWatch
oldTs := ts
defer func() {
clustersToWatch = oldClustersToWatch
ts = oldTs
db.ClearVTOrcDatabase()
}()

ctx := context.Background()
ts = memorytopo.NewServer(ctx, "zone1")
require.NoError(t, initializeShardsToWatch())
require.NoError(t, ts.CreateKeyspace(ctx, "ks1", keyspaceDurabilityNone))
shards := []string{"-40", "40-80", "80-c0", "c0-"}
for _, shard := range shards {
require.NoError(t, ts.CreateShard(ctx, "ks1", shard))
}

// test shard refresh
require.NoError(t, refreshAllShards(ctx, "ks1"))
shardNames, err := inst.ReadShardNames("ks1")
require.NoError(t, err)
require.Equal(t, []string{"-40", "40-80", "80-c0", "c0-"}, shardNames)

// test topo shard delete propagates
require.NoError(t, ts.DeleteShard(ctx, "ks1", "c0-"))
require.NoError(t, refreshAllShards(ctx, "ks1"))
shardNames, err = inst.ReadShardNames("ks1")
require.NoError(t, err)
require.Equal(t, []string{"-40", "40-80", "80-c0"}, shardNames)

// test clustersToWatch filters what shards are saved
clustersToWatch = []string{"ks1/-80"}
require.NoError(t, initializeShardsToWatch())
require.NoError(t, refreshAllShards(ctx, "ks1"))
shardNames, err = inst.ReadShardNames("ks1")
require.NoError(t, err)
require.Equal(t, []string{"-40", "40-80"}, shardNames)
}
Loading