diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index a3013b6b83..7a5c401d1e 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -4889,6 +4889,9 @@ CREATE TABLE tab3 ( SetUpScript: []string{ "create table t (b bool);", "insert into t values (false);", + "create table t_idx (b bool);", + "create index idx on t_idx(b);", + "insert into t_idx values (false);", }, Assertions: []ScriptTestAssertion{ { @@ -4897,6 +4900,92 @@ CREATE TABLE tab3 ( {0}, }, }, + { + Query: "select * from t where (b in (false/'1'));", + Expected: []sql.Row{ + {0}, + }, + }, + { + Query: "select * from t_idx where (b in (-''));", + Expected: []sql.Row{ + {0}, + }, + }, + { + Query: "select * from t_idx where (b in (false/'1'));", + Expected: []sql.Row{ + {0}, + }, + }, + }, + }, + { + Name: "strings in tuple are properly hashed", + SetUpScript: []string{ + "create table t (v varchar(100));", + "insert into t values (false);", + "create table t_idx (v varchar(100));", + "create index idx on t_idx(v);", + "insert into t_idx values (false);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from t where (v in (-''));", + Expected: []sql.Row{ + {"0"}, + }, + }, + { + Query: "select * from t where (v in (false/'1'));", + Expected: []sql.Row{ + {"0"}, + }, + }, + { + Query: "select * from t_idx where (v in (-''));", + Expected: []sql.Row{ + {"0"}, + }, + }, + { + Query: "select * from t_idx where (v in (false/'1'));", + Expected: []sql.Row{ + {"0"}, + }, + }, + }, + }, + { + Name: "strings vs decimals with trailing 0s in IN exprs", + SetUpScript: []string{ + "create table t (v varchar(100));", + "insert into t values ('0'), ('0.0'), ('123'), ('123.0');", + "create table t_idx (v varchar(100));", + "create index idx on t_idx(v);", + "insert into t_idx values ('0'), ('0.0'), ('123'), ('123.0');", + }, + Assertions: []ScriptTestAssertion{ + { + Skip: true, + Query: "select * from t where (v in (0.0, 123));", + Expected: []sql.Row{ + {"0"}, + {"0.0"}, + {"123"}, + {"123.0"}, + }, + }, + { + Skip: true, + Query: "select * from t_idx where (v in (0.0, 123));", + Expected: []sql.Row{ + {"0"}, + {"0.0"}, + {"123"}, + {"123.0"}, + }, + }, }, }, { diff --git a/sql/expression/in.go b/sql/expression/in.go index 0d09084def..3c3c46ca2b 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -18,8 +18,6 @@ import ( "fmt" "strconv" - "github.com/cespare/xxhash/v2" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -250,46 +248,44 @@ func hashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) { return 0, nil } - // Collated strings that are equivalent may have different runes, so we must make them hash to the same value + var str string + coll := sql.Collation_Default if types.IsTextOnly(t) { - if str, ok := i.(string); ok { - return t.(sql.StringType).Collation().HashToUint(str) + coll = t.(sql.StringType).Collation() + if s, ok := i.(string); ok { + str = s } else { converted, err := convertOrTruncate(ctx, i, t) if err != nil { return 0, err } - return t.(sql.StringType).Collation().HashToUint(converted.(string)) + str = converted.(string) } } else { - hash := xxhash.New() x, err := convertOrTruncate(ctx, i, t.Promote()) if err != nil { return 0, err } // Remove trailing 0s from floats - var s string switch v := x.(type) { case float32: - s = strconv.FormatFloat(float64(v), 'f', -1, 32) - if s == "-0" { - s = "0" + str = strconv.FormatFloat(float64(v), 'f', -1, 32) + if str == "-0" { + str = "0" } case float64: - s = strconv.FormatFloat(v, 'f', -1, 64) - if s == "-0" { - s = "0" + str = strconv.FormatFloat(v, 'f', -1, 64) + if str == "-0" { + str = "0" } default: - s = fmt.Sprintf("%v", v) + str = fmt.Sprintf("%v", v) } - - if _, err := hash.Write([]byte(fmt.Sprintf("%s,", s))); err != nil { - return 0, err - } - return hash.Sum64(), nil } + + // Collated strings that are equivalent may have different runes, so we must make them hash to the same value + return coll.HashToUint(str) } // Eval implements the Expression interface. diff --git a/sql/types/strings.go b/sql/types/strings.go index 74a1fbd41d..9391a0abb5 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -330,8 +330,14 @@ func ConvertToString(v interface{}, t sql.StringType) (string, error) { } case float64: val = strconv.FormatFloat(s, 'f', -1, 64) + if val == "-0" { + val = "0" + } case float32: val = strconv.FormatFloat(float64(s), 'f', -1, 32) + if val == "-0" { + val = "0" + } case int: val = strconv.FormatInt(int64(s), 10) case int8: