Skip to content

Commit 1105d65

Browse files
ti-chi-botDefined2014
authored andcommitted
expression: let cast function supports explicit set charset (pingcap#55724) (pingcap#56088)
close pingcap#55677
1 parent cc04dd7 commit 1105d65

File tree

11 files changed

+139
-18
lines changed

11 files changed

+139
-18
lines changed

expression/bench_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ func genVecBuiltinFuncBenchCase(ctx sessionctx.Context, funcName string, testCas
14521452
case types.ETJson:
14531453
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
14541454
case types.ETString:
1455-
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
1455+
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false}
14561456
}
14571457
baseFunc, err = fc.getFunction(ctx, cols)
14581458
} else if funcName == ast.GetVar {

expression/builtin.go

+30
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,36 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt
418418
}
419419
}
420420

421+
func newBaseBuiltinCastFunc4String(ctx sessionctx.Context, funcName string, args []Expression, tp *types.FieldType, isExplicitCharset bool) (baseBuiltinFunc, error) {
422+
var bf baseBuiltinFunc
423+
var err error
424+
if isExplicitCharset {
425+
bf = baseBuiltinFunc{
426+
bufAllocator: newLocalColumnPool(),
427+
childrenVectorizedOnce: new(sync.Once),
428+
429+
args: args,
430+
ctx: ctx,
431+
tp: tp,
432+
}
433+
bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate())
434+
bf.setCollator(collate.GetCollator(tp.GetCollate()))
435+
bf.SetCoercibility(CoercibilityExplicit)
436+
bf.SetExplicitCharset(true)
437+
if tp.GetCharset() == charset.CharsetASCII {
438+
bf.SetRepertoire(ASCII)
439+
} else {
440+
bf.SetRepertoire(UNICODE)
441+
}
442+
} else {
443+
bf, err = newBaseBuiltinFunc(ctx, funcName, args, tp)
444+
if err != nil {
445+
return baseBuiltinFunc{}, err
446+
}
447+
}
448+
return bf, nil
449+
}
450+
421451
// vecBuiltinFunc contains all vectorized methods for a builtin function.
422452
type vecBuiltinFunc interface {
423453
// vectorized returns if this builtin function itself supports vectorized evaluation.

expression/builtin_cast.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,15 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx sessionctx.Context, args []
272272
type castAsStringFunctionClass struct {
273273
baseFunctionClass
274274

275-
tp *types.FieldType
275+
tp *types.FieldType
276+
isExplicitCharset bool
276277
}
277278

278279
func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) {
279280
if err := c.verifyArgs(args); err != nil {
280281
return nil, err
281282
}
282-
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
283+
bf, err := newBaseBuiltinCastFunc4String(ctx, c.funcName, args, c.tp, c.isExplicitCharset)
283284
if err != nil {
284285
return nil, err
285286
}
@@ -2080,13 +2081,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp
20802081

20812082
// BuildCastFunction builds a CAST ScalarFunction from the Expression.
20822083
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
2083-
res, err := BuildCastFunctionWithCheck(ctx, expr, tp)
2084+
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false)
20842085
terror.Log(err)
20852086
return
20862087
}
20872088

20882089
// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
2089-
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) {
2090+
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType, isExplicitCharset bool) (res Expression, err error) {
20902091
argType := expr.GetType()
20912092
// If source argument's nullable, then target type should be nullable
20922093
if !mysql.HasNotNullFlag(argType.GetFlag()) {
@@ -2112,7 +2113,7 @@ func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *typ
21122113
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
21132114
}
21142115
case types.ETString:
2115-
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
2116+
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, isExplicitCharset}
21162117
if expr.GetType().GetType() == mysql.TypeBit {
21172118
tp.SetFlen((expr.GetType().GetFlen() + 7) / 8)
21182119
}

