From 2c1ee84d49ece07a4b118cac84a24eae00a84241 Mon Sep 17 00:00:00 2001 From: jennifersp <44716627+jennifersp@users.noreply.github.com> Date: Thu, 12 May 2022 14:24:24 -0700 Subject: [PATCH] Allow unresolved tables and procedures in trigger body in CREATE TRIGGER (#992) --- enginetest/trigger_queries.go | 184 +++++++++++++++++++++++++++++++--- sql/analyzer/triggers.go | 43 ++++---- sql/errors.go | 3 + sql/plan/call.go | 2 +- sql/plan/ddl_trigger.go | 10 +- 5 files changed, 207 insertions(+), 35 deletions(-) diff --git a/enginetest/trigger_queries.go b/enginetest/trigger_queries.go index 2299ec4e94..5dd92848ae 100644 --- a/enginetest/trigger_queries.go +++ b/enginetest/trigger_queries.go @@ -1862,6 +1862,168 @@ end;`, }, }, }, + { + Name: "simple trigger with non-existent table in trigger body", + SetUpScript: []string{ + "create table a (x int primary key)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "create trigger insert_into_b after insert on a for each row insert into b values (new.x + 1)", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "insert into a values (1), (3), (5)", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "create table b (y int primary key)", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "insert into a values (1), (3), (5)", + Expected: []sql.Row{ + {sql.OkResult{RowsAffected: 3}}, + }, + }, + { + Query: "select x from a order by 1", + Expected: []sql.Row{ + {1}, {3}, {5}, + }, + }, + { + Query: "select y from b order by 1", + Expected: []sql.Row{ + {2}, {4}, {6}, + }, + }, + }, + }, + { + Name: "insert, update, delete triggers with non-existent table in trigger body", + SetUpScript: []string{ + "CREATE TABLE film (film_id smallint unsigned NOT NULL AUTO_INCREMENT, title varchar(128) NOT NULL, description text, PRIMARY KEY (film_id))", + "INSERT INTO `film` VALUES (1,'ACADEMY DINOSAUR','A Epic Drama in The Canadian Rockies'),(2,'ACE GOLDFINGER','An Astounding Epistle of a Database Administrator in Ancient China');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TRIGGER ins_film AFTER INSERT ON film FOR EACH ROW BEGIN INSERT INTO film_text (film_id, title, description) VALUES (new.film_id, new.title, new.description); END;", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: `CREATE TRIGGER upd_film AFTER UPDATE ON film FOR EACH ROW BEGIN + IF (old.title != new.title) OR (old.description != new.description) OR (old.film_id != new.film_id) + THEN + UPDATE film_text + SET title=new.title, + description=new.description, + film_id=new.film_id + WHERE film_id=old.film_id; + END IF; END;`, + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "CREATE TRIGGER del_film AFTER DELETE ON film FOR EACH ROW BEGIN DELETE FROM film_text WHERE film_id = old.film_id; END;", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "INSERT INTO `film` VALUES (3,'ADAPTATION HOLES','An Astounding Reflection in A Baloon Factory'),(4,'AFFAIR PREJUDICE','A Fanciful Documentary in A Shark Tank')", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "UPDATE film SET title = 'THE ACADEMY DINOSAUR' WHERE title = 'ACADEMY DINOSAUR'", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "DELETE FROM film WHERE title = 'ACE GOLDFINGER'", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "CREATE TABLE film_text (film_id smallint NOT NULL, title varchar(255) NOT NULL, description text, PRIMARY KEY (film_id))", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "SELECT COUNT(*) FROM film", + Expected: []sql.Row{{2}}, + }, + { + Query: "INSERT INTO `film` VALUES (3,'ADAPTATION HOLES','An Astounding Reflection in A Baloon Factory'),(4,'AFFAIR PREJUDICE','A Fanciful Documentary in A Shark Tank')", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 3}}}, + }, + { + Query: "SELECT COUNT(*) FROM film", + Expected: []sql.Row{{4}}, + }, + { + Query: "SELECT COUNT(*) FROM film_text", + Expected: []sql.Row{{2}}, + }, + { + Query: "UPDATE film SET title = 'DIFFERENT MOVIE' WHERE title = 'ADAPTATION HOLES'", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 1, InsertID: 0, Info: plan.UpdateInfo{Matched: 1, Updated: 1, Warnings: 0}}}}, + }, + { + Query: "SELECT COUNT(*) FROM film_text WHERE title = 'DIFFERENT MOVIE'", + Expected: []sql.Row{{1}}, + }, + { + Query: "DELETE FROM film WHERE title = 'DIFFERENT MOVIE'", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 1}}}, + }, + { + Query: "SELECT COUNT(*) FROM film_text WHERE title = 'DIFFERENT MOVIE'", + Expected: []sql.Row{{0}}, + }, + }, + }, { + Name: "non-existent procedure in trigger body", + SetUpScript: []string{ + "CREATE TABLE t0 (id INT PRIMARY KEY AUTO_INCREMENT, v1 INT, v2 TEXT);", + "CREATE TABLE t1 (id INT PRIMARY KEY AUTO_INCREMENT, v1 INT, v2 TEXT);", + "INSERT INTO t0 VALUES (1, 2, 'abc'), (2, 3, 'def');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t0;", + Expected: []sql.Row{{1, 2, "abc"}, {2, 3, "def"}}, + }, + { + Query: `CREATE PROCEDURE add_entry(i INT, s TEXT) BEGIN IF i > 50 THEN +SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'too big number'; END IF; +INSERT INTO t0 (v1, v2) VALUES (i, s); END;`, + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "CREATE TRIGGER trig AFTER INSERT ON t0 FOR EACH ROW BEGIN CALL back_up(NEW.v1, NEW.v2); END;", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "INSERT INTO t0 (v1, v2) VALUES (5, 'ggg');", + ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + }, + { + Query: "CREATE PROCEDURE back_up(num INT, msg TEXT) INSERT INTO t1 (v1, v2) VALUES (num*2, msg);", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "CALL add_entry(4, 'aaa');", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 1, InsertID: 1}}}, + }, + { + Query: "SELECT * FROM t0;", + Expected: []sql.Row{{1, 2, "abc"}, {2, 3, "def"}, {3, 4, "aaa"}}, + }, + { + Query: "SELECT * FROM t1;", + Expected: []sql.Row{{1, 8, "aaa"}}, + }, + { + Query: "CALL add_entry(54, 'bbb');", + ExpectedErrStr: "too big number (errno 1644) (sqlstate 45000)", + }, + }, + }, } // RollbackTriggerTests are trigger tests that require rollback logic to work correctly @@ -2782,22 +2944,20 @@ var TriggerErrorTests = []ScriptTest{ Query: "create trigger update_new after update on x for each row BEGIN set new.c = new.a + 1; END", ExpectedErr: sql.ErrInvalidUpdateInAfterTrigger, }, - // This isn't an error in MySQL until runtime, but we catch it earlier because why not { Name: "source column doesn't exist", SetUpScript: []string{ "create table x (a int primary key, b int, c int)", }, Query: "create trigger not_found before insert on x for each row set new.d = new.d + 1", - ExpectedErr: sql.ErrTableColumnNotFound, - }, - // TODO: this isn't an error in MySQL, but we could catch it and make it one - // { - // Name: "target column doesn't exist", - // SetUpScript: []string{ - // "create table x (a int primary key, b int, c int)", - // }, - // Query: "create trigger not_found before insert on x for each row set new.d = new.a + 1", - // ExpectedErr: sql.ErrTableColumnNotFound, - // }, + ExpectedErr: sql.ErrUnknownColumn, + }, + { + Name: "target column doesn't exist", + SetUpScript: []string{ + "create table x (a int primary key, b int, c int)", + }, + Query: "create trigger not_found before insert on x for each row set new.d = new.a + 1", + ExpectedErr: sql.ErrUnknownColumn, + }, } diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index 6d52c67775..316124fc0c 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -40,7 +40,7 @@ func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope * // UnresolvedColumn expressions with placeholder expressions that say they are Resolved(). // TODO: this might work badly for databases with tables named new and old. Needs tests. var err error - transform.InspectExpressions(ct, func(e sql.Expression) bool { + transform.InspectExpressions(ct.Body, func(e sql.Expression) bool { switch e := e.(type) { case *expression.UnresolvedColumn: if strings.ToLower(e.Table()) == "new" { @@ -73,7 +73,7 @@ func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope * } // Check to see if the plan sets a value for "old" rows, or if an AFTER trigger assigns to NEW. Both are illegal. - transform.InspectExpressionsWithNode(node, func(n sql.Node, e sql.Expression) bool { + transform.InspectExpressionsWithNode(ct.Body, func(n sql.Node, e sql.Expression) bool { if _, ok := n.(*plan.Set); !ok { return true } @@ -98,23 +98,32 @@ func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope * return nil, transform.SameTree, err } - // Finally analyze the entire trigger body with an appropriate scope for any "old" and "new" table references. This - // will catch (most) other errors in a trigger body. We set the trigger body at the end to pass to final validation - // steps at the end of analysis. - scopeNode := plan.NewProject( - []sql.Expression{expression.NewStar()}, - plan.NewCrossJoin( - plan.NewTableAlias("old", getResolvedTable(ct.Table)), - plan.NewTableAlias("new", getResolvedTable(ct.Table)), - ), - ) - - triggerLogic, _, err := a.analyzeWithSelector(ctx, ct.Body, (*Scope)(nil).newScope(scopeNode), SelectAllBatches, sel) - if err != nil { - return nil, transform.SameTree, err + trigTable := getResolvedTable(ct.Table) + sch := trigTable.Schema() + colsList := make(map[string]struct{}) + for _, c := range sch { + colsList[c.Name] = struct{}{} } - node, err = ct.WithChildren(ct.Table, StripPassthroughNodes(triggerLogic)) + // Check to see if the columns with "new" and "old" table reference are valid columns from the trigger table. + transform.InspectExpressionsWithNode(ct.Body, func(n sql.Node, e sql.Expression) bool { + switch e := e.(type) { + case *expression.UnresolvedColumn: + if strings.ToLower(e.Table()) == "old" || strings.ToLower(e.Table()) == "new" { + if _, ok := colsList[e.Name()]; !ok { + err = sql.ErrUnknownColumn.New(e.Name(), e.Table()) + } + } + case *deferredColumn: + if strings.ToLower(e.Table()) == "old" || strings.ToLower(e.Table()) == "new" { + if _, ok := colsList[e.Name()]; !ok { + err = sql.ErrUnknownColumn.New(e.Name(), e.Table()) + } + } + } + return true + }) + if err != nil { return nil, transform.SameTree, err } diff --git a/sql/errors.go b/sql/errors.go index 97846f62b4..31046e2a9d 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -232,6 +232,9 @@ var ( // ErrInvalidUpdateInAfterTrigger is returned when a trigger attempts to assign to a new row in an AFTER trigger ErrInvalidUpdateInAfterTrigger = errors.NewKind("Updating of new row is not allowed in after trigger") + // ErrUnknownColumn is returned when the given column is not found in referenced table + ErrUnknownColumn = errors.NewKind("Unknown column '%s' in '%s'") + // ErrUnboundPreparedStatementVariable is returned when a query is executed without a binding for one its variables. ErrUnboundPreparedStatementVariable = errors.NewKind(`unbound variable "%s" in query`) diff --git a/sql/plan/call.go b/sql/plan/call.go index 4e11b968a1..4a419ca163 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -124,7 +124,7 @@ func (c *Call) String() string { // RowIter implements the sql.Node interface. func (c *Call) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { for i, paramExpr := range c.Params { - val, err := paramExpr.Eval(ctx, nil) + val, err := paramExpr.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/plan/ddl_trigger.go b/sql/plan/ddl_trigger.go index 1ef3476513..79a5d6503e 100644 --- a/sql/plan/ddl_trigger.go +++ b/sql/plan/ddl_trigger.go @@ -81,7 +81,8 @@ func (c *CreateTrigger) WithDatabase(database sql.Database) (sql.Node, error) { } func (c *CreateTrigger) Resolved() bool { - return c.ddlNode.Resolved() && c.Table.Resolved() && c.Body.Resolved() + // c.Body can be unresolved since it can have unresolved table reference to non-existent table + return c.ddlNode.Resolved() && c.Table.Resolved() } func (c *CreateTrigger) Schema() sql.Schema { @@ -89,17 +90,16 @@ func (c *CreateTrigger) Schema() sql.Schema { } func (c *CreateTrigger) Children() []sql.Node { - return []sql.Node{c.Table, c.Body} + return []sql.Node{c.Table} } func (c *CreateTrigger) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 2 { - return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 2) + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } nc := *c nc.Table = children[0] - nc.Body = children[1] return &nc, nil }