Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions postgres/parser/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2553,6 +2553,7 @@ var unreservedTypeTokens = map[string]*T{
"float4": Float,
"float8": Float,
"inet": INet,
"integer": Int4,
"int2": Int2,
"int4": Int4,
"int8": Int,
Expand Down
23 changes: 22 additions & 1 deletion server/plpgsql/interpreter_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ package plpgsql

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/typecollection"
"github.com/dolthub/doltgresql/postgres/parser/types"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand Down Expand Up @@ -92,7 +94,26 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
if err != nil {
return nil, err
}
resolvedType, exists := typeCollection.GetType(id.NewType("pg_catalog", operation.PrimaryData))

// pg_query_go sets PrimaryData for implicit CASE statement variables to
// `pg_catalog."integer"`, so we remove double-quotes and extract the schema name.
typeName := operation.PrimaryData
typeName = strings.ReplaceAll(typeName, `"`, "")
schemaName := "pg_catalog"
if strings.Contains(typeName, ".") {
parts := strings.Split(typeName, ".")
schemaName = parts[0]
typeName = parts[1]
// Check the NonKeyword type names to see if we're looking at
// an alias of a type if we're in the pg_catalog schema.
if schemaName == "pg_catalog" {
typ, ok, _ := types.TypeForNonKeywordTypeName(typeName)
if ok && typ != nil {
typeName = typ.Name()
}
}
}
resolvedType, exists := typeCollection.GetType(id.NewType(schemaName, typeName))
if !exists {
return nil, pgtypes.ErrTypeDoesNotExist.New(operation.PrimaryData)
}
Expand Down
109 changes: 109 additions & 0 deletions server/plpgsql/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package plpgsql

import (
"fmt"
"strings"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -102,6 +103,25 @@ type plpgSQL_stmt_block struct {
LineNumber int32 `json:"lineno"`
}

// plpgSQL_stmt_case exists to match the expected JSON format.
type plpgSQL_stmt_case struct {
LineNumber int32 `json:"lineno"`
Expression expr `json:"t_expr"`
// VarNo indicates the ID for the __Case__Variable_N__ variable that holds the evaluated
// value of the case expression.
VarNo int32 `json:"t_varno"`
WhenList []statement `json:"case_when_list"`
HasElse bool `json:"have_else"`
Else []statement `json:"else_stmts"`
}

// plpgSQL_case_when exists to match the expected JSON format.
type plpgSQL_case_when struct {
LineNumber int32 `json:"lineno"`
Expression expr `json:"expr"`
Body []statement `json:"stmts"`
}

// plpgSQL_stmt_execsql exists to match the expected JSON format.
type plpgSQL_stmt_execsql struct {
SQLStmt sqlstmt `json:"sqlstmt"`
Expand Down Expand Up @@ -175,12 +195,14 @@ type sqlstmt struct {
// having a singular expected implementation.
type statement struct {
Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"`
Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"`
ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"`
Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"`
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"`
Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"`
Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"`
When *plpgSQL_case_when `json:"PLpgSQL_case_when"`
While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"`
}

Expand All @@ -204,6 +226,93 @@ func (stmt *plpgSQL_stmt_assign) Convert() (Assignment, error) {
}, nil
}

func (stmt *plpgSQL_stmt_case) Convert() (block Block, err error) {
// If the CASE statement has a main expression, start by assigning it to a variable so
// we can evaluate it once and only once.
if stmt.Expression.Expression.Query != "" {
// TODO: pg_query_go creates the definitions for these variables, and
// ideally users shouldn't be able to reference them. We could
// update all the references to them (i.e. declaration, assignment,
// and WHEN block exprs) to change the name to include a \0 char to
// prevent users from referencing them or colliding with them.
block.Body = append(block.Body, Assignment{
VariableName: fmt.Sprintf("__Case__Variable_%d__", stmt.VarNo),
Expression: stmt.Expression.Expression.Query,
})
}

// Record indexes of all the GOTO ops that jump to the very end of the case block so we
// can update them later and plug in the correct offsets after we know the final size.
var gotoEndOpsIndexes []int

// Add operations for each WHEN statement...
for _, stmt := range stmt.WhenList {
when := stmt.When
if when == nil {
return Block{}, fmt.Errorf("case statement WHEN clause is nil")
}

// TODO: The generated expressions from pg_query_go uses double quotes
// around the variable name, which is valid for Postgres, but
// our engine doesn't currently resolve double-quoted strings to
// variables, so for now, we just extract the double quotes.
expressionString := when.Expression.Expression.Query
expressionString = strings.ReplaceAll(expressionString, `"`, "")

convertedWhenBodyStatements, err := jsonConvertStatements(when.Body)
if err != nil {
return Block{}, err
}

block.Body = append(block.Body,
If{
Condition: expressionString,
GotoOffset: 2,
},
Goto{
// This GOTO jumps to the next WHEN block, so step over all the statements
// from this WHEN block, plus 1 for the GOTO op we add at the end of each
// block, and plus 1 more to move to the next statement.
Offset: int32(len(convertedWhenBodyStatements) + 1 + 1),
})
block.Body = append(block.Body, convertedWhenBodyStatements...)

// Add a GOTO op to jump to the end of the entire CASE block, and record its position
// in the statement block so we can update it later.
block.Body = append(block.Body, Goto{})
gotoEndOpsIndexes = append(gotoEndOpsIndexes, len(block.Body)-1)
}

if stmt.HasElse {
convertElseBodyStatements, err := jsonConvertStatements(stmt.Else)
if err != nil {
return Block{}, err
}
block.Body = append(block.Body, convertElseBodyStatements...)
// TODO: If no cases match and there is no ELSE block, then add a RAISE statement
// to return an error.
//} else {
// Sample PostgreSQL error response:
// ERROR: case not found
// HINT: CASE statement is missing ELSE part.
// CONTEXT: PL/pgSQL function interpreted_case(integer) line 5 at CASE
}

// Update all the GOTO ops that jump to the very end of the case block.
for _, gotoEndOpIndex := range gotoEndOpsIndexes {
// Sanity check that we are looking at a GOTO statement
if _, ok := block.Body[gotoEndOpIndex].(Goto); !ok {
return Block{}, fmt.Errorf("expected Goto statement, got %T", block.Body[gotoEndOpIndex])
}

block.Body[gotoEndOpIndex] = Goto{
Offset: int32(len(block.Body) - gotoEndOpIndex),
}
}

return block, nil
}

// Convert converts the JSON statement into its output form.
func (stmt *plpgSQL_stmt_execsql) Convert() (ExecuteSQL, error) {
var target string
Expand Down
2 changes: 2 additions & 0 deletions server/plpgsql/json_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) {
switch {
case stmt.Assignment != nil:
return stmt.Assignment.Convert()
case stmt.Case != nil:
return stmt.Case.Convert()
case stmt.ExecSQL != nil:
return stmt.ExecSQL.Convert()
case stmt.Exit != nil:
Expand Down
157 changes: 157 additions & 0 deletions testing/go/create_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,163 @@ $$ LANGUAGE plpgsql;`},
},
},
},
{
Name: "CASE, with ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE x
WHEN 1, 2 THEN
msg := 'one';
msg := msg || ' or two';
ELSE
msg := 'other';
msg := msg || ' value than one or two';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(2);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"other value than one or two"}},
},
},
},
{
// TODO: When no CASE statements match, and there is no ELSE block,
// Postgres raises an exception. Unskip this test after we
// add support for raising exceptions from functions.
Skip: true,
Name: "CASE, without ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE x
WHEN 1, 2 THEN
msg := 'one or two';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(2);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(0);",
ExpectedErr: "case not found",
},
},
},
{
Name: "Searched CASE, with ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE
WHEN x BETWEEN 0 AND 10 THEN
msg := 'value is between zero';
msg := msg || ' and ten';
WHEN x BETWEEN 11 AND 20 THEN
msg := 'value is between eleven and twenty';
ELSE
msg := 'value';
msg := msg || ' is';
msg := msg || ' out of';
msg := msg || ' bounds';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(10);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(11);",
Expected: []sql.Row{{"value is between eleven and twenty"}},
},
{
Query: "SELECT interpreted_case(21);",
Expected: []sql.Row{{"value is out of bounds"}},
},
},
},
{
// TODO: When no CASE statements match, and there is no ELSE block,
// Postgres raises an exception. Unskip this test after we
// add support for raising exceptions from functions.
Skip: true,
Name: "Searched CASE, without ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE
WHEN x BETWEEN 0 AND 10 THEN
msg := 'value is between zero and ten';
WHEN x BETWEEN 11 AND 20 THEN
msg := 'value';
msg := msg || ' is between';
msg := msg || ' eleven and';
msg := msg || ' twenty';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(10);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(11);",
Expected: []sql.Row{{"value is between eleven and twenty"}},
},
{
Query: "SELECT interpreted_case(21);",
ExpectedErr: "case not found",
},
},
},
{
Name: "CONTINUE",
SetUpScript: []string{`CREATE FUNCTION interpreted_continue() RETURNS INT4 AS $$
Expand Down
Loading