Skip to content
14 changes: 14 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8876,6 +8876,20 @@ from typestable`,
{116, 97},
},
},
{
Query: "select char(i, i + 10, pi()) from mytable;",
Expected: []sql.Row{
{[]byte{0x01, 0x0B, 0x03}},
{[]byte{0x02, 0x0C, 0x03}},
{[]byte{0x03, 0x0D, 0x03}},
},
},
{
Query: "select char(97, 98, 99 using utf8mb4);",
Expected: []sql.Row{
{"abc"},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
140 changes: 140 additions & 0 deletions sql/expression/function/char.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"fmt"
"strings"

"github.com/dolthub/vitess/go/sqltypes"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

// Char implements the sql function "char" which returns the character for each integer passed
type Char struct {
args []sql.Expression
Collation sql.CollationID
}

var _ sql.FunctionExpression = (*Char)(nil)
var _ sql.CollationCoercible = (*Char)(nil)

func NewChar(args ...sql.Expression) (sql.Expression, error) {
return &Char{args: args}, nil
}

// FunctionName implements sql.FunctionExpression
func (c *Char) FunctionName() string {
return "char"
}

// Resolved implements sql.FunctionExpression
func (c *Char) Resolved() bool {
for _, arg := range c.args {
if !arg.Resolved() {
return false
}
}
return true
}

// String implements sql.Expression
func (c *Char) String() string {
args := make([]string, len(c.args))
for i, arg := range c.args {
args[i] = arg.String()
}
str := strings.Join(args, ", ")
return fmt.Sprintf("%s(%s)", c.FunctionName(), str)
}

// Type implements sql.Expression
func (c *Char) Type() sql.Type {
if c.Collation == sql.Collation_binary || c.Collation == sql.Collation_Unspecified {
return types.MustCreateString(sqltypes.VarBinary, int64(len(c.args)*4), sql.Collation_binary)
}
return types.MustCreateString(sqltypes.VarChar, int64(len(c.args)*16), c.Collation)
}

// IsNullable implements sql.Expression
func (c *Char) IsNullable() bool {
return true
}

// Description implements sql.FunctionExpression
func (c *Char) Description() string {
return "interprets each argument N as an integer and returns a string consisting of the characters given by the code values of those integers."
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

// char converts num into a byte array
// This function is essentially converting the number to base 256
func char(num uint32) []byte {
if num == 0 {
return []byte{}
}
return append(char(num>>8), byte(num&255))
}

// Eval implements the sql.Expression interface
func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
res := []byte{}
for _, arg := range c.args {
if arg == nil {
continue
}

val, err := arg.Eval(ctx, row)
if err != nil {
return nil, err
}

if val == nil {
continue
}

v, _, err := types.Uint32.Convert(val)
if err != nil {
ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", val)
res = append(res, 0)
continue
}

res = append(res, char(v.(uint32))...)
}

result, _, err := c.Type().Convert(res)
if err != nil {
return nil, err
}

return result, nil
}

// Children implements sql.Expression
func (c *Char) Children() []sql.Expression {
return c.args
}

// WithChildren implements the sql.Expression interface
func (c *Char) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return NewChar(children...)
}
176 changes: 176 additions & 0 deletions sql/expression/function/char_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"testing"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

func TestChar(t *testing.T) {
tests := []struct {
name string
args []sql.Expression
exp interface{}
err bool
skip bool
}{
{
name: "null",
args: []sql.Expression{
nil,
},
exp: []byte{},
},
{
name: "null literal",
args: []sql.Expression{
expression.NewLiteral(nil, types.Null),
},
exp: []byte{},
},
{
name: "nulls are skipped",
args: []sql.Expression{
expression.NewLiteral(int32(1), types.Int32),
expression.NewLiteral(nil, types.Null),
expression.NewLiteral(int32(300), types.Int32),
expression.NewLiteral(int32(4000), types.Int32),
},
exp: []byte{0x1, 0x01, 0x2c, 0xf, 0xa0},
},
{
name: "-1",
args: []sql.Expression{
expression.NewLiteral(int32(-1), types.Int32),
},
exp: []byte{0xff, 0xff, 0xff, 0xff},
},
{
name: "256",
args: []sql.Expression{
expression.NewLiteral(int32(256), types.Int32),
},
exp: []byte{0x1, 0x0},
},
{
name: "512",
args: []sql.Expression{
expression.NewLiteral(int32(512), types.Int32),
},
exp: []byte{0x2, 0x0},
},
{
name: "256 * 256",
args: []sql.Expression{
expression.NewLiteral(int32(256*256), types.Int32),
},
exp: []byte{0x1, 0x0, 0x0},
},
{
name: "1 2 3 4",
args: []sql.Expression{
expression.NewLiteral(int32(1), types.Int32),
expression.NewLiteral(int32(2), types.Int32),
expression.NewLiteral(int32(3), types.Int32),
expression.NewLiteral(int32(4), types.Int32),
},
exp: []byte{0x1, 0x2, 0x3, 0x4},
},
{
name: "1 20 300 4000",
args: []sql.Expression{
expression.NewLiteral(int32(1), types.Int32),
expression.NewLiteral(int32(20), types.Int32),
expression.NewLiteral(int32(300), types.Int32),
expression.NewLiteral(int32(4000), types.Int32),
},
exp: []byte{0x1, 0x14, 0x1, 0x2c, 0xf, 0xa0},
},
{
name: "float32 1.99",
args: []sql.Expression{
expression.NewLiteral(float32(1.99), types.Float32),
},
exp: []byte{0x2},
},
{
name: "float64 1.99",
args: []sql.Expression{
expression.NewLiteral(1.99, types.Float64),
},
exp: []byte{0x2},
},
{
name: "decimal 1.99",
args: []sql.Expression{
expression.NewLiteral(decimal.NewFromFloat(1.99), types.DecimalType_{}),
},
exp: []byte{0x2},
},
{
name: "good string",
args: []sql.Expression{
expression.NewLiteral("12", types.Text),
},
exp: []byte{0x0C},
},
{
name: "bad string",
args: []sql.Expression{
expression.NewLiteral("abc", types.Text),
},
exp: []byte{0x0},
},
{
name: "mix types",
args: []sql.Expression{
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(9999, types.Int32),
expression.NewLiteral(1.23, types.Int32),
expression.NewLiteral("78", types.Text),
expression.NewLiteral("abc", types.Text),
},
exp: []byte{0x01, 0x27, 0x0F, 0x01, 0x4E, 0x0},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip {
t.Skip()
}

ctx := sql.NewEmptyContext()
f, err := NewChar(tt.args...)
require.NoError(t, err)

res, err := f.Eval(ctx, nil)
if tt.err {
require.Error(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tt.exp, res)
})
}
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var BuiltIns = []sql.Function{
sql.Function1{Name: "bit_xor", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewBitXor(e) }},
sql.Function1{Name: "ceil", Fn: NewCeil},
sql.Function1{Name: "ceiling", Fn: NewCeil},
sql.FunctionN{Name: "char", Fn: NewChar},
sql.Function1{Name: "char_length", Fn: NewCharLength},
sql.Function1{Name: "character_length", Fn: NewCharLength},
sql.FunctionN{Name: "coalesce", Fn: NewCoalesce},
Expand Down
19 changes: 19 additions & 0 deletions sql/planbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) sql.Expression {
b.handleErr(err)
}
return expression.NewConvertUsing(expr, charset)
case *ast.CharExpr:
args := make([]sql.Expression, len(v.Exprs))
for i, e := range v.Exprs {
args[i] = b.selectExprToExpression(inScope, e)
}

f, err := function.NewChar(args...)
if err != nil {
b.handleErr(err)
}

collId, err := sql.ParseCollation(&v.Type, nil, true)
if err != nil {
b.handleErr(err)
}

charFunc := f.(*function.Char)
charFunc.Collation = collId
return charFunc
case *ast.ConvertExpr:
var err error
typeLength := 0
Expand Down
4 changes: 2 additions & 2 deletions sql/types/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -1031,14 +1031,14 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
return uint32(v), sql.InRange, nil
case int32:
if v < 0 {
return uint32(math.MaxUint32 + uint(v)), sql.OutOfRange, nil
return uint32(math.MaxUint32 - uint(-v-1)), sql.OutOfRange, nil
} else if int(v) > math.MaxUint32 {
return uint32(math.MaxUint32), sql.OutOfRange, nil
}
return uint32(v), sql.InRange, nil
case int64:
if v < 0 {
return uint32(math.MaxUint32 + uint(v)), sql.OutOfRange, nil
return uint32(math.MaxUint32 - uint(-v-1)), sql.OutOfRange, nil
} else if v > math.MaxUint32 {
return uint32(math.MaxUint32), sql.OutOfRange, nil
}
Expand Down