Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
"github.com/dolthub/doltgresql/core/typecollection"
)

// contextValues contains a set of objects that will be passed alongside the context.
// contextValues contains a set of cached data passed alongside the context. This data is considered temporary
// and may be refreshed at any point, including during the middle of a query. Callers should not assume that
// data stored in contextValues is persisted, and other types of data should not be added to contextValues.
type contextValues struct {
collection *sequences.Collection
types *typecollection.TypeCollection
Expand Down
2 changes: 2 additions & 0 deletions core/functions/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (pgf *Collection) Serialize(ctx context.Context) ([]byte, error) {
writer.StringSlice(op.SecondaryData)
writer.String(op.Target)
writer.Int32(int32(op.Index))
writer.StringMap(op.Options)
}
}

Expand Down Expand Up @@ -100,6 +101,7 @@ func Deserialize(ctx context.Context, data []byte) (*Collection, error) {
op.SecondaryData = reader.StringSlice()
op.Target = reader.String()
op.Index = int(reader.Int32())
op.Options = reader.StringMap()
f.Operations[opIdx] = op
}
// Add the function to each map
Expand Down
21 changes: 17 additions & 4 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start
if err != nil {
return err
}
err = h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, useStmt, parsed, func(res *Result) error {
err = h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, useStmt, parsed, func(_ *sql.Context, _ *Result) error {
return nil
})
// If a database isn't specified, then we attempt to connect to a database with the same name as the user,
Expand Down Expand Up @@ -774,7 +774,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
return false, false, err
}

callback := func(res *Result) error { return nil }
callback := func(_ *sql.Context, _ *Result) error { return nil }
err = h.doltgresHandler.ComExecuteBound(sqlCtx, h.mysqlConn, "COPY FROM", copyState.insertNode, callback)
if err != nil {
return false, false, err
Expand Down Expand Up @@ -929,10 +929,23 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {

// spoolRowsCallback returns a callback function that will send RowDescription message,
// then a DataRow message for each row in the result set.
func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute bool) func(res *Result) error {
func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute bool) func(ctx *sql.Context, res *Result) error {
// IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query.
isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE"
return func(res *Result) error {
return func(ctx *sql.Context, res *Result) error {
sess := dsess.DSessFromSess(ctx.Session)
for _, notice := range sess.Notices() {
backendMsg, ok := notice.(pgproto3.BackendMessage)
if !ok {
return fmt.Errorf("unexpected notice message type: %T", notice)
}

if err := h.send(backendMsg); err != nil {
return err
}
}
sess.ClearNotices()

if returnsRow(tag) {
// EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it
if !isExecute {
Expand Down
12 changes: 6 additions & 6 deletions server/doltgres_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query stri
}

// ComExecuteBound implements the Handler interface.
func (h *DoltgresHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error {
func (h *DoltgresHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*sql.Context, *Result) error) error {
analyzedPlan, ok := boundQuery.(sql.Node)
if !ok {
return errors.Errorf("boundQuery must be a sql.Node, but got %T", boundQuery)
Expand Down Expand Up @@ -181,7 +181,7 @@ func (h *DoltgresHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, q
}

// ComQuery implements the Handler interface.
func (h *DoltgresHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error {
func (h *DoltgresHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*sql.Context, *Result) error) error {
// TODO: This technically isn't query start and underestimates query execution time
start := time.Now()
if h.sel != nil {
Expand Down Expand Up @@ -281,7 +281,7 @@ func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32

var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)

func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*Result) error) error {
func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*sql.Context, *Result) error) error {
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
return err
Expand Down Expand Up @@ -349,7 +349,7 @@ func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query stri
return nil
}

return callback(r)
return callback(sqlCtx, r)
}

// QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of
Expand Down Expand Up @@ -500,7 +500,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*Result) error, resultFields []pgproto3.FieldDescription) (*Result, bool, error) {
func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*sql.Context, *Result) error, resultFields []pgproto3.FieldDescription) (*Result, bool, error) {
defer trace.StartRegion(ctx, "DoltgresHandler.resultForDefaultIter").End()

var r *Result
Expand Down Expand Up @@ -567,7 +567,7 @@ func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Sche
r = &Result{Fields: resultFields}
}
if r.RowsAffected == rowsBatch {
if err := callback(r); err != nil {
if err := callback(ctx, r); err != nil {
return err
}
r = nil
Expand Down
4 changes: 2 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ type Handler interface {
// ComBind is called when a connection receives a request to bind a prepared statement to a set of values.
ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars BindVariables) (mysql.BoundQuery, []pgproto3.FieldDescription, error)
// ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values.
ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error
ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*sql.Context, *Result) error) error
// ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed.
ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement) (mysql.ParsedQuery, []pgproto3.FieldDescription, error)
// ComQuery is called when a connection receives a query. Note the contents of the query slice may change
// after the first call to callback. So the DoltgresHandler should not hang on to the byte slice.
ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error
ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*sql.Context, *Result) error) error
// ComResetConnection resets the connection's session, clearing out any cached prepared statements, locks, user and
// session variables. The currently selected database is preserved.
ComResetConnection(c *mysql.Conn) error
Expand Down
88 changes: 86 additions & 2 deletions server/plpgsql/interpreter_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ package plpgsql

import (
"fmt"
"strconv"
"strings"

"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/jackc/pgx/v5/pgproto3"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/typecollection"
Expand Down Expand Up @@ -87,8 +90,6 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
if err != nil {
return nil, err
}
case OpCode_Case:
// TODO: implement
case OpCode_Declare:
typeCollection, err := GetTypesCollectionFromContext(ctx)
if err != nil {
Expand Down Expand Up @@ -188,6 +189,28 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
if _, err = sql.RowIterToRows(ctx, rowIter); err != nil {
return nil, err
}
case OpCode_Raise:
// TODO: Use the client_min_messages config param to determine which
// notice levels to send to the client.
// https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-CLIENT-MIN-MESSAGES

// TODO: Notices at the EXCEPTION level should also abort the current tx.

message, err := evaluteNoticeMessage(ctx, iFunc, operation, stack)
if err != nil {
return nil, err
}

noticeResponse := &pgproto3.NoticeResponse{
Severity: operation.PrimaryData,
Message: message,
}

if err = applyNoticeOptions(ctx, noticeResponse, operation.Options); err != nil {
return nil, err
}
sess := dsess.DSessFromSess(ctx.Session)
sess.Notice(noticeResponse)
case OpCode_Return:
if len(operation.PrimaryData) == 0 {
return nil, nil
Expand All @@ -207,3 +230,64 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
}
return nil, nil
}

// applyNoticeOptions adds the specified |options| to the |noticeResponse|.
func applyNoticeOptions(ctx *sql.Context, noticeResponse *pgproto3.NoticeResponse, options map[string]string) error {
for key, value := range options {
i, err := strconv.Atoi(key)
if err != nil {
return err
}

switch NoticeOptionType(i) {
case NoticeOptionTypeErrCode:
noticeResponse.Code = value
case NoticeOptionTypeMessage:
noticeResponse.Message = value
case NoticeOptionTypeDetail:
noticeResponse.Detail = value
case NoticeOptionTypeHint:
noticeResponse.Hint = value
case NoticeOptionTypeConstraint:
noticeResponse.ConstraintName = value
case NoticeOptionTypeDataType:
noticeResponse.DataTypeName = value
case NoticeOptionTypeTable:
noticeResponse.TableName = value
case NoticeOptionTypeSchema:
noticeResponse.SchemaName = value
default:
ctx.GetLogger().Warnf("unhandled notice option type: %s", key)
}
}
return nil
}

// evaluteNoticeMessage evaluates the message for a RAISE NOTICE statement, including
// evaluating any specified parameters and plugging them into the message in place of
// the % placeholders.
func evaluteNoticeMessage(ctx *sql.Context, iFunc InterpretedFunction,
operation InterpreterOperation, stack InterpreterStack) (string, error) {
message := operation.SecondaryData[0]
if len(operation.SecondaryData) > 1 {
params := operation.SecondaryData[1:]
currentParam := 0

parts := strings.Split(message, "%%")
for i, part := range parts {
for strings.Contains(part, "%") {
retVal, err := iFunc.QuerySingleReturn(ctx, stack, "SELECT "+params[currentParam], nil, nil)
if err != nil {
return "", err
}
currentParam += 1

s := fmt.Sprintf("%v", retVal)
part = strings.Replace(part, "%", s, 1)
}
parts[i] = part
}
message = strings.Join(parts, "%")
}
return message, nil
}
10 changes: 6 additions & 4 deletions server/plpgsql/interpreter_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS
OpCode_InsertInto // https://www.postgresql.org/docs/15/plpgsql-statements.html
OpCode_Perform // https://www.postgresql.org/docs/15/plpgsql-statements.html
OpCode_Raise // https://www.postgresql.org/docs/15/plpgsql-errors-and-messages.html
OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING
OpCode_ScopeBegin // This is used for scope control, specific to Doltgres
OpCode_ScopeEnd // This is used for scope control, specific to Doltgres
Expand All @@ -41,8 +42,9 @@ const (
// InterpreterOperation is an operation that will be performed by the interpreter.
type InterpreterOperation struct {
OpCode OpCode
PrimaryData string // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc.
SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc.
Target string // This is the variable that will store the results (if applicable)
Index int // This is the index that should be set for operations that move the function counter
PrimaryData string // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc.
SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc.
Target string // This is the variable that will store the results (if applicable)
Index int // This is the index that should be set for operations that move the function counter
Options map[string]string // This is extra data for operations that need it
}
42 changes: 42 additions & 0 deletions server/plpgsql/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package plpgsql

import (
"fmt"
"strconv"
"strings"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -160,6 +161,26 @@ type plpgSQL_stmt_perform struct {
LineNumber int32 `json:"lineno"`
}

// plpgSQL_stmt_raise exists to match the expected JSON format.
type plpgSQL_stmt_raise struct {
LineNumber int32 `json:"lineno"`
ELogLevel int32 `json:"elog_level"`
Message string `json:"message"`
Params []sqlstmt `json:"params"`
Options []plpgSQL_raise_option_wrapper `json:"options"`
}

// plpgSQL_raise_option_wrapper exists to match the expected JSON format.
type plpgSQL_raise_option_wrapper struct {
Option plpgSQL_raise_option `json:"PLpgSQL_raise_option"`
}

// plpgSQL_raise_option exists to match the expected JSON format.
type plpgSQL_raise_option struct {
OptionType int32 `json:"opt_type"`
Expression sqlstmt `json:"expr"`
}

// plpgSQL_stmt_return exists to match the expected JSON format.
type plpgSQL_stmt_return struct {
Expression expr `json:"expr"`
Expand Down Expand Up @@ -201,6 +222,7 @@ type statement struct {
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"`
Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"`
Raise *plpgSQL_stmt_raise `json:"PLpgSQL_stmt_raise"`
Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"`
When *plpgSQL_case_when `json:"PLpgSQL_case_when"`
While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"`
Expand Down Expand Up @@ -447,6 +469,26 @@ func (stmt *plpgSQL_stmt_perform) Convert() Perform {
}
}

// Convert converts the JSON statement into its output form.
func (stmt *plpgSQL_stmt_raise) Convert() Raise {
var params []string
for _, param := range stmt.Params {
params = append(params, param.Expr.Query)
}

options := make(map[string]string)
for _, option := range stmt.Options {
options[strconv.Itoa(int(option.Option.OptionType))] = option.Option.Expression.Expr.Query
}

return Raise{
Level: NoticeLevel(uint8(stmt.ELogLevel)).String(),
Message: stmt.Message,
Params: params,
Options: options,
}
}

// Convert converts the JSON statement into its output form.
func (stmt *plpgSQL_stmt_return) Convert() Return {
return Return{
Expand Down
2 changes: 2 additions & 0 deletions server/plpgsql/json_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) {
return stmt.Loop.Convert()
case stmt.Perform != nil:
return stmt.Perform.Convert(), nil
case stmt.Raise != nil:
return stmt.Raise.Convert(), nil
case stmt.Return != nil:
return stmt.Return.Convert(), nil
case stmt.While != nil:
Expand Down
Loading