diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index 637b2066ecb..0b6cd0cc323 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -49,12 +49,13 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/dolt/go/libraries/utils/valctx" ) // SqlEngine packages up the context necessary to run sql queries against dsqle. type SqlEngine struct { provider sql.DatabaseProvider - contextFactory contextFactory + ContextFactory sql.ContextFactory dsessFactory sessionFactory engine *gms.Engine fs filesys.Filesys @@ -92,6 +93,19 @@ func NewSqlEngine( mrEnv *env.MultiRepoEnv, config *SqlEngineConfig, ) (*SqlEngine, error) { + // Context validation is a testing mode that we run Dolt in + // during integration tests. It asserts that `context.Context` + // instances which reach the storage layer have gone through + // GC session lifecycle callbacks. This is only relevant in + // sql mode, so we only enable it here. This is potentially + // relevant in non-sql-server contexts, because things like + // replication and events can still cause concurrency during a + // GC, so we put this here instead of in sql-server. + const contextValidationEnabledEnvVar = "DOLT_CONTEXT_VALIDATION_ENABLED" + if val := os.Getenv(contextValidationEnabledEnvVar); val != "" && val != "0" && strings.ToLower(val) != "false" { + valctx.EnableContextValidation() + } + gcSafepointController := gcctx.NewGCSafepointController() ctx = gcctx.WithGCSafepointController(ctx, gcSafepointController) @@ -230,8 +244,8 @@ func NewSqlEngine( engine.Analyzer.ExecBuilder = rowexec.NewOverrideBuilder(kvexec.Builder{}) sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, gcSafepointController, config.Autocommit) sqlEngine.provider = pro - sqlEngine.contextFactory = sqlContextFactory sqlEngine.dsessFactory = sessFactory + sqlEngine.ContextFactory = sqlContextFactory sqlEngine.engine = engine sqlEngine.fs = pro.FileSystem() @@ -263,7 +277,7 @@ func NewSqlEngine( } if engine.EventScheduler == nil { - err = configureEventScheduler(config, engine, sqlEngine.contextFactory, sessFactory, pro) + err = configureEventScheduler(config, engine, sqlEngine.ContextFactory, sessFactory, pro) if err != nil { return nil, err } @@ -275,7 +289,7 @@ func NewSqlEngine( return nil, err } - err = configureBinlogReplicaController(config, engine, sqlEngine.contextFactory, binLogSession) + err = configureBinlogReplicaController(config, engine, sqlEngine.ContextFactory, binLogSession) if err != nil { return nil, err } @@ -314,6 +328,9 @@ func (se *SqlEngine) InitStats(ctx context.Context) error { if err != nil { return err } + defer sql.SessionEnd(sqlCtx.Session) + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) dbs := pro.AllDatabases(sqlCtx) statsPro := se.GetUnderlyingEngine().Analyzer.Catalog.StatsProvider if sc, ok := statsPro.(*statspro.StatsController); ok { @@ -328,7 +345,7 @@ func (se *SqlEngine) InitStats(ctx context.Context) error { sqlDbs = append(sqlDbs, db) } - err = sc.Init(ctx, pro, se.NewDefaultContext, sqlDbs) + err = sc.Init(sqlCtx, pro, se.NewDefaultContext, sqlDbs) if err != nil { return err } @@ -355,7 +372,7 @@ func (se *SqlEngine) Databases(ctx *sql.Context) []dsess.SqlDatabase { // NewContext returns a new sql.Context with the given session. func (se *SqlEngine) NewContext(ctx context.Context, session sql.Session) (*sql.Context, error) { - return se.contextFactory(ctx, session) + return se.ContextFactory(ctx, sql.WithSession(session)), nil } // NewDefaultContext returns a new sql.Context with a new default dolt session. @@ -364,7 +381,7 @@ func (se *SqlEngine) NewDefaultContext(ctx context.Context) (*sql.Context, error if err != nil { return nil, err } - return se.contextFactory(ctx, session) + return se.ContextFactory(ctx, sql.WithSession(session)), nil } // NewLocalContext returns a new |sql.Context| with its client set to |root| @@ -416,11 +433,8 @@ func (se *SqlEngine) Close() error { } // configureBinlogReplicaController configures the binlog replication controller with the |engine|. -func configureBinlogReplicaController(config *SqlEngineConfig, engine *gms.Engine, ctxFactory contextFactory, session *dsess.DoltSession) error { - executionCtx, err := ctxFactory(context.Background(), session) - if err != nil { - return err - } +func configureBinlogReplicaController(config *SqlEngineConfig, engine *gms.Engine, ctxFactory sql.ContextFactory, session *dsess.DoltSession) error { + executionCtx := ctxFactory(context.Background(), sql.WithSession(session)) dblr.DoltBinlogReplicaController.SetExecutionContext(executionCtx) dblr.DoltBinlogReplicaController.SetEngine(engine) engine.Analyzer.Catalog.BinlogReplicaController = config.BinlogReplicaController @@ -442,14 +456,14 @@ func configureBinlogPrimaryController(engine *gms.Engine) error { // configureEventScheduler configures the event scheduler with the |engine| for executing events, a |sessFactory| // for creating sessions, and a DoltDatabaseProvider, |pro|. -func configureEventScheduler(config *SqlEngineConfig, engine *gms.Engine, ctxFactory contextFactory, sessFactory sessionFactory, pro *dsqle.DoltDatabaseProvider) error { +func configureEventScheduler(config *SqlEngineConfig, engine *gms.Engine, ctxFactory sql.ContextFactory, sessFactory sessionFactory, pro *dsqle.DoltDatabaseProvider) error { // getCtxFunc is used to create new session with a new context for event scheduler. getCtxFunc := func() (*sql.Context, error) { sess, err := sessFactory(sql.NewBaseSession(), pro) if err != nil { return nil, err } - return ctxFactory(context.Background(), sess) + return ctxFactory(context.Background(), sql.WithSession(sess)), nil } // A hidden env var allows overriding the event scheduler period for testing. This option is not @@ -471,9 +485,13 @@ func configureEventScheduler(config *SqlEngineConfig, engine *gms.Engine, ctxFac } // sqlContextFactory returns a contextFactory that creates a new sql.Context with the given session -func sqlContextFactory(ctx context.Context, session sql.Session) (*sql.Context, error) { - sqlCtx := sql.NewContext(ctx, sql.WithSession(session)) - return sqlCtx, nil +func sqlContextFactory(ctx context.Context, opts ...sql.ContextOption) *sql.Context { + ctx = valctx.WithContextValidation(ctx) + sqlCtx := sql.NewContext(ctx, opts...) + if sqlCtx.Session != nil { + valctx.SetContextValidation(ctx, dsess.DSessFromSess(sqlCtx.Session).Validate) + } + return sqlCtx } // doltSessionFactory returns a sessionFactory that creates a new DoltSession @@ -521,6 +539,7 @@ func NewSqlEngineForEnv(ctx context.Context, dEnv *env.DoltEnv, options ...Confi if err != nil { return nil, "", err } + if err := engine.InitStats(ctx); err != nil { return nil, "", err } diff --git a/go/cmd/dolt/commands/show.go b/go/cmd/dolt/commands/show.go index a435684296f..d504708a028 100644 --- a/go/cmd/dolt/commands/show.go +++ b/go/cmd/dolt/commands/show.go @@ -171,7 +171,7 @@ func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, d return 1 } - if !opts.pretty && !dEnv.DoltDB(ctx).Format().UsesFlatbuffers() { + if !opts.pretty && !dEnv.DoltDB(sqlCtx).Format().UsesFlatbuffers() { cli.PrintErrln("`dolt show --no-pretty` or `dolt show (BRANCHNAME)` is not supported when using old LD_1 storage format.") return 1 } @@ -180,7 +180,7 @@ func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, d for _, specRef := range resolvedRefs { // If --no-pretty was supplied, always display the raw contents of the referenced object. if !opts.pretty { - err := printRawValue(ctx, dEnv, specRef) + err := printRawValue(sqlCtx, dEnv, specRef) if err != nil { return handleErrAndExit(err) } @@ -202,12 +202,12 @@ func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, d cli.PrintErrln("`dolt show (NON_COMMIT_HASH)` requires a local environment. Not intended for common use.") return 1 } - if !dEnv.DoltDB(ctx).Format().UsesFlatbuffers() { + if !dEnv.DoltDB(sqlCtx).Format().UsesFlatbuffers() { cli.PrintErrln("`dolt show (NON_COMMIT_HASH)` is not supported when using old LD_1 storage format.") return 1 } - value, err := getValueFromRefSpec(ctx, dEnv, specRef) + value, err := getValueFromRefSpec(sqlCtx, dEnv, specRef) if err != nil { err = fmt.Errorf("error resolving spec ref '%s': %w", specRef, err) if err != nil { diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 02f5735b603..8823b21084f 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -783,7 +783,10 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu subCtx, stop := signal.NotifyContext(initialCtx, os.Interrupt, syscall.SIGTERM) defer stop() - sqlCtx := sql.NewContext(subCtx, sql.WithSession(sqlCtx.Session)) + var cancel func() + sqlCtx, cancel = sqlCtx.NewSubContext() + stopAfter := context.AfterFunc(subCtx, cancel) + defer stopAfter() cmdType, subCmd, newQuery, err := preprocessQuery(query, lastSqlCmd, cliCtx) if err != nil { diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 05443d84dd3..e336a7a1168 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -721,7 +721,7 @@ func ConfigureServices( mySQLServer, err = server.NewServerWithHandler( serverConf, sqlEngine.GetUnderlyingEngine(), - sql.NewContext, + sqlEngine.ContextFactory, newSessionBuilder(sqlEngine, cfg.ServerConfig), metListener, func(h mysql.Handler) (mysql.Handler, error) { @@ -732,7 +732,7 @@ func ConfigureServices( mySQLServer, err = server.NewServer( serverConf, sqlEngine.GetUnderlyingEngine(), - sql.NewContext, + sqlEngine.ContextFactory, newSessionBuilder(sqlEngine, cfg.ServerConfig), metListener, ) diff --git a/go/cmd/dolt/commands/tblcmds/export.go b/go/cmd/dolt/commands/tblcmds/export.go index c461db0cc68..6c55fcac935 100644 --- a/go/cmd/dolt/commands/tblcmds/export.go +++ b/go/cmd/dolt/commands/tblcmds/export.go @@ -218,7 +218,7 @@ func (cmd ExportCmd) Exec(ctx context.Context, commandStr string, args []string, return commands.HandleVErrAndExitCode(errhand.BuildDError("Error creating reader for %s.", exOpts.SrcName()).AddCause(err).Build(), usage) } - wr, verr := getTableWriter(ctx, root, dEnv, rd.GetSchema(), exOpts) + wr, verr := getTableWriter(sqlCtx, root, dEnv, rd.GetSchema(), exOpts) if verr != nil { return commands.HandleVErrAndExitCode(verr, usage) } diff --git a/go/cmd/dolt/commands/utils.go b/go/cmd/dolt/commands/utils.go index 7813bcde2d2..7eeb79a8877 100644 --- a/go/cmd/dolt/commands/utils.go +++ b/go/cmd/dolt/commands/utils.go @@ -246,14 +246,14 @@ func newLateBindingEngine( Autocommit: true, } - var lateBinder cli.LateBindQueryist = func(ctx2 context.Context) (cli.Queryist, *sql.Context, func(), error) { + var lateBinder cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) { // We've deferred loading the database as long as we can. // If we're binding the Queryist, that means that engine is actually // going to be used. - mrEnv.ReloadDBs(ctx2) + mrEnv.ReloadDBs(ctx) se, err := engine.NewSqlEngine( - ctx2, + ctx, mrEnv, config, ) @@ -261,22 +261,15 @@ func newLateBindingEngine( return nil, nil, nil, err } - if err := se.InitStats(ctx2); err != nil { + if err := se.InitStats(ctx); err != nil { + se.Close() return nil, nil, nil, err } - sqlCtx, err := se.NewDefaultContext(ctx2) - if err != nil { - return nil, nil, nil, err - } - - // Whether we're running in shell mode or some other mode, sql commands from the command line always have a current - // database set when you begin using them. - sqlCtx.SetCurrentDatabase(database) - rawDb := se.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb salt, err := mysql.NewSalt() if err != nil { + se.Close() return nil, nil, nil, err } @@ -292,6 +285,7 @@ func newLateBindingEngine( err := passwordValidate(rawDb, salt, dbUser, authResponse) if err != nil { + se.Close() return nil, nil, nil, err } @@ -303,9 +297,28 @@ func newLateBindingEngine( rawDb.AddEphemeralSuperUser(ed, dbUser, config.ServerHost, "") } + sqlCtx, err := se.NewDefaultContext(ctx) + if err != nil { + se.Close() + return nil, nil, nil, err + } + // Whether we're running in shell mode or some other mode, sql commands from the command line always have a current + // database set when you begin using them. + sqlCtx.SetCurrentDatabase(database) + + // For now, we treat the entire lifecycle of this + // sqlCtx as one big session-in-use window. + sql.SessionCommandBegin(sqlCtx.Session) + + close := func() { + sql.SessionCommandEnd(sqlCtx.Session) + sql.SessionEnd(sqlCtx.Session) + se.Close() + } + // Set client to specified user sqlCtx.Session.SetClient(sql.Client{User: dbUser, Address: config.ServerHost, Capabilities: 0}) - return se, sqlCtx, func() { se.Close() }, nil + return se, sqlCtx, close, nil } return lateBinder, nil diff --git a/go/libraries/doltcore/doltdb/gcctx/context.go b/go/libraries/doltcore/doltdb/gcctx/context.go index 120b401698f..5b6631de4cb 100644 --- a/go/libraries/doltcore/doltdb/gcctx/context.go +++ b/go/libraries/doltcore/doltdb/gcctx/context.go @@ -17,6 +17,7 @@ package gcctx import ( "context" + "github.com/dolthub/dolt/go/libraries/utils/valctx" "github.com/dolthub/dolt/go/store/hash" ) @@ -48,6 +49,8 @@ func WithGCSafepointController(ctx context.Context, controller *GCSafepointContr controller: controller, } ret := context.WithValue(ctx, safepointControllerkey, state) + ret = valctx.WithContextValidation(ret) + valctx.SetContextValidation(ret, state.Validate) return ret } diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index d6525549d3d..f5d68466b7b 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -43,6 +43,7 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/concurrentmap" "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/libraries/utils/lockutil" + "github.com/dolthub/dolt/go/libraries/utils/valctx" "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/types" ) @@ -136,6 +137,9 @@ func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Files for _, esp := range dprocedures.DoltProcedures { externalProcedures.Register(esp) } + if valctx.IsEnabled() { + externalProcedures.Register(dprocedures.NewTestValctxProcedure()) + } // If the specified |fs| is an in mem file system, default to using the InMemDoltDB dbFactoryUrl so that all // databases are created with the same file system type. diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_test_valctx.go b/go/libraries/doltcore/sqle/dprocedures/dolt_test_valctx.go new file mode 100644 index 00000000000..97fa5674b1f --- /dev/null +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_test_valctx.go @@ -0,0 +1,60 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dprocedures + +import ( + "context" + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/store/hash" +) + +// Only installed if valctx is enabled. This is only used in testing, and +// it lets tests assert that valctx is registered and working as expected. +// +// This stored procedure intentionally calls into the DoltDB layer with +// an unregistered valctx and allows a caller to look for the error which +// should get surfaced. +func doltTestValctx(ctx *sql.Context, args ...string) (sql.RowIter, error) { + dbName := ctx.GetCurrentDatabase() + if len(dbName) == 0 { + return rowToIter(int64(1)), fmt.Errorf("Empty database name.") + } + dSess := dsess.DSessFromSess(ctx.Session) + ddb, ok := dSess.GetDoltDB(ctx, dbName) + if !ok { + return rowToIter(int64(1)), fmt.Errorf("Unable to get DoltDB") + } + // With valctx enabled, this should panic. We are passing in + // |context.Background()| here intentionally. Note, that if + // this does not panic, it will return an error, because a + // RootValue with a |0| hash should not exist in the + // database. We ignore that error here, and always return + // success. If valctx changed to return an error instead of + // panic, this would need to be reworked. + ddb.ReadRootValue(context.Background(), hash.Hash{}) + return rowToIter(int64(0)), nil +} + +func NewTestValctxProcedure() sql.ExternalStoredProcedureDetails { + return sql.ExternalStoredProcedureDetails{ + Name: "dolt_test_valctx", + Schema: int64Schema("status"), + Function: doltTestValctx, + } +} diff --git a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go index 106a33de725..39067084b9d 100644 --- a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go +++ b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go @@ -68,8 +68,8 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb mm: mutexmap.NewMutexMap(), init: make(chan struct{}), } - ctx = context.Background() gcSafepointController := getGCSafepointController(ctx) + ctx = context.Background() if gcSafepointController != nil { ctx = gcctx.WithGCSafepointController(ctx, gcSafepointController) } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 3fda7d44d4e..cdaf1691004 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -1764,6 +1764,17 @@ func (d *DoltSession) SessionEnd() { } } +func (d *DoltSession) Validate() { + // If this gets called, valctx context validation is enabled + // and the purpose is to validate that this session is + // registered with an open command on our current + // gcSafepointController. + if d.gcSafepointController == nil { + panic("DoltSession.Validate called. Expected to have a gcSafepointController but did not.") + } + d.gcSafepointController.Validate(d) +} + // dolt_gc accesses the safepoint controller for the current // sql engine through here. func (d *DoltSession) GCSafepointController() *gcctx.GCSafepointController { diff --git a/go/libraries/doltcore/sqle/statspro/controller.go b/go/libraries/doltcore/sqle/statspro/controller.go index dc36c95c338..7af721cc548 100644 --- a/go/libraries/doltcore/sqle/statspro/controller.go +++ b/go/libraries/doltcore/sqle/statspro/controller.go @@ -300,7 +300,16 @@ func (sc *StatsController) AnalyzeTable(ctx *sql.Context, table sql.Table, dbNam } newStats := newRootStats() - err = sc.updateTable(ctx, newStats, table.Name(), sqlDb, nil) + + // XXX: Use a new context for this operation. |updateTable| does GC + // lifecycle callbacks on the context. |ctx| already has lifecycle + // callbacks registered because we are part of a SQL handler. + newCtx, err := sc.ctxGen(ctx.Context) + if err != nil { + return err + } + newCtx.SetCurrentDatabase(ctx.GetCurrentDatabase()) + err = sc.updateTable(newCtx, newStats, table.Name(), sqlDb, nil) if err != nil { return err } diff --git a/go/libraries/doltcore/sqle/statspro/jobqueue/serialqueue.go b/go/libraries/doltcore/sqle/statspro/jobqueue/serialqueue.go index e8a55bad4ba..f0cfd62a940 100644 --- a/go/libraries/doltcore/sqle/statspro/jobqueue/serialqueue.go +++ b/go/libraries/doltcore/sqle/statspro/jobqueue/serialqueue.go @@ -261,21 +261,23 @@ func (s *SerialQueue) DoSync(ctx context.Context, f func() error) error { // No return leaves the session in an incomplete state. func (s *SerialQueue) DoSyncSessionAware(ctx *sql.Context, f func() error) error { started := atomic.Bool{} + var err error nf := func() error { if started.Swap(true) { return nil } sql.SessionCommandBegin(ctx.Session) defer sql.SessionCommandEnd(ctx.Session) - return f() - } - w, err := s.submitWork(schedPriority_Normal, nf) - if err != nil { + err = f() return err } + w, serr := s.submitWork(schedPriority_Normal, nf) + if serr != nil { + return serr + } select { case <-w.done: - return nil + return err case <-ctx.Done(): if started.Swap(true) { <-w.done diff --git a/go/libraries/utils/valctx/valctx.go b/go/libraries/utils/valctx/valctx.go index c7b8431e0c3..92a8692d537 100644 --- a/go/libraries/utils/valctx/valctx.go +++ b/go/libraries/utils/valctx/valctx.go @@ -26,6 +26,10 @@ func EnableContextValidation() { enabled = true } +func IsEnabled() bool { + return enabled +} + type ctxKey int var validationKey ctxKey diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 1ae5e7f912b..dfa668b23b6 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -43,6 +43,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" + "github.com/dolthub/dolt/go/libraries/utils/valctx" "github.com/dolthub/dolt/go/store/blobstore" "github.com/dolthub/dolt/go/store/chunks" "github.com/dolthub/dolt/go/store/hash" @@ -156,6 +157,7 @@ func (nbs *NomsBlockStore) ChunkJournal() *ChunkJournal { } func (nbs *NomsBlockStore) GetChunkLocationsWithPaths(ctx context.Context, hashes hash.HashSet) (map[string]map[hash.Hash]Range, error) { + valctx.ValidateContext(ctx) sourcesToRanges, err := nbs.getChunkLocations(ctx, hashes) if err != nil { return nil, err @@ -222,6 +224,7 @@ func (nbs *NomsBlockStore) getChunkLocations(ctx context.Context, hashes hash.Ha } func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.HashSet) (map[hash.Hash]map[hash.Hash]Range, error) { + valctx.ValidateContext(ctx) sourcesToRanges, err := nbs.getChunkLocations(ctx, hashes) if err != nil { return nil, err @@ -288,6 +291,7 @@ func (nbs *NomsBlockStore) conjoinIfRequired(ctx context.Context) (bool, error) } func (nbs *NomsBlockStore) UpdateManifest(ctx context.Context, updates map[hash.Hash]uint32) (ManifestInfo, error) { + valctx.ValidateContext(ctx) sources, err := nbs.openChunkSourcesForAddTableFiles(ctx, updates) if err != nil { return manifestContents{}, err @@ -406,6 +410,7 @@ func (nbs *NomsBlockStore) updateManifestAddFiles(ctx context.Context, updates m } func (nbs *NomsBlockStore) UpdateManifestWithAppendix(ctx context.Context, updates map[hash.Hash]uint32, option ManifestAppendixOption) (ManifestInfo, error) { + valctx.ValidateContext(ctx) sources, err := nbs.openChunkSourcesForAddTableFiles(ctx, updates) if err != nil { return manifestContents{}, err @@ -709,6 +714,7 @@ func (nbs *NomsBlockStore) waitForGC(ctx context.Context) error { } func (nbs *NomsBlockStore) Put(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry) error { + valctx.ValidateContext(ctx) return nbs.putChunk(ctx, c, getAddrs, nbs.refCheck) } @@ -831,6 +837,7 @@ func (nbs *NomsBlockStore) errorIfDangling(root hash.Hash, checker refCheck) err } func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, error) { + valctx.ValidateContext(ctx) ctx, span := tracer.Start(ctx, "nbs.Get") defer span.End() @@ -881,6 +888,7 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, } func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found func(context.Context, *chunks.Chunk)) error { + valctx.ValidateContext(ctx) ctx, span := tracer.Start(ctx, "nbs.GetMany", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() return nbs.getManyWithFunc(ctx, hashes, gcDependencyMode_TakeDependency, @@ -891,6 +899,7 @@ func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, fou } func (nbs *NomsBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, ToChunker)) error { + valctx.ValidateContext(ctx) return nbs.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) } @@ -1015,6 +1024,7 @@ func (nbs *NomsBlockStore) Count() (uint32, error) { } func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { + valctx.ValidateContext(ctx) t1 := time.Now() defer func() { nbs.stats.HasLatency.SampleTimeSince(t1) @@ -1059,6 +1069,7 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { } func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + valctx.ValidateContext(ctx) return nbs.hasManyDep(ctx, hashes, gcDependencyMode_TakeDependency) } @@ -1211,6 +1222,7 @@ func toHasRecords(hashes hash.HashSet) []hasRecord { } func (nbs *NomsBlockStore) Rebase(ctx context.Context) error { + valctx.ValidateContext(ctx) nbs.mu.Lock() defer nbs.mu.Unlock() return nbs.rebase(ctx) @@ -1246,12 +1258,14 @@ func (nbs *NomsBlockStore) rebase(ctx context.Context) error { } func (nbs *NomsBlockStore) Root(ctx context.Context) (hash.Hash, error) { + valctx.ValidateContext(ctx) nbs.mu.RLock() defer nbs.mu.RUnlock() return nbs.upstream.root, nil } func (nbs *NomsBlockStore) Commit(ctx context.Context, current, last hash.Hash) (success bool, err error) { + valctx.ValidateContext(ctx) return nbs.commit(ctx, current, last, nbs.refCheck) } @@ -1503,6 +1517,7 @@ func (tf tableFile) Open(ctx context.Context) (io.ReadCloser, uint64, error) { // Sources retrieves the current root hash, a list of all table files (which may include appendix tablefiles), // and a second list of only the appendix table files func (nbs *NomsBlockStore) Sources(ctx context.Context) (hash.Hash, []chunks.TableFile, []chunks.TableFile, error) { + valctx.ValidateContext(ctx) nbs.mu.Lock() defer nbs.mu.Unlock() @@ -1627,6 +1642,7 @@ func (nbs *NomsBlockStore) Path() (string, bool) { // WriteTableFile will read a table file from the provided reader and write it to the TableFileStore func (nbs *NomsBlockStore) WriteTableFile(ctx context.Context, fileName string, numChunks int, contentHash []byte, getRd func() (io.ReadCloser, uint64, error)) error { + valctx.ValidateContext(ctx) tfp, ok := nbs.p.(tableFilePersister) if !ok { return errors.New("Not implemented") @@ -1642,6 +1658,7 @@ func (nbs *NomsBlockStore) WriteTableFile(ctx context.Context, fileName string, // AddTableFilesToManifest adds table files to the manifest func (nbs *NomsBlockStore) AddTableFilesToManifest(ctx context.Context, fileIdToNumChunks map[string]int, getAddrs chunks.GetAddrsCurry) error { + valctx.ValidateContext(ctx) return nbs.addTableFilesToManifest(ctx, fileIdToNumChunks, getAddrs, nbs.refCheck) } @@ -1816,6 +1833,7 @@ func (nbs *NomsBlockStore) openChunkSourcesForAddTableFiles(ctx context.Context, // PruneTableFiles deletes old table files that are no longer referenced in the manifest. func (nbs *NomsBlockStore) PruneTableFiles(ctx context.Context) (err error) { + valctx.ValidateContext(ctx) return nbs.pruneTableFiles(ctx) } @@ -1884,6 +1902,7 @@ func (nbs *NomsBlockStore) beginRead() (endRead func()) { } func (nbs *NomsBlockStore) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { + valctx.ValidateContext(ctx) return markAndSweepChunks(ctx, nbs, nbs, dest, getAddrs, filter, mode) } @@ -1966,6 +1985,7 @@ type markAndSweeper struct { } func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) error { + valctx.ValidateContext(ctx) toVisit := make(hash.HashSet, len(hashes)) for _, h := range hashes { if _, ok := i.visited[h]; !ok { @@ -2049,6 +2069,7 @@ func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) err } func (i *markAndSweeper) Finalize(ctx context.Context) (chunks.GCFinalizer, error) { + valctx.ValidateContext(ctx) specs, err := i.gcc.copyTablesToDir(ctx) if err != nil { return nil, err diff --git a/integration-tests/bats/garbage_collection.bats b/integration-tests/bats/garbage_collection.bats index bad97f6f836..d83b73be7f3 100644 --- a/integration-tests/bats/garbage_collection.bats +++ b/integration-tests/bats/garbage_collection.bats @@ -34,6 +34,14 @@ teardown() { dolt gc -s } +@test "garbage_collection: valctx is enabled" { + run dolt sql -q "call dolt_test_valctx();" + # Calling dolt_test_valctx should exit non-zero. + [ "$status" -ne "0" ] + # It should have surfaced a panic. + [[ "$output" =~ "panic: " ]] || false +} + @test "garbage_collection: smoke test" { dolt sql <