Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8732,6 +8732,22 @@ where
},
},
},
{
Name: "subquery with case insensitive collation",
Dialect: "mysql",
SetUpScript: []string{
"create table tbl (t text) collate=utf8mb4_0900_ai_ci;",
"insert into tbl values ('abcdef');",
},
Assertions: []ScriptTestAssertion{
{
Query: "select 'AbCdEf' in (select t from tbl);",
Expected: []sql.Row{
{true},
},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
6 changes: 3 additions & 3 deletions memory/table_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package memory

import (
"context"
"fmt"
"sort"
"strconv"
Expand All @@ -25,6 +24,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/hash"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/sql/types"
)
Expand Down Expand Up @@ -275,7 +275,7 @@ func (td *TableData) numRows(ctx *sql.Context) (uint64, error) {
}

// throws an error if any two or more rows share the same |cols| values.
func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string, idxName string) error {
func (td *TableData) errIfDuplicateEntryExist(ctx *sql.Context, cols []string, idxName string) error {
columnMapping, err := td.columnIndexes(cols)

// We currently skip validating duplicates on unique virtual columns.
Expand All @@ -297,7 +297,7 @@ func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string
if hasNulls(idxPrefixKey) {
continue
}
h, err := sql.HashOf(ctx, idxPrefixKey)
h, err := hash.HashOf(ctx, td.schema.Schema, idxPrefixKey)
if err != nil {
return err
}
Expand Down
37 changes: 0 additions & 37 deletions sql/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,12 @@
package sql

import (
"context"
"fmt"
"runtime"
"sync"

"github.com/cespare/xxhash/v2"

lru "github.com/hashicorp/golang-lru"
)

// HashOf returns a hash of the given value to be used as key in a cache.
func HashOf(ctx context.Context, v Row) (uint64, error) {
hash := digestPool.Get().(*xxhash.Digest)
hash.Reset()
defer digestPool.Put(hash)
for i, x := range v {
if i > 0 {
// separate each value in the row with a nil byte
if _, err := hash.Write([]byte{0}); err != nil {
return 0, err
}
}
x, err := UnwrapAny(ctx, x)
if err != nil {
return 0, err
}
// TODO: probably much faster to do this with a type switch
// TODO: we don't have the type info necessary to appropriately encode the value of a string with a non-standard
// collation, which means that two strings that differ only in their collations will hash to the same value.
// See rowexec/grouping_key()
if _, err := fmt.Fprintf(hash, "%v,", x); err != nil {
return 0, err
}
}
return hash.Sum64(), nil
}

var digestPool = sync.Pool{
New: func() any {
return xxhash.New()
},
}

// ErrKeyNotFound is returned when the key could not be found in the cache.
var ErrKeyNotFound = fmt.Errorf("memory: key not found in cache")

Expand Down
33 changes: 0 additions & 33 deletions sql/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package sql

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -178,35 +177,3 @@ func TestRowsCache(t *testing.T) {
require.True(freed)
})
}

func BenchmarkHashOf(b *testing.B) {
ctx := context.Background()
row := NewRow(1, "1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
sum, err := HashOf(ctx, row)
if err != nil {
b.Fatal(err)
}
if sum != 11268758894040352165 {
b.Fatalf("got %v", sum)
}
}
}

func BenchmarkParallelHashOf(b *testing.B) {
ctx := context.Background()
row := NewRow(1, "1")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sum, err := HashOf(ctx, row)
if err != nil {
b.Fatal(err)
}
if sum != 11268758894040352165 {
b.Fatalf("got %v", sum)
}
}
})
}
88 changes: 88 additions & 0 deletions sql/hash/hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package hash

import (
"fmt"
"sync"

"github.com/cespare/xxhash/v2"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

var digestPool = sync.Pool{
New: func() any {
return xxhash.New()
},
}

// HashOf returns a hash of the given value to be used as key in a cache.
func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
hash := digestPool.Get().(*xxhash.Digest)
hash.Reset()
defer digestPool.Put(hash)
for i, v := range row {
if i > 0 {
// separate each value in the row with a nil byte
if _, err := hash.Write([]byte{0}); err != nil {
return 0, err
}
}

v, err := sql.UnwrapAny(ctx, v)
if err != nil {
return 0, fmt.Errorf("error unwrapping value: %w", err)
}

// TODO: we may not always have the type information available, so we check schema length.
// Then, defer to original behavior
if i >= len(sch) || v == nil {
_, err := fmt.Fprintf(hash, "%v", v)
if err != nil {
return 0, err
}
continue
}

switch typ := sch[i].Type.(type) {
case types.ExtendedType:
// TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL,
// so we're using the old (probably incorrect) behavior for now
_, err = fmt.Fprintf(hash, "%v", v)
if err != nil {
return 0, err
}
case types.StringType:
var strVal string
strVal, err = types.ConvertToString(ctx, v, typ, nil)
if err != nil {
return 0, err
}
err = typ.Collation().WriteWeightString(hash, strVal)
if err != nil {
return 0, err
}
default:
// TODO: probably much faster to do this with a type switch
_, err = fmt.Fprintf(hash, "%v", v)
if err != nil {
return 0, err
}
}
}
return hash.Sum64(), nil
}
53 changes: 53 additions & 0 deletions sql/hash/hash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package hash

import (
"testing"

"github.com/dolthub/go-mysql-server/sql"
)

func BenchmarkHashOf(b *testing.B) {
ctx := sql.NewEmptyContext()
row := sql.NewRow(1, "1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
sum, err := HashOf(ctx, nil, row)
if err != nil {
b.Fatal(err)
}
if sum != 11268758894040352165 {
b.Fatalf("got %v", sum)
}
}
}

func BenchmarkParallelHashOf(b *testing.B) {
ctx := sql.NewEmptyContext()
row := sql.NewRow(1, "1")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sum, err := HashOf(ctx, nil, row)
if err != nil {
b.Fatal(err)
}
if sum != 11268758894040352165 {
b.Fatalf("got %v", sum)
}
}
})
}
20 changes: 10 additions & 10 deletions sql/iters/rel_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/hash"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -571,7 +572,7 @@ func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

hash, err := sql.HashOf(ctx, row)
hash, err := hash.HashOf(ctx, nil, row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -643,22 +644,21 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) {
ii.cache = make(map[uint64]int)
for {
res, err := ii.RIter.Next(ctx)
if err != nil && err != io.EOF {
if err != nil {
if err == io.EOF {
break
}
return nil, err
}

hash, herr := sql.HashOf(ctx, res)
hash, herr := hash.HashOf(ctx, nil, res)
if herr != nil {
return nil, herr
}
if _, ok := ii.cache[hash]; !ok {
ii.cache[hash] = 0
}
ii.cache[hash]++

if err == io.EOF {
break
}
}
ii.cached = true
}
Expand All @@ -669,7 +669,7 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

hash, herr := sql.HashOf(ctx, res)
hash, herr := hash.HashOf(ctx, nil, res)
if herr != nil {
return nil, herr
}
Expand Down Expand Up @@ -714,7 +714,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

hash, herr := sql.HashOf(ctx, res)
hash, herr := hash.HashOf(ctx, nil, res)
if herr != nil {
return nil, herr
}
Expand All @@ -736,7 +736,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

hash, herr := sql.HashOf(ctx, res)
hash, herr := hash.HashOf(ctx, nil, res)
if herr != nil {
return nil, herr
}
Expand Down
6 changes: 3 additions & 3 deletions sql/plan/hash_lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"fmt"
"sync"

"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/hash"
"github.com/dolthub/go-mysql-server/sql/types"
)

// NewHashLookup returns a node that performs an indexed hash lookup
Expand Down Expand Up @@ -127,7 +127,7 @@ func (n *HashLookup) GetHashKey(ctx *sql.Context, e sql.Expression, row sql.Row)
return nil, err
}
if s, ok := key.([]interface{}); ok {
return sql.HashOf(ctx, s)
return hash.HashOf(ctx, n.Schema(), s)
}
// byte slices are not hashable
if k, ok := key.([]byte); ok {
Expand Down
Loading