expression/builtin_cast_test.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ func TestCastFuncSig(t *testing.T) {
646646
tp := types.NewFieldType(mysql.TypeVarString)
647647
tp.SetCharset(charset.CharsetBin)
648648
args := []Expression{c.before}
649-
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
649+
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
650650
require.NoError(t, err)
651651
switch i {
652652
case 0:
@@ -732,7 +732,7 @@ func TestCastFuncSig(t *testing.T) {
732732
tp := types.NewFieldType(mysql.TypeVarString)
733733
tp.SetFlen(c.flen)
734734
tp.SetCharset(charset.CharsetBin)
735-
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
735+
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
736736
require.NoError(t, err)
737737
switch i {
738738
case 0:
@@ -1083,7 +1083,7 @@ func TestCastFuncSig(t *testing.T) {
10831083
// null case
10841084
args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}}
10851085
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)})
1086-
bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString))
1086+
bf, err := newBaseBuiltinCastFunc4String(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false)
10871087
require.NoError(t, err)
10881088
sig = &builtinCastRealAsStringSig{bf}
10891089
sRes, isNull, err := sig.evalString(row.ToRow())
@@ -1677,10 +1677,8 @@ func TestCastArrayFunc(t *testing.T) {
16771677
},
16781678
}
16791679
for _, tt := range tbl {
1680-
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp)
1681-
if tt.buildFuncSuccess {
1682-
require.NoError(t, err, tt.input)
1683-
} else {
1680+
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false)
1681+
if !tt.buildFuncSuccess {
16841682
require.Error(t, err, tt.input)
16851683
continue
16861684
}

expression/collation.go

+18-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ type collationInfo struct {
4444

4545
charset string
4646
collation string
47+
48+
isExplicitCharset bool
4749
}
4850

4951
func (c *collationInfo) HasCoercibility() bool {
@@ -76,6 +78,14 @@ func (c *collationInfo) CharsetAndCollation() (string, string) {
7678
return c.charset, c.collation
7779
}
7880

81+
func (c *collationInfo) IsExplicitCharset() bool {
82+
return c.isExplicitCharset
83+
}
84+
85+
func (c *collationInfo) SetExplicitCharset(explicit bool) {
86+
c.isExplicitCharset = explicit
87+
}
88+
7989
// CollationInfo contains all interfaces about dealing with collation.
8090
type CollationInfo interface {
8191
// HasCoercibility returns if the Coercibility value is initialized.
@@ -98,6 +108,12 @@ type CollationInfo interface {
98108

99109
// SetCharsetAndCollation sets charset and collation.
100110
SetCharsetAndCollation(chs, coll string)
111+
112+
// IsExplicitCharset return the charset is explicit set or not.
113+
IsExplicitCharset() bool
114+
115+
// SetExplicitCharset set the charset is explicit or not.
116+
SetExplicitCharset(bool)
101117
}
102118

103119
// Coercibility values are used to check whether the collation of one item can be coerced to
@@ -246,9 +262,8 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
246262
case ast.Cast:
247263
// We assume all the cast are implicit.
248264
ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().GetCharset(), args[0].GetType().GetCollate()}
249-
// Non-string type cast to string type should use @@character_set_connection and @@collation_connection.
250-
// String type cast to string type should keep its original charset and collation. It should not happen.
251-
if retType == types.ETString && argTps[0] != types.ETString {
265+
// Cast to string type should use @@character_set_connection and @@collation_connection.
266+
if retType == types.ETString {
252267
ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo()
253268
}
254269
return ec, nil

expression/scalar_function.go

+10
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,16 @@ func (sf *ScalarFunction) SetRepertoire(r Repertoire) {
626626
sf.Function.SetRepertoire(r)
627627
}
628628

629+
// IsExplicitCharset return the charset is explicit set or not.
630+
func (sf *ScalarFunction) IsExplicitCharset() bool {
631+
return sf.Function.IsExplicitCharset()
632+
}
633+
634+
// SetExplicitCharset set the charset is explicit or not.
635+
func (sf *ScalarFunction) SetExplicitCharset(explicit bool) {
636+
sf.Function.SetExplicitCharset(explicit)
637+
}
638+
629639
const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{}))
630640

631641
// MemoryUsage return the memory usage of ScalarFunction

