Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -16230,10 +16230,6 @@ var CreateDatabaseScripts = []ScriptTest{
{
Name: "CREATE DATABASE error handling",
Assertions: []ScriptTestAssertion{
{
Query: "create database `abc/def`",
ExpectedErr: sql.ErrInvalidDatabaseName,
},
{
Query: "CREATE DATABASE newtestdb CHARACTER SET utf8mb4 ENCRYPTION='N'",
Expected: []sql.Row{{types.NewOkResult(1)}},
Expand Down
42 changes: 39 additions & 3 deletions sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ var (

// ErrDistinctOnMatchOrderBy is returned when DISTINCT ON does not match the initial ORDER BY expressions
ErrDistinctOnMatchOrderBy = errors.NewKind("SELECT DISTINCT ON expressions must match initial ORDER BY expressions")

// ErrWrongDBName is returned for illegal database names with the [mysql.ERWrongDbName] error code and [mysql.SSClientError] SQLSTATE.
ErrWrongDBName = newMySQLKind("Incorrect database name '%s'", mysql.ERWrongDbName, mysql.SSClientError)
)

// CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the
Expand All @@ -983,10 +986,8 @@ func CastSQLError(err error) *mysql.SQLError {
if mysqlErr, ok := err.(*mysql.SQLError); ok {
return mysqlErr
}

var code int
var sqlState string = ""

var sqlState = ""
if w, ok := err.(WrappedInsertError); ok {
return CastSQLError(w.Cause)
}
Expand All @@ -995,6 +996,12 @@ func CastSQLError(err error) *mysql.SQLError {
return CastSQLError(wm.Err)
}

for _, mySQLErr := range mySQLErrors {
if mySQLErr.Kind.Is(err) {
return mysql.NewSQLError(mySQLErr.Code, mySQLErr.SQLState, "%s", err.Error())
}
}

switch {
case ErrTableNotFound.Is(err):
code = mysql.ERNoSuchTable
Expand Down Expand Up @@ -1064,6 +1071,35 @@ func CastSQLError(err error) *mysql.SQLError {
return mysql.NewSQLError(code, sqlState, "%s", err.Error())
}

// mySQLErrors contain MySQL-specific [sql.SQLError] with their other metadata.
var mySQLErrors []SQLError

// newMySQLKind creates [sql.SQLError] specifically for mySQLErrors that is automatically interpreted by
// [sql.CastSQLError]. If |SQLState| is omitted, an empty string takes its place.
func newMySQLKind(msg string, code int, sqlState ...string) *errors.Kind {
err := errors.NewKind(msg)
state := ""
if len(sqlState) > 0 {
state = sqlState[0]
}
mySQLErrors = append(mySQLErrors, SQLError{
Kind: err,
Code: code,
SQLState: state,
})
return err
}

// SQLError identifies the error family and other metadata for SQL errors.
type SQLError struct {
// Kind identifies the engine error family.
Kind *errors.Kind
// Code is the numeric error code, and is implementation specific (e.g., MySQL error codes are not cross-platform).
Code int
// SQLState is the five-character string taken from ANSI SQL and ODBC.
SQLState string
}

// UnwrapError removes any wrapping errors (e.g. WrappedInsertError) around the specified error and
// returns the first non-wrapped error type.
func UnwrapError(err error) error {
Expand Down
4 changes: 0 additions & 4 deletions sql/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,10 +1820,6 @@ func (b *Builder) buildDBDDL(inScope *scope, c *ast.DBDDL) (outScope *scope) {
outScope = inScope.push()
switch strings.ToLower(c.Action) {
case ast.CreateStr:
if strings.ContainsRune(c.DBName, '/') {
b.handleErr(sql.ErrInvalidDatabaseName.New(c.DBName))
}

var charsetStr, collationStr string
if len(c.CharsetCollate) != 0 && b.ctx != nil && b.ctx.Session != nil {
b.ctx.Session.Warn(&sql.Warning{
Expand Down