diff --git a/engine.go b/engine.go index 229d02cfa7..192a33faa0 100644 --- a/engine.go +++ b/engine.go @@ -198,6 +198,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine { } ret.ReadOnly.Store(cfg.IsReadOnly) a.Runner = ret + a.ExecBuilder.Runner = ret return ret } diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 38602ffae7..8caf1973aa 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -592,6 +592,7 @@ func TestTableFunctions(t *testing.T) { harness = harness.WithProvider(engine.Analyzer.Catalog.DbProvider) engine.EngineAnalyzer().ExecBuilder = rowexec.NewBuilder(nil, sql.EngineOverrides{}) + engine.EngineAnalyzer().ExecBuilder.Runner = engine engine, err := enginetest.RunSetupScripts(harness.NewContext(), engine, setup.MydbData, true) require.NoError(t, err) diff --git a/enginetest/initialization.go b/enginetest/initialization.go index bf1c6c11c3..5d43dfc651 100644 --- a/enginetest/initialization.go +++ b/enginetest/initialization.go @@ -92,6 +92,7 @@ func NewEngineWithProvider(_ *testing.T, harness Harness, provider sql.DatabaseP idh.InitializeIndexDriver(engine.Analyzer.Catalog.AllDatabases(NewContext(harness))) } analyzer.Runner = engine + analyzer.ExecBuilder.Runner = engine return engine } diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 895cacdeb3..f2c21300ef 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -294,7 +294,7 @@ type Analyzer struct { // Parser is the parser used to parse SQL statements. Parser sql.Parser // ExecBuilder converts a sql.Node tree into an executable iterator. - ExecBuilder sql.NodeExecBuilder + ExecBuilder *rowexec.BaseBuilder // Runner represents the engine, which is represented as a separate interface to work around circular dependencies Runner sql.StatementRunner // SchemaFormatter is used to format the schema of a node to a string. diff --git a/sql/overrides.go b/sql/overrides.go index 2d8c61cc4f..2b41f42048 100644 --- a/sql/overrides.go +++ b/sql/overrides.go @@ -75,56 +75,56 @@ type ExecutionHooks struct { // CreateTable contains hooks related to CREATE TABLE statements. These will take a *plan.CreateTable. type CreateTable struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // RenameTable contains hooks related to RENAME TABLE statements. These will take a *plan.RenameTable. type RenameTable struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // DropTable contains hooks related to DROP TABLE statements. These will take a *plan.DropTable. type DropTable struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // TableAddColumn contains hooks related to ALTER TABLE ... ADD COLUMN statements. These will take a *plan.AddColumn. type TableAddColumn struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // TableRenameColumn contains hooks related to ALTER TABLE ... RENAME COLUMN statements. These will take a *plan.RenameColumn. type TableRenameColumn struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // TableModifyColumn contains hooks related to ALTER TABLE ... MODIFY COLUMN statements. These will take a // *plan.ModifyColumn. type TableModifyColumn struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } // TableDropColumn contains hooks related to ALTER TABLE ... DROP COLUMN statements. These will take a *plan.DropColumn. type TableDropColumn struct { // PreSQLExecution is called before the final step of statement execution, after analysis. - PreSQLExecution func(*Context, Node) (Node, error) + PreSQLExecution func(*Context, StatementRunner, Node) (Node, error) // PostSQLExecution is called after the final step of statement execution, after analysis. - PostSQLExecution func(*Context, Node) error + PostSQLExecution func(*Context, StatementRunner, Node) error } diff --git a/sql/rowexec/builder.go b/sql/rowexec/builder.go index 7226b7424b..cc926c4e7e 100644 --- a/sql/rowexec/builder.go +++ b/sql/rowexec/builder.go @@ -27,6 +27,7 @@ import ( type BaseBuilder struct { PriorityBuilder sql.NodeExecBuilder EngineOverrides sql.EngineOverrides + Runner sql.StatementRunner schemaFormatter sql.SchemaFormatter } @@ -34,10 +35,11 @@ var _ sql.NodeExecBuilder = (*BaseBuilder)(nil) // NewBuilder creates a new builder. If a priority builder is given, then it is tried first, and only uses the internal // builder logic if the given one does not return a result (and does not error). -func NewBuilder(priority sql.NodeExecBuilder, overrides sql.EngineOverrides) sql.NodeExecBuilder { +func NewBuilder(priority sql.NodeExecBuilder, overrides sql.EngineOverrides) *BaseBuilder { return &BaseBuilder{ PriorityBuilder: priority, EngineOverrides: overrides, + Runner: nil, // This is often set later (directly on the variable), as it's not yet available during creation schemaFormatter: sql.GetSchemaFormatter(overrides), } } diff --git a/sql/rowexec/common_test.go b/sql/rowexec/common_test.go index a59fba6d62..54ead209e9 100644 --- a/sql/rowexec/common_test.go +++ b/sql/rowexec/common_test.go @@ -29,7 +29,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -var DefaultBuilder = NewBuilder(nil, sql.EngineOverrides{}).(*BaseBuilder) +var DefaultBuilder = NewBuilder(nil, sql.EngineOverrides{}) func newContext(provider *memory.DbProvider) *sql.Context { return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider))) diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index 8a4ab7dfb5..a3937e061c 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -241,7 +241,7 @@ func (b *BaseBuilder) buildDropCheck(ctx *sql.Context, n *plan.DropCheck, row sq func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, row sql.Row) (sql.RowIter, error) { if b.EngineOverrides.Hooks.RenameTable.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.RenameTable.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.RenameTable.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -268,7 +268,7 @@ func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, ro } } if b.EngineOverrides.Hooks.RenameTable.PostSQLExecution != nil { - if err := b.EngineOverrides.Hooks.RenameTable.PostSQLExecution(ctx, n); err != nil { + if err := b.EngineOverrides.Hooks.RenameTable.PostSQLExecution(ctx, b.Runner, n); err != nil { return nil, err } } @@ -278,7 +278,7 @@ func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, ro func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, row sql.Row) (sql.RowIter, error) { if b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -321,6 +321,7 @@ func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, m: n, alterable: alterable, overrides: b.EngineOverrides, + runner: b.Runner, }, nil } @@ -951,7 +952,7 @@ func (b *BaseBuilder) buildDropSchema(ctx *sql.Context, n *plan.DropSchema, row func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, row sql.Row) (sql.RowIter, error) { if b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -1002,7 +1003,7 @@ func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, return nil, err } if b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution != nil { - if err = b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution(ctx, n); err != nil { + if err = b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution(ctx, b.Runner, n); err != nil { return nil, err } } @@ -1012,7 +1013,7 @@ func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, func (b *BaseBuilder) buildAddColumn(ctx *sql.Context, n *plan.AddColumn, row sql.Row) (sql.RowIter, error) { if b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -1096,7 +1097,7 @@ func (b *BaseBuilder) buildAlterDB(ctx *sql.Context, n *plan.AlterDB, row sql.Ro func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, row sql.Row) (sql.RowIter, error) { var err error if b.EngineOverrides.Hooks.CreateTable.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.CreateTable.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.CreateTable.PreSQLExecution(ctx, b.Runner, n) if err != nil { return sql.RowsToRowIter(), err } @@ -1262,7 +1263,7 @@ func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, ro } if b.EngineOverrides.Hooks.CreateTable.PostSQLExecution != nil { - if err = b.EngineOverrides.Hooks.CreateTable.PostSQLExecution(ctx, n); err != nil { + if err = b.EngineOverrides.Hooks.CreateTable.PostSQLExecution(ctx, b.Runner, n); err != nil { return nil, err } } @@ -1345,7 +1346,7 @@ func (b *BaseBuilder) buildCreateTrigger(ctx *sql.Context, n *plan.CreateTrigger func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row sql.Row) (sql.RowIter, error) { if b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -1370,6 +1371,7 @@ func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row d: n, alterable: alterable, overrides: b.EngineOverrides, + runner: b.Runner, }, nil } diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index fe6f80e1b5..777adbe9bf 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -353,6 +353,7 @@ type modifyColumnIter struct { m *plan.ModifyColumn alterable sql.AlterableTable overrides sql.EngineOverrides + runner sql.StatementRunner runOnce bool } @@ -459,7 +460,7 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } if rewritten { if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil { - if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil { + if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.runner, i.m); err != nil { return nil, err } } @@ -483,7 +484,7 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } } if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil { - if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil { + if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.runner, i.m); err != nil { return nil, err } } @@ -1380,7 +1381,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } if rewritten { if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { - if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.a); err != nil { + if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil { return nil, err } } @@ -1402,7 +1403,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { // We only need to update all table rows if the new column is non-nil if i.a.Column().Nullable && i.a.Column().Default == nil { if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { - if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil { + if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil { return nil, err } } @@ -1415,7 +1416,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil { - if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil { + if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil { return nil, err } } @@ -1772,6 +1773,7 @@ type dropColumnIter struct { d *plan.DropColumn alterable sql.AlterableTable overrides sql.EngineOverrides + runner sql.StatementRunner runOnce bool } @@ -1799,7 +1801,7 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } if rewritten { if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil { - if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil { + if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.runner, i.d); err != nil { return nil, err } } @@ -1826,7 +1828,7 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) { } } if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil { - if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil { + if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.runner, i.d); err != nil { return nil, err } } diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index 5b102dd6d6..1d05e33f2c 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -201,7 +201,7 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql. var curdb sql.Database if b.EngineOverrides.Hooks.DropTable.PreSQLExecution != nil { - nn, err := b.EngineOverrides.Hooks.DropTable.PreSQLExecution(ctx, n) + nn, err := b.EngineOverrides.Hooks.DropTable.PreSQLExecution(ctx, b.Runner, n) if err != nil { return nil, err } @@ -274,7 +274,7 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql. } if b.EngineOverrides.Hooks.DropTable.PostSQLExecution != nil { - if err = b.EngineOverrides.Hooks.DropTable.PostSQLExecution(ctx, n); err != nil { + if err = b.EngineOverrides.Hooks.DropTable.PostSQLExecution(ctx, b.Runner, n); err != nil { return nil, err } }