Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions errguard/errguard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2026 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package errguard

import (
"fmt"
"runtime/debug"

"golang.org/x/sync/errgroup"
)

// Go runs |fn| in the errgroup, converting any panic into an error, with
// a stack trace, that is later returned by errgroup.Group.Wait(). The intent
// of this function is to provide a standard function for spawning a goroutine
// in an errgroup that has consistent panic recovery handling.
func Go(g *errgroup.Group, fn func() error) {
g.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic recovered: %v\n%s", r, debug.Stack())
}
}()
return fn()
})
}
9 changes: 5 additions & 4 deletions server/golden/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"

"github.com/dolthub/go-mysql-server/errguard"
"github.com/dolthub/go-mysql-server/sql"
)

Expand Down Expand Up @@ -108,11 +109,11 @@ func (v Validator) ComMultiQuery(
ag := newResultAggregator(callback)
var remainder string
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() (err error) {
errguard.Go(eg, func() (err error) {
remainder, err = v.handler.ComMultiQuery(ctx, c, query, ag.processResults)
return
})
eg.Go(func() error {
errguard.Go(eg, func() error {
// ignore errors from MySQL connection
_, _ = v.golden.ComMultiQuery(ctx, c, query, ag.processGoldenResults)
return nil
Expand All @@ -136,10 +137,10 @@ func (v Validator) ComQuery(
) error {
ag := newResultAggregator(callback)
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
errguard.Go(eg, func() error {
return v.handler.ComQuery(ctx, c, query, ag.processResults)
})
eg.Go(func() error {
errguard.Go(eg, func() error {
// ignore errors from MySQL connection
_ = v.golden.ComQuery(ctx, c, query, ag.processGoldenResults)
return nil
Expand Down
48 changes: 11 additions & 37 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ package server
import (
"context"
"encoding/base64"
goerrors "errors"
"fmt"
"io"
"net"
"regexp"
"runtime/debug"
"runtime/trace"
"sync"
"time"
Expand All @@ -38,6 +36,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/errguard"
"github.com/dolthub/go-mysql-server/internal/sockstate"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
Expand Down Expand Up @@ -635,20 +634,11 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()
pan2err := func(err *error) {
if recoveredPanic := recover(); recoveredPanic != nil {
stack := debug.Stack()
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack)
*err = goerrors.Join(*err, wrappedErr)
}
}

// TODO: poll for closed connections should obviously also run even if
// we're doing something with an OK result or a single row result, etc.
// This should be in the caller.
pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
return h.pollForClosedConnection(pollCtx, c)
})

Expand Down Expand Up @@ -681,8 +671,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s

// Read rows off the row iterator and send them to the row channel.
var rowChan = make(chan sql.Row, 512)
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
defer wg.Done()
defer close(rowChan)
for {
Expand All @@ -709,8 +698,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
// Drain rows from rowChan, convert to wire format, and send to resChan
var resChan = make(chan *sqltypes.Result, 4)
var res *sqltypes.Result
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
defer wg.Done()
defer close(resChan)

Expand Down Expand Up @@ -771,8 +759,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s

// Drain sqltypes.Result from resChan and call callback (send to client and potentially reset buffer)
var processedAtLeastOneBatch bool
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() (err error) {
defer cancelF()
defer wg.Done()
for {
Expand All @@ -794,8 +781,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s

// Close() kills this PID in the process list,
// wait until all rows have be sent over the wire
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
wg.Wait()
return iter.Close(ctx)
})
Expand All @@ -815,19 +801,11 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
defer trace.StartRegion(ctx, "Handler.resultForValueRowIter").End()

eg, ctx := ctx.NewErrgroup()
pan2err := func(err *error) {
if recoveredPanic := recover(); recoveredPanic != nil {
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, debug.Stack())
*err = goerrors.Join(*err, wrappedErr)
}
}

// TODO: poll for closed connections should obviously also run even if
// we're doing something with an OK result or a single row result, etc.
// This should be in the caller.
pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
return h.pollForClosedConnection(pollCtx, c)
})

Expand Down Expand Up @@ -858,8 +836,7 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema

// Drain rows from iter and send to rowsChan
var rowChan = make(chan sql.ValueRow, 512)
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
defer wg.Done()
defer close(rowChan)
for {
Expand All @@ -886,8 +863,7 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema
// Drain rows from rowChan, convert to wire format, and send to resChan
var resChan = make(chan *sqltypes.Result, 4)
var res *sqltypes.Result
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
defer close(resChan)
defer wg.Done()

Expand Down Expand Up @@ -940,8 +916,7 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema

// Drain sqltypes.Result from resChan and call callback (send to client and reset buffer)
var processedAtLeastOneBatch bool
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() (err error) {
defer cancelF()
defer wg.Done()
for {
Expand All @@ -963,8 +938,7 @@ func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema

// Close() kills this PID in the process list,
// wait until all rows have be sent over the wire
eg.Go(func() (err error) {
defer pan2err(&err)
errguard.Go(eg, func() error {
wg.Wait()
return iter.Close(ctx)
})
Expand Down
6 changes: 4 additions & 2 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"syscall"

"golang.org/x/sync/errgroup"

"github.com/dolthub/go-mysql-server/errguard"
)

var UnixSocketInUseError = errors.New("bind address at given unix socket path is already in use")
Expand Down Expand Up @@ -84,7 +86,7 @@ func NewListener(protocol, address string, unixSocketPath string) (*Listener, er
shutdown: make(chan struct{}),
once: &sync.Once{},
}
l.eg.Go(func() error {
errguard.Go(l.eg, func() error {
for {
conn, err := l.netListener.Accept()
// connection can be closed already from the other goroutine
Expand All @@ -102,7 +104,7 @@ func NewListener(protocol, address string, unixSocketPath string) (*Listener, er
})

if l.unixListener != nil {
l.eg.Go(func() error {
errguard.Go(l.eg, func() error {
for {
conn, err := l.unixListener.Accept()
// connection can be closed already from the other goroutine
Expand Down
12 changes: 3 additions & 9 deletions sql/rowexec/agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ package rowexec

import (
"errors"
"fmt"
"io"

"github.com/dolthub/go-mysql-server/errguard"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
"github.com/dolthub/go-mysql-server/sql/hash"
Expand Down Expand Up @@ -171,7 +171,7 @@ func (i *groupByGroupingIter) compute(ctx *sql.Context) error {
eg, subCtx := ctx.NewErrgroup()

var rowChan = make(chan sql.Row, 512)
eg.Go(func() error {
errguard.Go(eg, func() error {
defer close(rowChan)
for {
row, err := i.child.Next(subCtx)
Expand All @@ -185,13 +185,7 @@ func (i *groupByGroupingIter) compute(ctx *sql.Context) error {
}
})

eg.Go(func() (err error) {
defer func() {
if recoveredPanic := recover(); recoveredPanic != nil {
err = fmt.Errorf("caught panic: %v", recoveredPanic)
}
}()

errguard.Go(eg, func() error {
for {
row, ok := <-rowChan
if !ok {
Expand Down
Loading