expression/util.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,16 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
453453
}
454454
if substituted {
455455
flag := v.RetType.GetFlag()
456-
e := BuildCastFunction(v.GetCtx(), newArg, v.RetType)
456+
var e Expression
457+
var err error
458+
if v.FuncName.L == ast.Cast {
459+
e, err = BuildCastFunctionWithCheck(v.GetCtx(), newArg, v.RetType, v.Function.IsExplicitCharset())
460+
terror.Log(err)
461+
} else {
462+
// for grouping function recreation, use clone (meta included) instead of newFunction
463+
e = v.Clone()
464+
e.(*ScalarFunction).Function.getArgs()[0] = newArg
465+
}
457466
e.SetCoercibility(v.Coercibility())
458467
e.GetType().SetFlag(flag)
459468
return true, false, e

expression/util_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,8 @@ func (m *MockExpr) Coercibility() Coercibility
588588
func (m *MockExpr) SetCoercibility(Coercibility) {}
589589
func (m *MockExpr) Repertoire() Repertoire { return UNICODE }
590590
func (m *MockExpr) SetRepertoire(Repertoire) {}
591+
func (m *MockExpr) IsExplicitCharset() bool { return false }
592+
func (m *MockExpr) SetExplicitCharset(bool) {}
591593

592594
func (m *MockExpr) CharsetAndCollation() (string, string) {
593595
return "", ""

planner/core/expression_rewriter.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
12091209
return retNode, false
12101210
}
12111211

1212-
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp)
1212+
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, v.ExplicitCharSet)
12131213
if err != nil {
12141214
er.err = err
12151215
return retNode, false

tests/integrationtest/r/expression/cast.result

+30
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,33 @@ select 1.194192591e9 > t0.c0 from t0;
2828
select 1.194192591e9 < t0.c0 from t0;
2929
1.194192591e9 < t0.c0
3030
0
31+
drop table if exists test;
32+
CREATE TABLE `test` (
33+
`id` bigint(20) NOT NULL,
34+
`update_user` varchar(32) DEFAULT NULL,
35+
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
36+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
37+
insert into test values(1,'张三');
38+
insert into test values(2,'李四');
39+
insert into test values(3,'张三');
40+
insert into test values(4,'李四');
41+
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;
42+
id update_user
43+
1 张三
44+
3 张三
45+
2 李四
46+
drop table test;
47+
CREATE TABLE `test` (
48+
`id` bigint NOT NULL,
49+
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
50+
PRIMARY KEY (`id`)
51+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
52+
insert into test values(1,'张三');
53+
insert into test values(2,'李四');
54+
insert into test values(3,'张三');
55+
insert into test values(4,'李四');
56+
select * from test order by cast(update_user as char) desc , id limit 3;
57+
id update_user
58+
2 李四
59+
4 李四
60+
1 张三

tests/integrationtest/t/expression/cast.test

+26
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,29 @@ select t0.c0 > 1.194192591e9 from t0;
1111
select t0.c0 < 1.194192591e9 from t0;
1212
select 1.194192591e9 > t0.c0 from t0;
1313
select 1.194192591e9 < t0.c0 from t0;
14+
15+
# TestCastAsStringExplicitCharSet
16+
drop table if exists test;
17+
CREATE TABLE `test` (
18+
`id` bigint(20) NOT NULL,
19+
`update_user` varchar(32) DEFAULT NULL,
20+
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
21+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
22+
insert into test values(1,'张三');
23+
insert into test values(2,'李四');
24+
insert into test values(3,'张三');
25+
insert into test values(4,'李四');
26+
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;
27+
28+
drop table test;
29+
CREATE TABLE `test` (
30+
`id` bigint NOT NULL,
31+
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
32+
PRIMARY KEY (`id`)
33+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
34+
insert into test values(1,'张三');
35+
insert into test values(2,'李四');
36+
insert into test values(3,'张三');
37+
insert into test values(4,'李四');
38+
select * from test order by cast(update_user as char) desc , id limit 3;
39+

0 commit comments

Comments
 (0)