diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index c49057fe22..afd1614fd6 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -8876,6 +8876,20 @@ from typestable`, {116, 97}, }, }, + { + Query: "select char(i, i + 10, pi()) from mytable;", + Expected: []sql.Row{ + {[]byte{0x01, 0x0B, 0x03}}, + {[]byte{0x02, 0x0C, 0x03}}, + {[]byte{0x03, 0x0D, 0x03}}, + }, + }, + { + Query: "select char(97, 98, 99 using utf8mb4);", + Expected: []sql.Row{ + {"abc"}, + }, + }, } var KeylessQueries = []QueryTest{ diff --git a/sql/expression/function/char.go b/sql/expression/function/char.go new file mode 100644 index 0000000000..e034db7564 --- /dev/null +++ b/sql/expression/function/char.go @@ -0,0 +1,140 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + "strings" + + "github.com/dolthub/vitess/go/sqltypes" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Char implements the sql function "char" which returns the character for each integer passed +type Char struct { + args []sql.Expression + Collation sql.CollationID +} + +var _ sql.FunctionExpression = (*Char)(nil) +var _ sql.CollationCoercible = (*Char)(nil) + +func NewChar(args ...sql.Expression) (sql.Expression, error) { + return &Char{args: args}, nil +} + +// FunctionName implements sql.FunctionExpression +func (c *Char) FunctionName() string { + return "char" +} + +// Resolved implements sql.FunctionExpression +func (c *Char) Resolved() bool { + for _, arg := range c.args { + if !arg.Resolved() { + return false + } + } + return true +} + +// String implements sql.Expression +func (c *Char) String() string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + str := strings.Join(args, ", ") + return fmt.Sprintf("%s(%s)", c.FunctionName(), str) +} + +// Type implements sql.Expression +func (c *Char) Type() sql.Type { + if c.Collation == sql.Collation_binary || c.Collation == sql.Collation_Unspecified { + return types.MustCreateString(sqltypes.VarBinary, int64(len(c.args)*4), sql.Collation_binary) + } + return types.MustCreateString(sqltypes.VarChar, int64(len(c.args)*16), c.Collation) +} + +// IsNullable implements sql.Expression +func (c *Char) IsNullable() bool { + return true +} + +// Description implements sql.FunctionExpression +func (c *Char) Description() string { + return "interprets each argument N as an integer and returns a string consisting of the characters given by the code values of those integers." +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// char converts num into a byte array +// This function is essentially converting the number to base 256 +func char(num uint32) []byte { + if num == 0 { + return []byte{} + } + return append(char(num>>8), byte(num&255)) +} + +// Eval implements the sql.Expression interface +func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + res := []byte{} + for _, arg := range c.args { + if arg == nil { + continue + } + + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + continue + } + + v, _, err := types.Uint32.Convert(val) + if err != nil { + ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", val) + res = append(res, 0) + continue + } + + res = append(res, char(v.(uint32))...) + } + + result, _, err := c.Type().Convert(res) + if err != nil { + return nil, err + } + + return result, nil +} + +// Children implements sql.Expression +func (c *Char) Children() []sql.Expression { + return c.args +} + +// WithChildren implements the sql.Expression interface +func (c *Char) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewChar(children...) +} diff --git a/sql/expression/function/char_test.go b/sql/expression/function/char_test.go new file mode 100644 index 0000000000..100aa5e535 --- /dev/null +++ b/sql/expression/function/char_test.go @@ -0,0 +1,176 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestChar(t *testing.T) { + tests := []struct { + name string + args []sql.Expression + exp interface{} + err bool + skip bool + }{ + { + name: "null", + args: []sql.Expression{ + nil, + }, + exp: []byte{}, + }, + { + name: "null literal", + args: []sql.Expression{ + expression.NewLiteral(nil, types.Null), + }, + exp: []byte{}, + }, + { + name: "nulls are skipped", + args: []sql.Expression{ + expression.NewLiteral(int32(1), types.Int32), + expression.NewLiteral(nil, types.Null), + expression.NewLiteral(int32(300), types.Int32), + expression.NewLiteral(int32(4000), types.Int32), + }, + exp: []byte{0x1, 0x01, 0x2c, 0xf, 0xa0}, + }, + { + name: "-1", + args: []sql.Expression{ + expression.NewLiteral(int32(-1), types.Int32), + }, + exp: []byte{0xff, 0xff, 0xff, 0xff}, + }, + { + name: "256", + args: []sql.Expression{ + expression.NewLiteral(int32(256), types.Int32), + }, + exp: []byte{0x1, 0x0}, + }, + { + name: "512", + args: []sql.Expression{ + expression.NewLiteral(int32(512), types.Int32), + }, + exp: []byte{0x2, 0x0}, + }, + { + name: "256 * 256", + args: []sql.Expression{ + expression.NewLiteral(int32(256*256), types.Int32), + }, + exp: []byte{0x1, 0x0, 0x0}, + }, + { + name: "1 2 3 4", + args: []sql.Expression{ + expression.NewLiteral(int32(1), types.Int32), + expression.NewLiteral(int32(2), types.Int32), + expression.NewLiteral(int32(3), types.Int32), + expression.NewLiteral(int32(4), types.Int32), + }, + exp: []byte{0x1, 0x2, 0x3, 0x4}, + }, + { + name: "1 20 300 4000", + args: []sql.Expression{ + expression.NewLiteral(int32(1), types.Int32), + expression.NewLiteral(int32(20), types.Int32), + expression.NewLiteral(int32(300), types.Int32), + expression.NewLiteral(int32(4000), types.Int32), + }, + exp: []byte{0x1, 0x14, 0x1, 0x2c, 0xf, 0xa0}, + }, + { + name: "float32 1.99", + args: []sql.Expression{ + expression.NewLiteral(float32(1.99), types.Float32), + }, + exp: []byte{0x2}, + }, + { + name: "float64 1.99", + args: []sql.Expression{ + expression.NewLiteral(1.99, types.Float64), + }, + exp: []byte{0x2}, + }, + { + name: "decimal 1.99", + args: []sql.Expression{ + expression.NewLiteral(decimal.NewFromFloat(1.99), types.DecimalType_{}), + }, + exp: []byte{0x2}, + }, + { + name: "good string", + args: []sql.Expression{ + expression.NewLiteral("12", types.Text), + }, + exp: []byte{0x0C}, + }, + { + name: "bad string", + args: []sql.Expression{ + expression.NewLiteral("abc", types.Text), + }, + exp: []byte{0x0}, + }, + { + name: "mix types", + args: []sql.Expression{ + expression.NewLiteral(1, types.Int32), + expression.NewLiteral(9999, types.Int32), + expression.NewLiteral(1.23, types.Int32), + expression.NewLiteral("78", types.Text), + expression.NewLiteral("abc", types.Text), + }, + exp: []byte{0x01, 0x27, 0x0F, 0x01, 0x4E, 0x0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.Skip() + } + + ctx := sql.NewEmptyContext() + f, err := NewChar(tt.args...) + require.NoError(t, err) + + res, err := f.Eval(ctx, nil) + if tt.err { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.exp, res) + }) + } +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index c406763209..0f98a7be6b 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -50,6 +50,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "bit_xor", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewBitXor(e) }}, sql.Function1{Name: "ceil", Fn: NewCeil}, sql.Function1{Name: "ceiling", Fn: NewCeil}, + sql.FunctionN{Name: "char", Fn: NewChar}, sql.Function1{Name: "char_length", Fn: NewCharLength}, sql.Function1{Name: "character_length", Fn: NewCharLength}, sql.FunctionN{Name: "coalesce", Fn: NewCoalesce}, diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 70102496bc..185c7c3564 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -175,6 +175,25 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) sql.Expression { b.handleErr(err) } return expression.NewConvertUsing(expr, charset) + case *ast.CharExpr: + args := make([]sql.Expression, len(v.Exprs)) + for i, e := range v.Exprs { + args[i] = b.selectExprToExpression(inScope, e) + } + + f, err := function.NewChar(args...) + if err != nil { + b.handleErr(err) + } + + collId, err := sql.ParseCollation(&v.Type, nil, true) + if err != nil { + b.handleErr(err) + } + + charFunc := f.(*function.Char) + charFunc.Collation = collId + return charFunc case *ast.ConvertExpr: var err error typeLength := 0 diff --git a/sql/types/number.go b/sql/types/number.go index 607a4628d4..e7870fa07c 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1031,14 +1031,14 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan return uint32(v), sql.InRange, nil case int32: if v < 0 { - return uint32(math.MaxUint32 + uint(v)), sql.OutOfRange, nil + return uint32(math.MaxUint32 - uint(-v-1)), sql.OutOfRange, nil } else if int(v) > math.MaxUint32 { return uint32(math.MaxUint32), sql.OutOfRange, nil } return uint32(v), sql.InRange, nil case int64: if v < 0 { - return uint32(math.MaxUint32 + uint(v)), sql.OutOfRange, nil + return uint32(math.MaxUint32 - uint(-v-1)), sql.OutOfRange, nil } else if v > math.MaxUint32 { return uint32(math.MaxUint32), sql.OutOfRange, nil }