Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/triggers/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ import (
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/rootobject/objinterface"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// Collection contains a collection of triggers.
Expand Down Expand Up @@ -83,8 +83,8 @@ type Trigger struct {
Function id.Function
Timing TriggerTiming
Events []TriggerEvent
ForEachRow bool // When false, represents FOR EACH STATEMENT
When sql.Expression // TODO: should this be PLpgSQL operations?
ForEachRow bool // When false, represents FOR EACH STATEMENT
When []plpgsql.InterpreterOperation
Deferrable TriggerDeferrable
ReferencedTableName id.Table // FROM referenced_table_name
Constraint bool
Expand Down
26 changes: 24 additions & 2 deletions core/triggers/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cockroachdb/errors"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
"github.com/dolthub/doltgresql/utils"
)

Expand All @@ -37,14 +38,23 @@ func (trigger Trigger) Serialize(ctx context.Context) ([]byte, error) {
writer.Id(trigger.Function.AsId())
writer.Uint8(uint8(trigger.Timing))
writer.Bool(trigger.ForEachRow)
// TODO: writer.Unknown(trigger.When)
writer.Uint8(uint8(trigger.Deferrable))
writer.Id(trigger.ReferencedTableName.AsId())
writer.Bool(trigger.Constraint)
writer.String(trigger.OldTransitionName)
writer.String(trigger.NewTransitionName)
writer.StringSlice(trigger.Arguments)
writer.String(trigger.Definition)
// Write the WHEN operations
writer.VariableUint(uint64(len(trigger.When)))
for _, op := range trigger.When {
writer.Uint16(uint16(op.OpCode))
writer.String(op.PrimaryData)
writer.StringSlice(op.SecondaryData)
writer.String(op.Target)
writer.Int32(int32(op.Index))
writer.StringMap(op.Options)
}
// Write the events
writer.VariableUint(uint64(len(trigger.Events)))
for _, event := range trigger.Events {
Expand Down Expand Up @@ -73,14 +83,26 @@ func DeserializeTrigger(ctx context.Context, data []byte) (Trigger, error) {
t.Function = id.Function(reader.Id())
t.Timing = TriggerTiming(reader.Uint8())
t.ForEachRow = reader.Bool()
// TODO: trigger.When = reader.Unknown()
t.Deferrable = TriggerDeferrable(reader.Uint8())
t.ReferencedTableName = id.Table(reader.Id())
t.Constraint = reader.Bool()
t.OldTransitionName = reader.String()
t.NewTransitionName = reader.String()
t.Arguments = reader.StringSlice()
t.Definition = reader.String()
// Read the WHEN operations
opCount := reader.VariableUint()
t.When = make([]plpgsql.InterpreterOperation, opCount)
for opIdx := uint64(0); opIdx < opCount; opIdx++ {
op := plpgsql.InterpreterOperation{}
op.OpCode = plpgsql.OpCode(reader.Uint16())
op.PrimaryData = reader.String()
op.SecondaryData = reader.StringSlice()
op.Target = reader.String()
op.Index = int(reader.Int32())
op.Options = reader.StringMap()
t.When[opIdx] = op
}
// Read the events
eventCount := reader.VariableUint()
t.Events = make([]TriggerEvent, eventCount)
Expand Down
36 changes: 32 additions & 4 deletions server/ast/create_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@
package ast

import (
"fmt"
"regexp"

"github.com/cockroachdb/errors"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/triggers"
pgnodes "github.com/dolthub/doltgresql/server/node"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgnodes "github.com/dolthub/doltgresql/server/node"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// createTriggerWhenCapture is a regex that should only capture the contents of the WHEN expression. Although a bit
// complex, this is done to ensure that the capture group contains only the WHEN expression and nothing else.
var createTriggerWhenCapture = regexp.MustCompile(`(?is)create\s+(?:or\s+replace\s+)?(?:constraint\s+)?trigger\s+.*\s+for\s+(?:each\s+)?(?:row|statement)\s+when\s+\((.*)\)\s+execute\s+(?:function|procedure).*`)

// nodeCreateTrigger handles *tree.CreateTrigger nodes.
func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement, error) {
func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (_ vitess.Statement, err error) {
if node.Constraint {
return NotYetSupportedError("CREATE CONSTRAINT TRIGGER is not yet supported")
}
Expand Down Expand Up @@ -76,6 +84,26 @@ func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement
return NotYetSupportedError("UNKNOWN EVENT TYPE is not yet supported for CREATE TRIGGER")
}
}
// WHEN expressions seem to behave identically to interpreted functions, so we'll parse them as interpreted functions.
// To do this, we need the raw string, and we wrap it as though it were a trigger function (which has special logic
// for handling NEW and OLD rows). Using a regex for this rather than modifying the parser may seem suboptimal, but
// we want to retain the parser validation of using an expression, however we cannot rely on the expression's
// String() function to return the **exact** same string, so we capture it with a regex.
var whenOps []plpgsql.InterpreterOperation
if node.When != nil {
matches := createTriggerWhenCapture.FindStringSubmatch(ctx.originalQuery)
if len(matches) != 2 {
return nil, errors.New("unable to parse WHEN expression from CREATE TRIGGER")
}
whenOps, err = plpgsql.Parse(fmt.Sprintf(`CREATE FUNCTION when_wrapper() RETURNS TRIGGER AS $$
BEGIN
RETURN %s;
END;
$$ LANGUAGE plpgsql;`, matches[1]))
if err != nil {
return nil, err
}
}
return vitess.InjectedStatement{
Statement: pgnodes.NewCreateTrigger(
id.NewTrigger(node.OnTable.Schema(), node.OnTable.Table(), node.Name.String()),
Expand All @@ -84,7 +112,7 @@ func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement
timing,
events,
node.ForEachRow,
nil, // TODO: node.When (expr)
whenOps,
node.Args.ToStrings(),
ctx.originalQuery,
),
Expand Down
12 changes: 6 additions & 6 deletions server/node/create_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (
"github.com/dolthub/go-mysql-server/sql/plan"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/functions"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/triggers"
"github.com/dolthub/doltgresql/server/plpgsql"
pgtypes "github.com/dolthub/doltgresql/server/types"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
)

// CreateTrigger implements CREATE TRIGGER.
Expand All @@ -37,7 +37,7 @@ type CreateTrigger struct {
Timing triggers.TriggerTiming
Events []triggers.TriggerEvent
ForEachRow bool
When sql.Expression
When []plpgsql.InterpreterOperation
Arguments []string
Definition string
}
Expand All @@ -53,7 +53,7 @@ func NewCreateTrigger(
timing triggers.TriggerTiming,
events []triggers.TriggerEvent,
forEachRow bool,
when sql.Expression,
when []plpgsql.InterpreterOperation,
arguments []string,
definition string) *CreateTrigger {
return &CreateTrigger{
Expand Down Expand Up @@ -125,7 +125,7 @@ func (c *CreateTrigger) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error
Timing: c.Timing,
Events: c.Events,
ForEachRow: c.ForEachRow,
When: nil,
When: c.When,
Deferrable: triggers.TriggerDeferrable_NotDeferrable,
ReferencedTableName: "",
Constraint: false,
Expand Down
32 changes: 31 additions & 1 deletion server/node/trigger_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package node

import (
"fmt"
"strings"

"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -83,14 +84,24 @@ func (te *TriggerExecution) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, e
return sourceIter, nil
}
trigFuncs := make([]framework.InterpretedFunction, len(te.Triggers))
whens := make([]framework.InterpretedFunction, len(te.Triggers))
for i, trig := range te.Triggers {
trigFuncs[i], err = te.loadTriggerFunction(ctx, trig)
if err != nil {
return nil, err
}
// If we have a WHEN expression, then we need to build a "function" to execute the expression
if len(trig.When) > 0 {
whens[i] = framework.InterpretedFunction{
ID: trigFuncs[i].ID, // Assign the same ID just so we have a valid one for later
ReturnType: pgtypes.Bool,
Statements: trig.When,
}
}
}
return &triggerExecutionIter{
functions: trigFuncs,
whens: whens,
split: te.Split,
treturn: te.Return,
runner: te.Runner.Runner,
Expand Down Expand Up @@ -156,6 +167,7 @@ func (te *TriggerExecution) loadTriggerFunction(ctx *sql.Context, trigger trigge
// triggerExecutionIter is the iterator for TriggerExecution.
type triggerExecutionIter struct {
functions []framework.InterpretedFunction
whens []framework.InterpretedFunction
split TriggerExecutionRowHandling
treturn TriggerExecutionRowHandling
runner analyzer.StatementRunner
Expand Down Expand Up @@ -185,7 +197,25 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
case TriggerExecutionRowHandling_New:
newRow = nextRow
}
for _, function := range t.functions {
for funcIdx, function := range t.functions {
if t.whens[funcIdx].ID.IsValid() {
whenValue, err := plpgsql.TriggerCall(ctx, t.whens[funcIdx], t.runner, t.sch, oldRow, newRow)
if err != nil {
if strings.Contains(err.Error(), "no valid cast for return value") {
// TODO: this error should technically be caught during parsing, but interpreted functions don't
// have the ability to determine types during parsing yet (also applies to the same error below)
return nil, fmt.Errorf("argument of WHEN must be type boolean")
}
return nil, err
}
whenBool, ok := whenValue.(bool)
if !ok {
return nil, fmt.Errorf("argument of WHEN must be type boolean")
}
if !whenBool {
continue
}
}
returnedValue, err := plpgsql.TriggerCall(ctx, function, t.runner, t.sch, oldRow, newRow)
if err != nil {
return nil, err
Expand Down
53 changes: 53 additions & 0 deletions testing/go/trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,5 +586,58 @@ $$ LANGUAGE plpgsql;`,
},
},
},
{
Name: "WHEN on BEFORE INSERT",
SetUpScript: []string{
"CREATE TABLE test (pk INT PRIMARY KEY, v1 TEXT);",
`CREATE FUNCTION trigger_func1() RETURNS TRIGGER AS $$
BEGIN
NEW.v1 := NEW.pk::text || '_' || NEW.v1;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;`,
`CREATE FUNCTION trigger_func2() RETURNS TRIGGER AS $$
BEGIN
NEW.v1 := NEW.v1 || '_' || NEW.pk::text;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;`,
`CREATE TRIGGER test_trigger1 BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk < 1) EXECUTE FUNCTION trigger_func1();`,
`CREATE TRIGGER test_trigger2 BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk > 1) EXECUTE FUNCTION trigger_func2();`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO test VALUES (0, 'hi'), (1, 'there'), (2, 'dude');",
Expected: []sql.Row{},
},
{
Query: "SELECT * FROM test;",
Expected: []sql.Row{
{0, "0_hi"},
{1, "there"},
{2, "dude_2"},
},
},
},
},
{
Name: "WHEN with non-boolean expression",
SetUpScript: []string{
"CREATE TABLE test (pk INT PRIMARY KEY, v1 TEXT);",
`CREATE FUNCTION trigger_func() RETURNS TRIGGER AS $$
BEGIN
NEW.v1 := NEW.pk::text || '_' || NEW.v1;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;`,
`CREATE TRIGGER test_trigger BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk + 1) EXECUTE FUNCTION trigger_func();`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO test VALUES (1, 'hi'), (2, 'there');",
ExpectedErr: "argument of WHEN must be type boolean",
},
},
},
})
}