diff --git a/postgres/parser/types/types.go b/postgres/parser/types/types.go index e431bf8209..f0e95b2e6f 100644 --- a/postgres/parser/types/types.go +++ b/postgres/parser/types/types.go @@ -2553,6 +2553,7 @@ var unreservedTypeTokens = map[string]*T{ "float4": Float, "float8": Float, "inet": INet, + "integer": Int4, "int2": Int2, "int4": Int4, "int8": Int, diff --git a/server/plpgsql/interpreter_logic.go b/server/plpgsql/interpreter_logic.go index bded121758..9d0a31d2e1 100644 --- a/server/plpgsql/interpreter_logic.go +++ b/server/plpgsql/interpreter_logic.go @@ -16,12 +16,14 @@ package plpgsql import ( "fmt" + "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/typecollection" + "github.com/dolthub/doltgresql/postgres/parser/types" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -92,7 +94,26 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement if err != nil { return nil, err } - resolvedType, exists := typeCollection.GetType(id.NewType("pg_catalog", operation.PrimaryData)) + + // pg_query_go sets PrimaryData for implicit CASE statement variables to + // `pg_catalog."integer"`, so we remove double-quotes and extract the schema name. + typeName := operation.PrimaryData + typeName = strings.ReplaceAll(typeName, `"`, "") + schemaName := "pg_catalog" + if strings.Contains(typeName, ".") { + parts := strings.Split(typeName, ".") + schemaName = parts[0] + typeName = parts[1] + // Check the NonKeyword type names to see if we're looking at + // an alias of a type if we're in the pg_catalog schema. + if schemaName == "pg_catalog" { + typ, ok, _ := types.TypeForNonKeywordTypeName(typeName) + if ok && typ != nil { + typeName = typ.Name() + } + } + } + resolvedType, exists := typeCollection.GetType(id.NewType(schemaName, typeName)) if !exists { return nil, pgtypes.ErrTypeDoesNotExist.New(operation.PrimaryData) } diff --git a/server/plpgsql/json.go b/server/plpgsql/json.go index 08618a65de..fe0dea0102 100644 --- a/server/plpgsql/json.go +++ b/server/plpgsql/json.go @@ -15,6 +15,7 @@ package plpgsql import ( + "fmt" "strings" "github.com/cockroachdb/errors" @@ -102,6 +103,25 @@ type plpgSQL_stmt_block struct { LineNumber int32 `json:"lineno"` } +// plpgSQL_stmt_case exists to match the expected JSON format. +type plpgSQL_stmt_case struct { + LineNumber int32 `json:"lineno"` + Expression expr `json:"t_expr"` + // VarNo indicates the ID for the __Case__Variable_N__ variable that holds the evaluated + // value of the case expression. + VarNo int32 `json:"t_varno"` + WhenList []statement `json:"case_when_list"` + HasElse bool `json:"have_else"` + Else []statement `json:"else_stmts"` +} + +// plpgSQL_case_when exists to match the expected JSON format. +type plpgSQL_case_when struct { + LineNumber int32 `json:"lineno"` + Expression expr `json:"expr"` + Body []statement `json:"stmts"` +} + // plpgSQL_stmt_execsql exists to match the expected JSON format. type plpgSQL_stmt_execsql struct { SQLStmt sqlstmt `json:"sqlstmt"` @@ -175,12 +195,14 @@ type sqlstmt struct { // having a singular expected implementation. type statement struct { Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"` + Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"` ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"` Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"` If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"` Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"` Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"` Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"` + When *plpgSQL_case_when `json:"PLpgSQL_case_when"` While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"` } @@ -204,6 +226,93 @@ func (stmt *plpgSQL_stmt_assign) Convert() (Assignment, error) { }, nil } +func (stmt *plpgSQL_stmt_case) Convert() (block Block, err error) { + // If the CASE statement has a main expression, start by assigning it to a variable so + // we can evaluate it once and only once. + if stmt.Expression.Expression.Query != "" { + // TODO: pg_query_go creates the definitions for these variables, and + // ideally users shouldn't be able to reference them. We could + // update all the references to them (i.e. declaration, assignment, + // and WHEN block exprs) to change the name to include a \0 char to + // prevent users from referencing them or colliding with them. + block.Body = append(block.Body, Assignment{ + VariableName: fmt.Sprintf("__Case__Variable_%d__", stmt.VarNo), + Expression: stmt.Expression.Expression.Query, + }) + } + + // Record indexes of all the GOTO ops that jump to the very end of the case block so we + // can update them later and plug in the correct offsets after we know the final size. + var gotoEndOpsIndexes []int + + // Add operations for each WHEN statement... + for _, stmt := range stmt.WhenList { + when := stmt.When + if when == nil { + return Block{}, fmt.Errorf("case statement WHEN clause is nil") + } + + // TODO: The generated expressions from pg_query_go uses double quotes + // around the variable name, which is valid for Postgres, but + // our engine doesn't currently resolve double-quoted strings to + // variables, so for now, we just extract the double quotes. + expressionString := when.Expression.Expression.Query + expressionString = strings.ReplaceAll(expressionString, `"`, "") + + convertedWhenBodyStatements, err := jsonConvertStatements(when.Body) + if err != nil { + return Block{}, err + } + + block.Body = append(block.Body, + If{ + Condition: expressionString, + GotoOffset: 2, + }, + Goto{ + // This GOTO jumps to the next WHEN block, so step over all the statements + // from this WHEN block, plus 1 for the GOTO op we add at the end of each + // block, and plus 1 more to move to the next statement. + Offset: int32(len(convertedWhenBodyStatements) + 1 + 1), + }) + block.Body = append(block.Body, convertedWhenBodyStatements...) + + // Add a GOTO op to jump to the end of the entire CASE block, and record its position + // in the statement block so we can update it later. + block.Body = append(block.Body, Goto{}) + gotoEndOpsIndexes = append(gotoEndOpsIndexes, len(block.Body)-1) + } + + if stmt.HasElse { + convertElseBodyStatements, err := jsonConvertStatements(stmt.Else) + if err != nil { + return Block{}, err + } + block.Body = append(block.Body, convertElseBodyStatements...) + // TODO: If no cases match and there is no ELSE block, then add a RAISE statement + // to return an error. + //} else { + // Sample PostgreSQL error response: + // ERROR: case not found + // HINT: CASE statement is missing ELSE part. + // CONTEXT: PL/pgSQL function interpreted_case(integer) line 5 at CASE + } + + // Update all the GOTO ops that jump to the very end of the case block. + for _, gotoEndOpIndex := range gotoEndOpsIndexes { + // Sanity check that we are looking at a GOTO statement + if _, ok := block.Body[gotoEndOpIndex].(Goto); !ok { + return Block{}, fmt.Errorf("expected Goto statement, got %T", block.Body[gotoEndOpIndex]) + } + + block.Body[gotoEndOpIndex] = Goto{ + Offset: int32(len(block.Body) - gotoEndOpIndex), + } + } + + return block, nil +} + // Convert converts the JSON statement into its output form. func (stmt *plpgSQL_stmt_execsql) Convert() (ExecuteSQL, error) { var target string diff --git a/server/plpgsql/json_convert.go b/server/plpgsql/json_convert.go index 4bf4ff04a2..55c186e2a7 100644 --- a/server/plpgsql/json_convert.go +++ b/server/plpgsql/json_convert.go @@ -52,6 +52,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) { switch { case stmt.Assignment != nil: return stmt.Assignment.Convert() + case stmt.Case != nil: + return stmt.Case.Convert() case stmt.ExecSQL != nil: return stmt.ExecSQL.Convert() case stmt.Exit != nil: diff --git a/testing/go/create_function_test.go b/testing/go/create_function_test.go index bf1e06b470..002b613480 100644 --- a/testing/go/create_function_test.go +++ b/testing/go/create_function_test.go @@ -98,6 +98,163 @@ $$ LANGUAGE plpgsql;`}, }, }, }, + { + Name: "CASE, with ELSE", + SetUpScript: []string{` +CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$ +DECLARE + msg TEXT; +BEGIN + CASE x + WHEN 1, 2 THEN + msg := 'one'; + msg := msg || ' or two'; + ELSE + msg := 'other'; + msg := msg || ' value than one or two'; + END CASE; + RETURN msg; +END; +$$ LANGUAGE plpgsql;`}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT interpreted_case(1);", + Expected: []sql.Row{{"one or two"}}, + }, + { + Query: "SELECT interpreted_case(2);", + Expected: []sql.Row{{"one or two"}}, + }, + { + Query: "SELECT interpreted_case(0);", + Expected: []sql.Row{{"other value than one or two"}}, + }, + }, + }, + { + // TODO: When no CASE statements match, and there is no ELSE block, + // Postgres raises an exception. Unskip this test after we + // add support for raising exceptions from functions. + Skip: true, + Name: "CASE, without ELSE", + SetUpScript: []string{` +CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$ +DECLARE + msg TEXT; +BEGIN + CASE x + WHEN 1, 2 THEN + msg := 'one or two'; + END CASE; + RETURN msg; +END; +$$ LANGUAGE plpgsql;`}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT interpreted_case(1);", + Expected: []sql.Row{{"one or two"}}, + }, + { + Query: "SELECT interpreted_case(2);", + Expected: []sql.Row{{"one or two"}}, + }, + { + Query: "SELECT interpreted_case(0);", + ExpectedErr: "case not found", + }, + }, + }, + { + Name: "Searched CASE, with ELSE", + SetUpScript: []string{` +CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$ +DECLARE + msg TEXT; +BEGIN + CASE + WHEN x BETWEEN 0 AND 10 THEN + msg := 'value is between zero'; + msg := msg || ' and ten'; + WHEN x BETWEEN 11 AND 20 THEN + msg := 'value is between eleven and twenty'; + ELSE + msg := 'value'; + msg := msg || ' is'; + msg := msg || ' out of'; + msg := msg || ' bounds'; + END CASE; + RETURN msg; +END; +$$ LANGUAGE plpgsql;`}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT interpreted_case(0);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(1);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(10);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(11);", + Expected: []sql.Row{{"value is between eleven and twenty"}}, + }, + { + Query: "SELECT interpreted_case(21);", + Expected: []sql.Row{{"value is out of bounds"}}, + }, + }, + }, + { + // TODO: When no CASE statements match, and there is no ELSE block, + // Postgres raises an exception. Unskip this test after we + // add support for raising exceptions from functions. + Skip: true, + Name: "Searched CASE, without ELSE", + SetUpScript: []string{` +CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$ +DECLARE + msg TEXT; +BEGIN + CASE + WHEN x BETWEEN 0 AND 10 THEN + msg := 'value is between zero and ten'; + WHEN x BETWEEN 11 AND 20 THEN + msg := 'value'; + msg := msg || ' is between'; + msg := msg || ' eleven and'; + msg := msg || ' twenty'; + END CASE; + RETURN msg; +END; +$$ LANGUAGE plpgsql;`}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT interpreted_case(0);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(1);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(10);", + Expected: []sql.Row{{"value is between zero and ten"}}, + }, + { + Query: "SELECT interpreted_case(11);", + Expected: []sql.Row{{"value is between eleven and twenty"}}, + }, + { + Query: "SELECT interpreted_case(21);", + ExpectedErr: "case not found", + }, + }, + }, { Name: "CONTINUE", SetUpScript: []string{`CREATE FUNCTION interpreted_continue() RETURNS INT4 AS $$