diff --git a/README.md b/README.md index a80db8cba3..0859681244 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ func main() { Protocol: "tcp", Address: fmt.Sprintf("%s:%d", address, port), } - s, err := server.NewServer(config, engine, memory.NewSessionBuilder(pro), nil) + s, err := server.NewServer(config, engine, sql.NewContext, memory.NewSessionBuilder(pro), nil) if err != nil { panic(err) } diff --git a/_integration/go/go.mod b/_integration/go/go.mod index 6c0e25190b..66f989318b 100644 --- a/_integration/go/go.mod +++ b/_integration/go/go.mod @@ -1,8 +1,25 @@ module github.com/dolthub/go-mysql-server/integration/go -go 1.14 +go 1.22 + +toolchain go1.24.1 + +require ( + github.com/go-mysql-org/go-mysql v1.12.0 + github.com/go-sql-driver/mysql v1.7.1 +) require ( - github.com/go-sql-driver/mysql v1.4.0 - google.golang.org/appengine v1.2.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/Masterminds/semver v1.5.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/klauspost/compress v1.17.8 // indirect + github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect + github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 // indirect + github.com/pingcap/tidb/pkg/parser v0.0.0-20241118164214-4f047be191be // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect + golang.org/x/text v0.20.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/_integration/go/go.sum b/_integration/go/go.sum index 33d775ef12..fd58b7190f 100644 --- a/_integration/go/go.sum +++ b/_integration/go/go.sum @@ -1,7 +1,75 @@ -github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-mysql-org/go-mysql v1.12.0 h1:tyToNggfCfl11OY7GbWa2Fq3ofyScO9GY8b5f5wAmE4= +github.com/go-mysql-org/go-mysql v1.12.0/go.mod h1:/XVjs1GlT6NPSf13UgXLv/V5zMNricTCqeNaehSBghs= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb h1:3pSi4EDG6hg0orE1ndHkXvX6Qdq2cZn8gAPir8ymKZk= +github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= +github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 h1:2SOzvGvE8beiC1Y4g9Onkvu6UmuBBOeWRGQEjJaT/JY= +github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= +github.com/pingcap/tidb/pkg/parser v0.0.0-20241118164214-4f047be191be h1:t5EkCmZpxLCig5GQA0AZG47aqsuL5GTsJeeUD+Qfies= +github.com/pingcap/tidb/pkg/parser v0.0.0-20241118164214-4f047be191be/go.mod h1:Hju1TEWZvrctQKbztTRwXH7rd41Yq0Pgmq4PrEKcq7o= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -google.golang.org/appengine v1.2.0 h1:S0iUepdCWODXRvtE+gcRDd15L+k+k1AiHlMiMjefH24= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/_integration/go/mysql_test.go b/_integration/go/mysql_test.go index eb13aee95f..1b15ed95e0 100644 --- a/_integration/go/mysql_test.go +++ b/_integration/go/mysql_test.go @@ -19,6 +19,8 @@ import ( "reflect" "testing" + "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" _ "github.com/go-sql-driver/mysql" ) @@ -120,6 +122,101 @@ func TestGrafana(t *testing.T) { } } +func TestMySQLStreaming(t *testing.T) { + conn, err := client.Connect("127.0.0.1:3306", "root", "", "mydb") + if err != nil { + t.Fatalf("can't connect to mysql: %s", err) + } + defer func() { + err = conn.Close() + if err != nil { + t.Fatalf("error closing mysql connection: %s", err) + } + }() + + var result mysql.Result + var rows [][2]string + err = conn.ExecuteSelectStreaming("SELECT name, email FROM mytable ORDER BY name, email", &result, func(row []mysql.FieldValue) error { + if len(row) != 2 { + t.Fatalf("expected 2 columns, got %d", len(row)) + } + rows = append(rows, [2]string{row[0].String(), row[1].String()}) + return nil + }, nil) + + expected := [][2]string{ + {"Evil Bob", "evilbob@gmail.com"}, + {"Jane Doe", "jane@doe.com"}, + {"John Doe", "john@doe.com"}, + {"John Doe", "johnalt@doe.com"}, + } + + if len(expected) != len(rows) { + t.Errorf("got %d rows, expecting %d", len(rows), len(expected)) + } + + for i := range rows { + if rows[i][0] != expected[i][0] || rows[i][1] != expected[i][1] { + t.Errorf( + "incorrect row %d, got: {%s, %s}, expected: {%s, %s}", + i, + rows[i][0], rows[i][1], + expected[i][0], expected[i][1], + ) + } + } +} + +func TestMySQLStreamingPrepared(t *testing.T) { + conn, err := client.Connect("127.0.0.1:3306", "root", "", "mydb") + if err != nil { + t.Fatalf("can't connect to mysql: %s", err) + } + defer func() { + err = conn.Close() + if err != nil { + t.Fatalf("error closing mysql connection: %s", err) + } + }() + + stmt, err := conn.Prepare("SELECT name, email, ? FROM mytable ORDER BY name, email") + if err != nil { + t.Fatalf("error preparing statement: %s", err) + } + + var result mysql.Result + var rows [][3]string + err = stmt.ExecuteSelectStreaming(&result, func(row []mysql.FieldValue) error { + if len(row) != 3 { + t.Fatalf("expected 3 columns, got %d", len(row)) + } + rows = append(rows, [3]string{row[0].String(), row[1].String(), row[2].String()}) + return nil + }, nil, "abc") + + expected := [][3]string{ + {"Evil Bob", "evilbob@gmail.com", "abc"}, + {"Jane Doe", "jane@doe.com", "abc"}, + {"John Doe", "john@doe.com", "abc"}, + {"John Doe", "johnalt@doe.com", "abc"}, + } + + if len(expected) != len(rows) { + t.Errorf("got %d rows, expecting %d", len(rows), len(expected)) + } + + for i := range rows { + if rows[i][0] != expected[i][0] || rows[i][1] != expected[i][1] || rows[i][2] != expected[i][2] { + t.Errorf( + "incorrect row %d, got: {%s, %s, %s}, expected: {%s, %s, %s}", + i, + rows[i][0], rows[i][1], rows[i][2], + expected[i][0], expected[i][1], expected[i][2], + ) + } + } +} + func getResult(t *testing.T, rs *sql.Rows) [][]string { t.Helper() diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index d6fca96fd9..cde528279c 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -74,6 +74,20 @@ func TestQueries(t *testing.T, harness Harness) { }) } + for _, tt := range queries.FunctionQueryTests { + t.Run(tt.Query, func(t *testing.T) { + if sh, ok := harness.(SkippingHarness); ok { + if sh.SkipQueryTest(tt.Query) { + t.Skipf("Skipping query plan for %s", tt.Query) + } + } + if IsServerEngine(e) && tt.SkipServerEngine { + t.Skip("skipping for server engine") + } + TestQueryWithContext(t, ctx, e, harness, tt.Query, tt.Expected, tt.ExpectedColumns, nil, nil) + }) + } + // TODO: move this into its own test method if keyless, ok := harness.(KeylessTableHarness); ok && keyless.SupportsKeylessTables() { for _, tt := range queries.KeylessQueries { @@ -218,6 +232,17 @@ func TestQueriesPrepared(t *testing.T, harness Harness) { } }) + t.Run("function query prepared tests", func(t *testing.T) { + for _, tt := range queries.FunctionQueryTests { + if tt.SkipPrepared { + continue + } + t.Run(tt.Query, func(t *testing.T) { + TestPreparedQueryWithEngine(t, harness, e, tt) + }) + } + }) + t.Run("keyless prepared tests", func(t *testing.T) { harness.Setup(setup.MydbData, setup.KeylessData, setup.Keyless_idxData, setup.MytableData) for _, tt := range queries.KeylessQueries { @@ -487,6 +512,7 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) { for _, querySet := range [][]queries.QueryTest{ queries.QueryTests, + queries.FunctionQueryTests, queries.KeylessQueries, } { for _, tt := range querySet { @@ -496,7 +522,7 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) { for _, querySet := range [][]queries.WriteQueryTest{ queries.InsertQueries, - queries.UpdateTests, + queries.UpdateWriteQueryTests, queries.DeleteTests, queries.ReplaceQueries, } { @@ -1352,9 +1378,12 @@ func TestReplaceIntoErrors(t *testing.T, harness Harness) { func TestUpdate(t *testing.T, harness Harness) { harness.Setup(setup.MydbData, setup.MytableData, setup.Mytable_del_idxData, setup.FloattableData, setup.NiltableData, setup.TypestableData, setup.Pk_tablesData, setup.OthertableData, setup.TabletestData) - for _, tt := range queries.UpdateTests { + for _, tt := range queries.UpdateWriteQueryTests { RunWriteQueryTest(t, harness, tt) } + for _, tt := range queries.UpdateScriptTests { + TestScript(t, harness, tt) + } } func TestUpdateIgnore(t *testing.T, harness Harness) { @@ -1421,9 +1450,12 @@ func TestDelete(t *testing.T, harness Harness) { func TestUpdateQueriesPrepared(t *testing.T, harness Harness) { harness.Setup(setup.MydbData, setup.MytableData, setup.Mytable_del_idxData, setup.OthertableData, setup.TypestableData, setup.Pk_tablesData, setup.FloattableData, setup.NiltableData, setup.TabletestData) - for _, tt := range queries.UpdateTests { + for _, tt := range queries.UpdateWriteQueryTests { runWriteQueryTestPrepared(t, harness, tt) } + for _, tt := range queries.UpdateScriptTests { + TestScriptPrepared(t, harness, tt) + } } func TestDeleteQueriesPrepared(t *testing.T, harness Harness) { @@ -4112,7 +4144,7 @@ func TestVariables(t *testing.T, harness Harness) { }, { Query: "SET GLOBAL select_into_buffer_size = 9001", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT @@SESSION.select_into_buffer_size", @@ -4124,7 +4156,7 @@ func TestVariables(t *testing.T, harness Harness) { }, { Query: "SET @@GLOBAL.select_into_buffer_size = 9002", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT @@GLOBAL.select_into_buffer_size", @@ -4133,7 +4165,7 @@ func TestVariables(t *testing.T, harness Harness) { { // For boolean types, OFF/ON is converted Query: "SET @@GLOBAL.activate_all_roles_on_login = 'ON'", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT @@GLOBAL.activate_all_roles_on_login", @@ -4142,7 +4174,7 @@ func TestVariables(t *testing.T, harness Harness) { { // For non-boolean types, OFF/ON is not converted Query: "SET @@GLOBAL.delay_key_write = 'OFF'", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT @@GLOBAL.delay_key_write", @@ -4168,7 +4200,7 @@ func TestVariables(t *testing.T, harness Harness) { }, { Query: "SET GLOBAL select_into_buffer_size = 131072", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, } { t.Run(assertion.Query, func(t *testing.T) { @@ -5271,17 +5303,17 @@ func TestPersist(t *testing.T, harness Harness, newPersistableSess func(ctx *sql }{ { Query: "SET PERSIST max_connections = 1000;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, ExpectedGlobal: int64(1000), ExpectedPersist: int64(1000), }, { Query: "SET @@PERSIST.max_connections = 1000;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, ExpectedGlobal: int64(1000), ExpectedPersist: int64(1000), }, { Query: "SET PERSIST_ONLY max_connections = 1000;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, ExpectedGlobal: int64(151), ExpectedPersist: int64(1000), }, diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 4c2678a058..bfc799d4d2 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -84,14 +84,8 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q require.NoError(t, err, nil) t.Run(script.Name, func(t *testing.T) { - if sh, ok := harness.(SkippingHarness); ok { - if sh.SkipQueryTest(script.Name) { - t.Skip() - } - - if !supportedDialect(harness, script.Dialect) { - t.Skip() - } + if skipScript(harness, script, false) { + t.Skip() } for _, statement := range script.SetUpScript { @@ -126,7 +120,7 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q ctx = th.NewSession() } - if skipAssertion(t, harness, assertion) { + if skipAssertion(harness, assertion) { t.Skip() } @@ -161,30 +155,26 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q }) } -func skipAssertion(t *testing.T, harness Harness, assertion queries.ScriptTestAssertion) bool { - if sh, ok := harness.(SkippingHarness); ok && sh.SkipQueryTest(assertion.Query) { +func skipScript(harness Harness, script queries.ScriptTest, prepared bool) bool { + if sh, ok := harness.(SkippingHarness); ok && sh.SkipQueryTest(script.Name) { return true } - if !supportedDialect(harness, assertion.Dialect) { - return true - } + return script.Skip || !supportedDialect(harness, script.Dialect) || (prepared && script.SkipPrepared) +} - if assertion.Skip { +func skipAssertion(harness Harness, assertion queries.ScriptTestAssertion) bool { + if sh, ok := harness.(SkippingHarness); ok && sh.SkipQueryTest(assertion.Query) { return true } - return false + return assertion.Skip || !supportedDialect(harness, assertion.Dialect) } // TestScriptPrepared substitutes literals for bindvars, runs the test script given, // and makes any assertions given func TestScriptPrepared(t *testing.T, harness Harness, script queries.ScriptTest) bool { return t.Run(script.Name, func(t *testing.T) { - if script.SkipPrepared { - t.Skip() - } - e := mustNewEngine(t, harness) defer e.Close() TestScriptWithEnginePrepared(t, e, harness, script) @@ -194,6 +184,10 @@ func TestScriptPrepared(t *testing.T, harness Harness, script queries.ScriptTest // TestScriptWithEnginePrepared runs the test script with bindvars substituted for literals // using the engine provided. func TestScriptWithEnginePrepared(t *testing.T, e QueryEngine, harness Harness, script queries.ScriptTest) { + if skipScript(harness, script, true) { + t.Skip() + } + ctx := NewContext(harness) err := CreateNewConnectionForServerEngine(ctx, e) require.NoError(t, err, nil) @@ -223,13 +217,7 @@ func TestScriptWithEnginePrepared(t *testing.T, e QueryEngine, harness Harness, for _, assertion := range assertions { t.Run(assertion.Query, func(t *testing.T) { - - if sh, ok := harness.(SkippingHarness); ok { - if sh.SkipQueryTest(assertion.Query) { - t.Skip() - } - } - if assertion.Skip { + if skipAssertion(harness, assertion) { t.Skip() } diff --git a/enginetest/join_planning_tests.go b/enginetest/join_planning_tests.go index 3deccf8551..753bbec61b 100644 --- a/enginetest/join_planning_tests.go +++ b/enginetest/join_planning_tests.go @@ -28,6 +28,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/planbuilder" "github.com/dolthub/go-mysql-server/sql/transform" + "github.com/dolthub/go-mysql-server/sql/types" ) type JoinPlanTest struct { @@ -103,7 +104,7 @@ var JoinPlanningTests = []joinPlanScript{ }, { q: "set @@SESSION.disable_merge_join = 1", - exp: []sql.Row{{}}, + exp: []sql.Row{{types.NewOkResult(0)}}, }, { q: "select /*+ JOIN_ORDER(ab, xy) MERGE_JOIN(ab, xy)*/ * from ab join xy on y = a order by 1, 3", diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index cb5455eed8..ae8994fb7a 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -226,7 +226,7 @@ func TestSingleScript(t *testing.T) { for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) - harness.UseServer() + //harness.UseServer() engine, err := harness.NewEngine(t) if err != nil { panic(err) diff --git a/enginetest/queries/alter_table_queries.go b/enginetest/queries/alter_table_queries.go index a6103a8aff..62d235f651 100644 --- a/enginetest/queries/alter_table_queries.go +++ b/enginetest/queries/alter_table_queries.go @@ -135,7 +135,7 @@ var AlterTableScripts = []ScriptTest{ { Query: "SELECT * FROM information_schema.CHECK_CONSTRAINTS", Expected: []sql.Row{ - {"def", "mydb", "v1gt0", "(v1 > 0)"}, + {"def", "mydb", "v1gt0", "(`v1` > 0)"}, }, }, }, @@ -1033,16 +1033,258 @@ var AlterTableScripts = []ScriptTest{ Name: "alter table comments are escaped", SetUpScript: []string{ "create table t (i int);", - `alter table t modify column i int comment "newline \n | return \r | backslash \\ | NUL \0 \x00"`, - `alter table t add column j int comment "newline \n | return \r | backslash \\ | NUL \0 \x00"`, + `alter table t modify column i int comment "newline \n | return \r | backslash \\ | NUL \0 \x00 | ctrlz \Z \x1A"`, + `alter table t add column j int comment "newline \n | return \r | backslash \\ | NUL \0 \x00 | ctrlz \Z \x1A"`, }, Assertions: []ScriptTestAssertion{ { Query: "show create table t", Expected: []sql.Row{{ "t", - "CREATE TABLE `t` (\n `i` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00'," + - "\n `j` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00'\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + "CREATE TABLE `t` (\n `i` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00 | ctrlz \x1A x1A'," + + "\n `j` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00 | ctrlz \x1A x1A'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, + { + Name: "alter table supports non-escaped \\Z", + SetUpScript: []string{ + "create table t (i int);", + `alter table t modify column i int comment "ctrlz \Z \\Z"`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t", + Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n" + + " `i` int COMMENT 'ctrlz \x1A \\\\Z'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, + + // Enum tests + { + Name: "alter nil enum", + Dialect: "mysql", + SetUpScript: []string{ + "create table xy (x int primary key, y enum ('a', 'b'));", + "insert into xy values (0, NULL),(1, 'b')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table xy modify y enum('a','b','c')", + }, + { + Query: "alter table xy modify y enum('a')", + ExpectedErr: types.ErrDataTruncatedForColumn, + }, + }, + }, + { + Name: "alter keyless table", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (c1 int, c2 varchar(200), c3 enum('one', 'two'));", + "insert into t values (1, 'one', NULL);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `alter table t modify column c1 int unsigned`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "describe t;", + Expected: []sql.Row{ + {"c1", "int unsigned", "YES", "", nil, ""}, + {"c2", "varchar(200)", "YES", "", nil, ""}, + {"c3", "enum('one','two')", "YES", "", nil, ""}, + }, + }, + { + Query: `alter table t drop column c1;`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "describe t;", + Expected: []sql.Row{ + {"c2", "varchar(200)", "YES", "", nil, ""}, + {"c3", "enum('one','two')", "YES", "", nil, ""}, + }, + }, + { + Query: "alter table t add column new3 int;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: `insert into t values ('two', 'two', -2);`, + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "describe t;", + Expected: []sql.Row{ + {"c2", "varchar(200)", "YES", "", nil, ""}, + {"c3", "enum('one','two')", "YES", "", nil, ""}, + {"new3", "int", "YES", "", nil, ""}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{"one", nil, nil}, {"two", "two", -2}}, + }, + }, + }, + { + Name: "preserve enums through alter statements", + SetUpScript: []string{ + "create table t (i int primary key, e enum('a', 'b', 'c'));", + "insert ignore into t values (0, 'error');", + "insert into t values (1, 'a');", + "insert into t values (2, 'b');", + "insert into t values (3, 'c');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {0, "", float64(0)}, + {1, "a", float64(1)}, + {2, "b", float64(2)}, + {3, "c", float64(3)}, + }, + }, + { + Query: "alter table t modify column e enum('c', 'a', 'b');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {0, "", float64(0)}, + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(1)}, + }, + }, + { + Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {0, "", float64(0)}, + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(4)}, + }, + }, + { + Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c', 'd');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {0, "", float64(0)}, + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(4)}, + }, + }, + { + Query: "alter table t modify column e enum('a', 'b', 'c');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {0, "", float64(0)}, + {1, "a", float64(1)}, + {2, "b", float64(2)}, + {3, "c", float64(3)}, + }, + }, + { + Query: "alter table t modify column e enum('abc');", + ExpectedErr: types.ErrDataTruncatedForColumn, + }, + }, + }, + + // Set tests + { + Name: "modify set column", + SetUpScript: []string{ + "create table t (i int primary key, s set('a', 'b', 'c'));", + "insert ignore into t values (0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, s + 0, s from t;", + Expected: []sql.Row{ + {0, float64(0), ""}, + {1, float64(1), "a"}, + {2, float64(2), "b"}, + {3, float64(3), "a,b"}, + {4, float64(4), "c"}, + {5, float64(5), "a,c"}, + {6, float64(6), "b,c"}, + {7, float64(7), "a,b,c"}, + }, + }, + { + Query: "alter table t modify column s set('a', 'b', 'c', 'd');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, s + 0, s from t;", + Expected: []sql.Row{ + {0, float64(0), ""}, + {1, float64(1), "a"}, + {2, float64(2), "b"}, + {3, float64(3), "a,b"}, + {4, float64(4), "c"}, + {5, float64(5), "a,c"}, + {6, float64(6), "b,c"}, + {7, float64(7), "a,b,c"}, + }, + }, + { + Skip: true, + Query: "alter table t modify column s set('c', 'b', 'a');", + Expected: []sql.Row{ + {types.NewOkResult(8)}, // We currently return 0 RowsAffected + }, + }, + { + Skip: true, + Query: "select i, s + 0, s from t;", + Expected: []sql.Row{ + {0, 0, ""}, + {1, 2, "a"}, + {2, 4, "b"}, + {3, 6, "a,b"}, + {4, 1, "c"}, + {5, 3, "c,a"}, + {6, 5, "c,b"}, + {7, 7, "c,a,b"}, + }, + }, + { + Skip: true, + Query: "alter table t modify column s set('a');", + ExpectedErrStr: "Data truncated for column", // We currently throw value 2 is not valid for this set }, }, }, @@ -1214,6 +1456,54 @@ var AlterTableAddAutoIncrementScripts = []ScriptTest{ }, }, }, + { + Name: "ALTER AUTO INCREMENT TABLE ADD column", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key, uk int UNIQUE KEY auto_increment);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "alter table test add column j int;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + }, + }, + { + Name: "ALTER TABLE MODIFY column with compound UNIQUE KEYS", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE table test (pk int primary key, uk1 int, uk2 int, unique(uk1, uk2))", + "ALTER TABLE `test` MODIFY column uk1 int auto_increment", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "describe test", + Expected: []sql.Row{ + {"pk", "int", "NO", "PRI", nil, ""}, + {"uk1", "int", "NO", "MUL", nil, "auto_increment"}, + {"uk2", "int", "YES", "", nil, ""}, + }, + }, + }, + }, + { + Name: "ALTER TABLE MODIFY column with compound KEYS", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE table test (pk int primary key, mk1 int, mk2 int, index(mk1, mk2))", + "ALTER TABLE `test` MODIFY column mk1 int auto_increment", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "describe test", + Expected: []sql.Row{ + {"pk", "int", "NO", "PRI", nil, ""}, + {"mk1", "int", "NO", "MUL", nil, "auto_increment"}, + {"mk2", "int", "YES", "", nil, ""}, + }, + }, + }, + }, } var AddDropPrimaryKeyScripts = []ScriptTest{ @@ -1864,7 +2154,7 @@ var RenameColumnScripts = []ScriptTest{ Query: `SELECT TC.CONSTRAINT_NAME, CC.CHECK_CLAUSE, TC.ENFORCED FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'mytable' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK';`, - Expected: []sql.Row{{"test_check", "(i2 < 12345)", "YES"}}, + Expected: []sql.Row{{"test_check", "(`i2` < 12345)", "YES"}}, }, }, }, diff --git a/enginetest/queries/ansi_quotes_queries.go b/enginetest/queries/ansi_quotes_queries.go index 060160b01a..d9f7bb1c03 100644 --- a/enginetest/queries/ansi_quotes_queries.go +++ b/enginetest/queries/ansi_quotes_queries.go @@ -71,7 +71,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES and make sure we can still run queries Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: `select "data" from auctions order by "ai" desc;`, @@ -154,7 +154,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: `show create table view1;`, @@ -197,7 +197,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: `insert into t values (2, 'George', 'SomethingElse');`, @@ -237,7 +237,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // Assert the procedure runs correctly with ANSI_QUOTES mode disabled @@ -269,7 +269,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // Insert a row with ANSI_QUOTES mode disabled @@ -298,7 +298,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // Assert the check constraint runs correctly when ANSI_QUOTES mode is disabled @@ -328,7 +328,7 @@ var AnsiQuotesTests = []ScriptTest{ { // Disable ANSI_QUOTES mode and make sure we can still list and run events Query: `SET @@sql_mode='NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: `SHOW EVENTS;`, diff --git a/enginetest/queries/charset_collation_engine.go b/enginetest/queries/charset_collation_engine.go index e409a0cffc..8c5b8a6278 100644 --- a/enginetest/queries/charset_collation_engine.go +++ b/enginetest/queries/charset_collation_engine.go @@ -463,7 +463,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@session.character_set_connection = 'latin1';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@session.character_set_connection, @@session.collation_connection;", @@ -473,7 +473,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@session.collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@session.character_set_connection, @@session.collation_connection;", @@ -490,7 +490,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@global.character_set_connection = 'latin1';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@global.character_set_connection, @@global.collation_connection;", @@ -500,7 +500,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@global.collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@global.character_set_connection, @@global.collation_connection;", @@ -517,7 +517,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@session.character_set_server = 'latin1';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@session.character_set_server, @@session.collation_server;", @@ -527,7 +527,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@session.collation_server = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@session.character_set_server, @@session.collation_server;", @@ -544,7 +544,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@global.character_set_server = 'latin1';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@global.character_set_server, @@global.collation_server;", @@ -554,7 +554,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "set @@global.collation_server = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@global.character_set_server, @@global.collation_server;", @@ -605,6 +605,10 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ {int64(2), uint16(2)}, }, }, + { + Query: "create table t (e enum('abc', 'ABC') collate utf8mb4_0900_ai_ci))", + Error: true, + }, }, }, { @@ -696,7 +700,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "SET collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT COUNT(*) FROM test WHERE v1 LIKE 'ABC';", @@ -756,7 +760,7 @@ var CharsetCollationEngineTests = []CharsetCollationEngineTest{ }, { Query: "SET collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT 'abc' LIKE 'ABC';", diff --git a/enginetest/queries/charset_collation_wire.go b/enginetest/queries/charset_collation_wire.go index 9a2351feee..8e953dd029 100644 --- a/enginetest/queries/charset_collation_wire.go +++ b/enginetest/queries/charset_collation_wire.go @@ -476,7 +476,7 @@ var CharsetCollationWireTests = []CharsetCollationWireTest{ }, { Query: "SET collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT COUNT(*) FROM test WHERE v1 LIKE 'ABC';", @@ -536,7 +536,7 @@ var CharsetCollationWireTests = []CharsetCollationWireTest{ }, { Query: "SET collation_connection = 'utf8mb4_0900_bin';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT 'abc' LIKE 'ABC';", diff --git a/enginetest/queries/check_scripts.go b/enginetest/queries/check_scripts.go index bb7b6e89dd..32637b3b7a 100644 --- a/enginetest/queries/check_scripts.go +++ b/enginetest/queries/check_scripts.go @@ -29,7 +29,11 @@ var CreateCheckConstraintsScripts = []ScriptTest{ Query: `SELECT TC.CONSTRAINT_NAME, CC.CHECK_CLAUSE, TC.ENFORCED FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK';`, - Expected: []sql.Row{{"chk1", "(B > 0)", "YES"}, {"chk2", "(b > 0)", "NO"}, {"chk3", "(B > 1)", "YES"}, {"chk4", "(upper(C) = c)", "YES"}}, + Expected: []sql.Row{ + {"chk1", "(`B` > 0)", "YES"}, + {"chk2", "(`b` > 0)", "NO"}, + {"chk3", "(`B` > 1)", "YES"}, + {"chk4", "(upper(`C`) = `c`)", "YES"}}, }, }, }, @@ -40,9 +44,7 @@ WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.C }, Assertions: []ScriptTestAssertion{ { - Query: `SELECT LENGTH(TC.CONSTRAINT_NAME) > 0 -FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC -WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK' AND CC.CHECK_CLAUSE = '(b > 100)';`, + Query: "SELECT LENGTH(TC.CONSTRAINT_NAME) > 0 FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK' AND CC.CHECK_CLAUSE = '(`b` > 100)';", Expected: []sql.Row{{true}}, }, }, @@ -66,7 +68,13 @@ CREATE TABLE T2 Query: `SELECT CC.CHECK_CLAUSE FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 't2' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK';`, - Expected: []sql.Row{{"(c1 = c2)"}, {"(c1 > 10)"}, {"(c2 > 0)"}, {"(c3 < 100)"}, {"(c1 = 0)"}, {"(C1 > C3)"}}, + Expected: []sql.Row{ + {"(`c1` = `c2`)"}, + {"(`c1` > 10)"}, + {"(`c2` > 0)"}, + {"(`c3` < 100)"}, + {"(`c1` = 0)"}, + {"(`C1` > `C3`)"}}, }, }, }, @@ -256,8 +264,8 @@ CREATE TABLE t4 { Query: "SELECT * from information_schema.check_constraints where constraint_name IN ('mycheck', 'hcheck') ORDER BY constraint_name", Expected: []sql.Row{ - {"def", "mydb", "hcheck", "(height < 10)"}, - {"def", "mydb", "mycheck", "(test_score >= 50)"}, + {"def", "mydb", "hcheck", "(`height` < 10)"}, + {"def", "mydb", "mycheck", "(`test_score` >= 50)"}, }, }, { @@ -318,6 +326,36 @@ CREATE TABLE t4 }, }, }, + { + Name: "check constraints using keywords", + SetUpScript: []string{ + "create table t (`order` int primary key, constraint chk check (`order` > 0));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (0);", + ExpectedErr: sql.ErrCheckConstraintViolated, + }, + { + Query: "insert into t values (100);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {100}, + }, + }, + { + Query: "show create table t;", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n `order` int NOT NULL,\n PRIMARY KEY (`order`),\n CONSTRAINT `chk` CHECK ((`order` > 0))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + }, + }, } var DropCheckConstraintsScripts = []ScriptTest{ @@ -336,7 +374,7 @@ var DropCheckConstraintsScripts = []ScriptTest{ Query: `SELECT TC.CONSTRAINT_NAME, CC.CHECK_CLAUSE, TC.ENFORCED FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 't1' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK';`, - Expected: []sql.Row{{"chk3", "(c > 0)", "YES"}}, + Expected: []sql.Row{{"chk3", "(`c` > 0)", "YES"}}, }, }, }, @@ -495,7 +533,7 @@ var ChecksOnUpdateScriptTests = []ScriptTest{ }, }, { - Name: "Update join updates", + Name: "Update join - single table", SetUpScript: []string{ "CREATE TABLE sales (year_built int primary key, CONSTRAINT `valid_year_built` CHECK (year_built <= 2022));", "INSERT INTO sales VALUES (1981);", @@ -535,6 +573,45 @@ var ChecksOnUpdateScriptTests = []ScriptTest{ }, }, }, + { + Name: "Update join - multiple tables", + SetUpScript: []string{ + "CREATE TABLE sales (year_built int primary key, CONSTRAINT `valid_year_built` CHECK (year_built <= 2022));", + "INSERT INTO sales VALUES (1981);", + "CREATE TABLE locations (state char(2) primary key, CONSTRAINT `state` CHECK (state != 'GA'));", + "INSERT INTO locations VALUES ('WA');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE sales JOIN locations SET sales.year_built = 2000, locations.state = 'GA';", + ExpectedErr: sql.ErrCheckConstraintViolated, + }, + { + Query: "UPDATE sales JOIN locations SET sales.year_built = 2025, locations.state = 'CA';", + ExpectedErr: sql.ErrCheckConstraintViolated, + }, + { + Query: "select * from sales;", + Expected: []sql.Row{{1981}}, + }, + { + Query: "select * from locations;", + Expected: []sql.Row{{"WA"}}, + }, + { + Query: "UPDATE sales JOIN locations SET sales.year_built = 2000, locations.state = 'CA';", + Expected: []sql.Row{{types.OkResult{2, 0, plan.UpdateInfo{2, 2, 0}}}}, + }, + { + Query: "select * from sales;", + Expected: []sql.Row{{2000}}, + }, + { + Query: "select * from locations;", + Expected: []sql.Row{{"CA"}}, + }, + }, + }, } var DisallowedCheckConstraintsScripts = []ScriptTest{ diff --git a/enginetest/queries/column_default_queries.go b/enginetest/queries/column_default_queries.go index 23e93f132b..8ab0b9222e 100644 --- a/enginetest/queries/column_default_queries.go +++ b/enginetest/queries/column_default_queries.go @@ -940,4 +940,120 @@ var ColumnDefaultTests = []ScriptTest{ }, }, }, + { + Name: "User variables in column defaults are not allowed", + SetUpScript: []string{ + "set @a = 1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t(i int DEFAULT (@a));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int DEFAULT ((@a)));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int DEFAULT (@a + 1));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "System variables in column defaults are not allowed", + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t(i int DEFAULT (@@version));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int DEFAULT (@@session.sql_mode));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int DEFAULT (@@global.max_connections));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "User variables in generated columns are not allowed", + SetUpScript: []string{ + "set @a = 1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t(i int GENERATED ALWAYS AS (@a));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int GENERATED ALWAYS AS (@a + 1));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "System variables in generated columns are not allowed", + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t(i int GENERATED ALWAYS AS (@@version));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + { + Query: "CREATE TABLE t(i int GENERATED ALWAYS AS (@@session.sql_mode));", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "User variables in ALTER TABLE ADD COLUMN defaults are not allowed", + SetUpScript: []string{ + "CREATE TABLE t(pk int PRIMARY KEY);", + "set @a = 1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER TABLE t ADD COLUMN i int DEFAULT (@a);", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "System variables in ALTER TABLE ADD COLUMN defaults are not allowed", + SetUpScript: []string{ + "CREATE TABLE t(pk int PRIMARY KEY);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER TABLE t ADD COLUMN i int DEFAULT (@@version);", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "User variables in ALTER TABLE ALTER COLUMN defaults are not allowed", + SetUpScript: []string{ + "CREATE TABLE t(pk int PRIMARY KEY, i int);", + "set @a = 1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER TABLE t ALTER COLUMN i SET DEFAULT (@a);", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, + { + Name: "System variables in ALTER TABLE ALTER COLUMN defaults are not allowed", + SetUpScript: []string{ + "CREATE TABLE t(pk int PRIMARY KEY, i int);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER TABLE t ALTER COLUMN i SET DEFAULT (@@version);", + ExpectedErr: sql.ErrColumnDefaultUserVariable, + }, + }, + }, } diff --git a/enginetest/queries/create_table_queries.go b/enginetest/queries/create_table_queries.go index 0a0ba057b0..48ba0b2d49 100644 --- a/enginetest/queries/create_table_queries.go +++ b/enginetest/queries/create_table_queries.go @@ -53,10 +53,16 @@ var CreateTableQueries = []WriteQueryTest{ ExpectedSelect: []sql.Row{{"tableWithComment", "CREATE TABLE `tableWithComment` (\n `pk` int\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin COMMENT=''''"}}, }, { - WriteQuery: `create table tableWithComment (pk int) COMMENT "newline \n | return \r | backslash \\ | NUL \0 \x00"`, + WriteQuery: `create table tableWithComment (pk int) COMMENT "newline \n | return \r | backslash \\ | NUL \0 \x00 | ctrlz \Z \x1A"`, ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, SelectQuery: "SHOW CREATE TABLE tableWithComment", - ExpectedSelect: []sql.Row{{"tableWithComment", "CREATE TABLE `tableWithComment` (\n `pk` int\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin COMMENT='newline \\n | return \\r | backslash \\\\ | NUL \\0 x00'"}}, + ExpectedSelect: []sql.Row{{"tableWithComment", "CREATE TABLE `tableWithComment` (\n `pk` int\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin COMMENT='newline \\n | return \\r | backslash \\\\ | NUL \\0 x00 | ctrlz \x1A x1A'"}}, + }, + { + WriteQuery: `create table tableWithComment (pk int) COMMENT "ctrlz \Z \x1A \\Z \\\Z"`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: "SHOW CREATE TABLE tableWithComment", + ExpectedSelect: []sql.Row{{"tableWithComment", "CREATE TABLE `tableWithComment` (\n `pk` int\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin COMMENT='ctrlz \x1A x1A \\\\Z \\\\\x1A'"}}, }, { WriteQuery: `create table tableWithColumnComment (pk int COMMENT "'")`, @@ -71,10 +77,10 @@ var CreateTableQueries = []WriteQueryTest{ ExpectedSelect: []sql.Row{{"tableWithColumnComment", "CREATE TABLE `tableWithColumnComment` (\n `pk` int COMMENT ''''\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, }, { - WriteQuery: `create table tableWithColumnComment (pk int COMMENT "newline \n | return \r | backslash \\ | NUL \0 \x00")`, + WriteQuery: `create table tableWithColumnComment (pk int COMMENT "newline \n | return \r | backslash \\ | NUL \0 \x00 | ctrlz \Z \x1A")`, ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, SelectQuery: "SHOW CREATE TABLE tableWithColumnComment", - ExpectedSelect: []sql.Row{{"tableWithColumnComment", "CREATE TABLE `tableWithColumnComment` (\n `pk` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00'\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + ExpectedSelect: []sql.Row{{"tableWithColumnComment", "CREATE TABLE `tableWithColumnComment` (\n `pk` int COMMENT 'newline \\n | return \\r | backslash \\\\ | NUL \\0 x00 | ctrlz \x1A x1A'\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, }, { WriteQuery: `create table floattypedefs (a float(10), b float(10, 2), c double(10, 2))`, diff --git a/enginetest/queries/foreign_key_queries.go b/enginetest/queries/foreign_key_queries.go index 1f26a03c81..fe45f845a3 100644 --- a/enginetest/queries/foreign_key_queries.go +++ b/enginetest/queries/foreign_key_queries.go @@ -1485,7 +1485,7 @@ var ForeignKeyTests = []ScriptTest{ }, { Query: "SET FOREIGN_KEY_CHECKS=0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "TRUNCATE parent;", @@ -1497,7 +1497,7 @@ var ForeignKeyTests = []ScriptTest{ }, { Query: "SET FOREIGN_KEY_CHECKS=1;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "INSERT INTO child VALUES (4, 5, 6);", @@ -2777,7 +2777,7 @@ var CreateForeignKeyTests = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "SET FOREIGN_KEY_CHECKS=0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "CREATE TABLE child4 (pk BIGINT PRIMARY KEY, CONSTRAINT fk_child4 FOREIGN KEY (pk) REFERENCES delayed_parent4 (pk))", diff --git a/enginetest/queries/function_queries.go b/enginetest/queries/function_queries.go new file mode 100644 index 0000000000..46197a1f4c --- /dev/null +++ b/enginetest/queries/function_queries.go @@ -0,0 +1,1031 @@ +// Copyright 2020-2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queries + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// FunctionQueryTests contains queries that primarily test SQL function calls +var FunctionQueryTests = []QueryTest{ + // String Functions + { + Query: `SELECT CONCAT("a", "b", "c")`, + Expected: []sql.Row{ + {string("abc")}, + }, + }, + { + Query: `SELECT INSERT("Quadratic", 3, 4, "What")`, + Expected: []sql.Row{ + {string("QuWhattic")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 2, "xyz")`, + Expected: []sql.Row{ + {string("hxyzlo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 2, "xyz")`, + Expected: []sql.Row{ + {string("xyzllo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 5, 1, "xyz")`, + Expected: []sql.Row{ + {string("hellxyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 5, "world")`, + Expected: []sql.Row{ + {string("world")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, 10, "world")`, + Expected: []sql.Row{ + {string("heworld")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 2, "")`, + Expected: []sql.Row{ + {string("hlo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, 0, "xyz")`, + Expected: []sql.Row{ + {string("hexyzllo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 0, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("hello", -1, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, -1, "xyz")`, + Expected: []sql.Row{ + {string("xyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, -1, "xyz")`, + Expected: []sql.Row{ + {string("hexyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 100, "xyz")`, + Expected: []sql.Row{ + {string("hxyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 50, "world")`, + Expected: []sql.Row{ + {string("world")}, + }, + }, + { + Query: `SELECT INSERT("hello", 10, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("", 1, 2, "xyz")`, + Expected: []sql.Row{ + {string("")}, + }, + }, + { + Query: `SELECT INSERT(NULL, 1, 2, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", NULL, 2, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, NULL, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 2, NULL)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT COALESCE(NULL, NULL, NULL, 'example', NULL, 1234567890)`, + Expected: []sql.Row{ + {string("example")}, + }, + }, + { + Query: `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, + Expected: []sql.Row{ + {int32(1234567890)}, + }, + }, + { + Query: "SELECT COALESCE (NULL, NULL)", + Expected: []sql.Row{{nil}}, + ExpectedColumns: []*sql.Column{ + { + Name: "COALESCE (NULL, NULL)", + Type: types.Null, + }, + }, + }, + { + Query: `SELECT COALESCE(CAST('{"a": "one \\n two"}' as json), '');`, + Expected: []sql.Row{ + {"{\"a\": \"one \\n two\"}"}, + }, + }, + { + Query: "SELECT concat(s, i) FROM mytable", + Expected: []sql.Row{ + {string("first row1")}, + {string("second row2")}, + {string("third row3")}, + }, + }, + { + Query: `SELECT INSERT(s, 1, 5, "new") FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("new row")}, + {string("newd row")}, + {string("new row")}, + }, + }, + { + Query: `SELECT INSERT(s, i, 2, "XY") FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("XYrst row")}, + {string("sXYond row")}, + {string("thXYd row")}, + }, + }, + { + Query: `SELECT INSERT(s, i + 1, i, UPPER(s)) FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("fFIRST ROWrst row")}, + {string("seSECOND ROWnd row")}, + {string("thiTHIRD ROWrow")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "Y", "N", ",", 4)`, + Expected: []sql.Row{ + {string("Y,N,Y,N")}, + }, + }, + { + Query: `SELECT EXPORT_SET(6, "1", "0", ",", 10)`, + Expected: []sql.Row{ + {string("0,1,1,0,0,0,0,0,0,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(0, "1", "0", ",", 4)`, + Expected: []sql.Row{ + {string("0,0,0,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(15, "1", "0", ",", 4)`, + Expected: []sql.Row{ + {string("1,1,1,1")}, + }, + }, + { + Query: `SELECT EXPORT_SET(1, "T", "F", ",", 3)`, + Expected: []sql.Row{ + {string("T,F,F")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0", "|", 4)`, + Expected: []sql.Row{ + {string("1|0|1|0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0", "", 4)`, + Expected: []sql.Row{ + {string("1010")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "ON", "OFF", ",", 4)`, + Expected: []sql.Row{ + {string("ON,OFF,ON,OFF")}, + }, + }, + { + Query: `SELECT EXPORT_SET(255, "1", "0", ",", 8)`, + Expected: []sql.Row{ + {string("1,1,1,1,1,1,1,1")}, + }, + }, + { + Query: `SELECT EXPORT_SET(1024, "1", "0", ",", 12)`, + Expected: []sql.Row{ + {string("0,0,0,0,0,0,0,0,0,0,1,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0")`, + Expected: []sql.Row{ + {string("1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0", ",", 1)`, + Expected: []sql.Row{ + {string("1")}, + }, + }, + { + Query: `SELECT EXPORT_SET(-1, "1", "0", ",", 4)`, + Expected: []sql.Row{ + {string("1,1,1,1")}, + }, + }, + { + Query: `SELECT EXPORT_SET(NULL, "1", "0", ",", 4)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, NULL, "0", ",", 4)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", NULL, ",", 4)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0", NULL, 4)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT EXPORT_SET(5, "1", "0", ",", NULL)`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT EXPORT_SET("5", "1", "0", ",", 4)`, + Expected: []sql.Row{ + {string("1,0,1,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(5.7, "1", "0", ",", 4)`, + Expected: []sql.Row{ + {string("0,1,1,0")}, + }, + }, + { + Query: `SELECT EXPORT_SET(i, "1", "0", ",", 4) FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("1,0,0,0")}, + {string("0,1,0,0")}, + {string("1,1,0,0")}, + }, + }, + { + Query: `SELECT MAKE_SET(1, "a", "b", "c")`, + Expected: []sql.Row{ + {string("a")}, + }, + }, + { + Query: `SELECT MAKE_SET(1 | 4, "hello", "nice", "world")`, + Expected: []sql.Row{ + {string("hello,world")}, + }, + }, + { + Query: `SELECT MAKE_SET(0, "a", "b", "c")`, + Expected: []sql.Row{ + {string("")}, + }, + }, + { + Query: `SELECT MAKE_SET(3, "a", "b", "c")`, + Expected: []sql.Row{ + {string("a,b")}, + }, + }, + { + Query: `SELECT MAKE_SET(5, "a", "b", "c")`, + Expected: []sql.Row{ + {string("a,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(7, "a", "b", "c")`, + Expected: []sql.Row{ + {string("a,b,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(1024, "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k")`, + Expected: []sql.Row{ + {string("k")}, + }, + }, + { + Query: `SELECT MAKE_SET(1025, "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k")`, + Expected: []sql.Row{ + {string("a,k")}, + }, + }, + { + Query: `SELECT MAKE_SET(7, "a", NULL, "c")`, + Expected: []sql.Row{ + {string("a,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(7, NULL, "b", "c")`, + Expected: []sql.Row{ + {string("b,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(NULL, "a", "b", "c")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT MAKE_SET("5", "a", "b", "c")`, + Expected: []sql.Row{ + {string("a,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(5.7, "a", "b", "c")`, + Expected: []sql.Row{ + {string("b,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(-1, "a", "b", "c")`, + Expected: []sql.Row{ + {string("a,b,c")}, + }, + }, + { + Query: `SELECT MAKE_SET(16, "a", "b", "c")`, + Expected: []sql.Row{ + {string("")}, + }, + }, + { + Query: `SELECT MAKE_SET(3, "", "test", "")`, + Expected: []sql.Row{ + {string(",test")}, + }, + }, + { + Query: `SELECT MAKE_SET(i, "first", "second", "third") FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("first")}, + {string("second")}, + {string("first,second")}, + }, + }, + { + Query: "SELECT version()", + Expected: []sql.Row{ + {"8.0.31"}, + }, + }, + { + Query: `SELECT RAND(100)`, + Expected: []sql.Row{ + {float64(0.8165026937796166)}, + }, + }, + { + Query: `SELECT RAND(i) from mytable order by i`, + Expected: []sql.Row{{0.6046602879796196}, {0.16729663442585624}, {0.7199826688373036}}, + }, + { + Query: `SELECT RAND(100) = RAND(100)`, + Expected: []sql.Row{ + {true}, + }, + }, + { + Query: `SELECT RAND() = RAND()`, + Expected: []sql.Row{ + {false}, + }, + }, + { + Query: "SELECT MOD(i, 2) from mytable order by i limit 1", + Expected: []sql.Row{ + {"1"}, + }, + }, + { + Query: "SELECT SIN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.8414709848078965}, + }, + }, + { + Query: "SELECT COS(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.5403023058681398}, + }, + }, + { + Query: "SELECT TAN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {1.557407724654902}, + }, + }, + { + Query: "SELECT ASIN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {1.5707963267948966}, + }, + }, + { + Query: "SELECT ACOS(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.0}, + }, + }, + { + Query: "SELECT ATAN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.7853981633974483}, + }, + }, + { + Query: "SELECT COT(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.6420926159343308}, + }, + }, + { + Query: "SELECT DEGREES(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {57.29577951308232}, + }, + }, + { + Query: "SELECT RADIANS(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {0.017453292519943295}, + }, + }, + { + Query: "SELECT CRC32(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {uint64(0x83dcefb7)}, + }, + }, + { + Query: "SELECT SIGN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "SELECT ASCII(s) from mytable order by i limit 1", + Expected: []sql.Row{ + {uint64(0x66)}, + }, + }, + { + Query: "SELECT HEX(s) from mytable order by i limit 1", + Expected: []sql.Row{ + {"666972737420726F77"}, + }, + }, + { + Query: "SELECT UNHEX(s) from mytable order by i limit 1", + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: "SELECT BIN(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {"1"}, + }, + }, + { + Query: "SELECT BIT_LENGTH(i) from mytable order by i limit 1", + Expected: []sql.Row{ + {64}, + }, + }, + { + Query: "select date_format(datetime_col, '%D') from datetime_table order by 1", + Expected: []sql.Row{ + {"1st"}, + {"4th"}, + {"7th"}, + }, + }, + { + Query: "select time_format(time_col, '%h%p') from datetime_table order by 1", + Expected: []sql.Row{ + {"03AM"}, + {"03PM"}, + {"04AM"}, + }, + }, + { + Query: "select from_unixtime(i) from mytable order by 1", + Expected: []sql.Row{ + {UnixTimeInLocal(1, 0)}, + {UnixTimeInLocal(2, 0)}, + {UnixTimeInLocal(3, 0)}, + }, + }, + + // FORMAT Function Tests + { + Query: `SELECT FORMAT(val, 2) FROM + (values row(4328904), row(432053.4853), row(5.93288775208e+08), row("5784029.372"), row(-4229842.122), row(-0.009)) a (val)`, + Expected: []sql.Row{ + {"4,328,904.00"}, + {"432,053.49"}, + {"593,288,775.21"}, + {"5,784,029.37"}, + {"-4,229,842.12"}, + {"-0.01"}, + }, + }, + { + Query: "SELECT FORMAT(i, 3) FROM mytable;", + Expected: []sql.Row{ + {"1.000"}, + {"2.000"}, + {"3.000"}, + }, + }, + { + Query: `SELECT FORMAT(val, 2, 'da_DK') FROM + (values row(4328904), row(432053.4853), row(5.93288775208e+08), row("5784029.372"), row(-4229842.122), row(-0.009)) a (val)`, + Expected: []sql.Row{ + {"4.328.904,00"}, + {"432.053,49"}, + {"593.288.775,21"}, + {"5.784.029,37"}, + {"-4.229.842,12"}, + {"-0,01"}, + }, + }, + { + Query: "SELECT FORMAT(i, 3, 'da_DK') FROM mytable;", + Expected: []sql.Row{ + {"1,000"}, + {"2,000"}, + {"3,000"}, + }, + }, + + // Date/Time Function Tests + { + Query: "SELECT DATEDIFF(date_col, '2019-12-28') FROM datetime_table where date_col = date('2019-12-31T12:00:00');", + Expected: []sql.Row{ + {3}, + }, + }, + { + Query: `SELECT DATEDIFF(val, '2019/12/28') FROM + (values row('2017-11-30 22:59:59'), row('2020/01/02'), row('2021-11-30'), row('2020-12-31T12:00:00')) a (val)`, + Expected: []sql.Row{ + {-758}, + {5}, + {703}, + {369}, + }, + }, + { + Query: "SELECT TIMESTAMPDIFF(SECOND,'2007-12-31 23:59:58', '2007-12-31 00:00:00');", + Expected: []sql.Row{ + {-86398}, + }, + }, + { + Query: `SELECT TIMESTAMPDIFF(MINUTE, val, '2019/12/28') FROM + (values row('2017-11-30 22:59:59'), row('2020/01/02'), row('2019-12-27 23:15:55'), row('2019-12-31T12:00:00')) a (val);`, + Expected: []sql.Row{ + {1090140}, + {-7200}, + {44}, + {-5040}, + }, + }, + { + Query: "SELECT TIMEDIFF(null, '2017-11-30 22:59:59');", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT DATEDIFF('2019/12/28', null);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT TIMESTAMPDIFF(SECOND, null, '2007-12-31 00:00:00');", + Expected: []sql.Row{{nil}}, + }, + + // TRIM Function Tests + { + Query: `SELECT TRIM(mytable.s) AS s FROM mytable`, + Expected: []sql.Row{{"first row"}, {"second row"}, {"third row"}}, + }, + { + Query: `SELECT TRIM("row" from mytable.s) AS s FROM mytable`, + Expected: []sql.Row{{"first "}, {"second "}, {"third "}}, + }, + { + Query: `SELECT TRIM(mytable.s from "first row") AS s FROM mytable`, + Expected: []sql.Row{{""}, {"first row"}, {"first row"}}, + }, + { + Query: `SELECT TRIM(mytable.s from mytable.s) AS s FROM mytable`, + Expected: []sql.Row{{""}, {""}, {""}}, + }, + { + Query: `SELECT TRIM(" foo ")`, + Expected: []sql.Row{{"foo"}}, + }, + { + Query: `SELECT TRIM(" " FROM " foo ")`, + Expected: []sql.Row{{"foo"}}, + }, + { + Query: `SELECT TRIM(LEADING " " FROM " foo ")`, + Expected: []sql.Row{{"foo "}}, + }, + { + Query: `SELECT TRIM(TRAILING " " FROM " foo ")`, + Expected: []sql.Row{{" foo"}}, + }, + { + Query: `SELECT TRIM(BOTH " " FROM " foo ")`, + Expected: []sql.Row{{"foo"}}, + }, + { + Query: `SELECT TRIM("" FROM " foo")`, + Expected: []sql.Row{{" foo"}}, + }, + { + Query: `SELECT TRIM("bar" FROM "barfoobar")`, + Expected: []sql.Row{{"foo"}}, + }, + { + Query: `SELECT TRIM(TRAILING "bar" FROM "barfoobar")`, + Expected: []sql.Row{{"barfoo"}}, + }, + { + Query: `SELECT TRIM(TRAILING "foo" FROM "foo")`, + Expected: []sql.Row{{""}}, + }, + { + Query: `SELECT TRIM(LEADING "ooo" FROM TRIM("oooo"))`, + Expected: []sql.Row{{"o"}}, + }, + { + Query: `SELECT TRIM(BOTH "foo" FROM TRIM("barfoobar"))`, + Expected: []sql.Row{{"barfoobar"}}, + }, + { + Query: `SELECT TRIM(LEADING "bar" FROM TRIM("foobar"))`, + Expected: []sql.Row{{"foobar"}}, + }, + { + Query: `SELECT TRIM(TRAILING "oo" FROM TRIM("oof"))`, + Expected: []sql.Row{{"oof"}}, + }, + { + Query: `SELECT TRIM(LEADING "test" FROM TRIM(" test "))`, + Expected: []sql.Row{{""}}, + }, + { + Query: `SELECT TRIM(LEADING CONCAT("a", "b") FROM TRIM("ababab"))`, + Expected: []sql.Row{{""}}, + }, + { + Query: `SELECT TRIM(TRAILING CONCAT("a", "b") FROM CONCAT("test","ab"))`, + Expected: []sql.Row{{"test"}}, + }, + { + Query: `SELECT TRIM(LEADING 1 FROM "11111112")`, + Expected: []sql.Row{{"2"}}, + }, + { + Query: `SELECT TRIM(LEADING 1 FROM 11111112)`, + Expected: []sql.Row{{"2"}}, + }, + + // SUBSTRING_INDEX Function Tests + { + Query: `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', 2)`, + Expected: []sql.Row{ + {"a.b"}, + }, + }, + { + Query: `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', -2)`, + Expected: []sql.Row{ + {"e.f"}, + }, + }, + { + Query: `SELECT SUBSTRING_INDEX(SUBSTRING_INDEX('source{d}', '{d}', 1), 'r', -1)`, + Expected: []sql.Row{ + {"ce"}, + }, + }, + { + Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS s FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY 1 HAVING s = 'secon';`, + Expected: []sql.Row{{"secon"}}, + }, + { + Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS s FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY s HAVING s = 'secon';`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS ss FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY s HAVING s = 'secon';`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS ss FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY ss HAVING ss = 'secon';`, + Expected: []sql.Row{ + {"secon"}, + }, + }, + + // INET Function Tests + { + Query: `SELECT INET_ATON("10.0.5.10")`, + Expected: []sql.Row{{uint64(167773450)}}, + }, + { + Query: `SELECT INET_NTOA(167773450)`, + Expected: []sql.Row{{"10.0.5.10"}}, + }, + { + Query: `SELECT INET_ATON("10.0.5.11")`, + Expected: []sql.Row{{uint64(167773451)}}, + }, + { + Query: `SELECT INET_NTOA(167773451)`, + Expected: []sql.Row{{"10.0.5.11"}}, + }, + { + Query: `SELECT INET_NTOA(INET_ATON("12.34.56.78"))`, + Expected: []sql.Row{{"12.34.56.78"}}, + }, + { + Query: `SELECT INET_ATON(INET_NTOA("12345678"))`, + Expected: []sql.Row{{uint64(12345678)}}, + }, + { + Query: `SELECT INET_ATON("notanipaddress")`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT INET_NTOA("spaghetti")`, + Expected: []sql.Row{{"0.0.0.0"}}, + }, + + // INET6 Function Tests + { + Query: `SELECT HEX(INET6_ATON("10.0.5.9"))`, + Expected: []sql.Row{{"0A000509"}}, + }, + { + Query: `SELECT HEX(INET6_ATON("::10.0.5.9"))`, + Expected: []sql.Row{{"0000000000000000000000000A000509"}}, + }, + { + Query: `SELECT HEX(INET6_ATON("1.2.3.4"))`, + Expected: []sql.Row{{"01020304"}}, + }, + { + Query: `SELECT HEX(INET6_ATON("fdfe::5455:caff:fefa:9098"))`, + Expected: []sql.Row{{"FDFE0000000000005455CAFFFEFA9098"}}, + }, + { + Query: `SELECT HEX(INET6_ATON("1111:2222:3333:4444:5555:6666:7777:8888"))`, + Expected: []sql.Row{{"11112222333344445555666677778888"}}, + }, + { + Query: `SELECT INET6_ATON("notanipaddress")`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("1234ffff5678ffff1234ffff5678ffff"))`, + Expected: []sql.Row{{"1234:ffff:5678:ffff:1234:ffff:5678:ffff"}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("ffffffff"))`, + Expected: []sql.Row{{"255.255.255.255"}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("000000000000000000000000ffffffff"))`, + Expected: []sql.Row{{"::255.255.255.255"}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("00000000000000000000ffffffffffff"))`, + Expected: []sql.Row{{"::ffff:255.255.255.255"}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("0000000000000000000000000000ffff"))`, + Expected: []sql.Row{{"::ffff"}}, + }, + { + Query: `SELECT INET6_NTOA(UNHEX("00000000000000000000000000000000"))`, + Expected: []sql.Row{{"::"}}, + }, + { + Query: `SELECT INET6_NTOA("notanipaddress")`, + Expected: []sql.Row{{nil}}, + }, + + // IS_IPV4/IS_IPV6 Function Tests + { + Query: `SELECT IS_IPV4("10.0.1.10")`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT IS_IPV4("::10.0.1.10")`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4("notanipaddress")`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV6("10.0.1.10")`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV6("::10.0.1.10")`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT IS_IPV6("notanipaddress")`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("10.0.1.10"))`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("::10.0.1.10"))`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("::ffff:10.0.1.10"))`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("notanipaddress"))`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("10.0.1.10"))`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("::10.0.1.10"))`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("::ffff:10.0.1.10"))`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("notanipaddress"))`, + Expected: []sql.Row{{nil}}, + }, + + // Additional Date/Time Function Tests + { + Query: "SELECT YEAR('2007-12-11') FROM mytable", + Expected: []sql.Row{{int32(2007)}, {int32(2007)}, {int32(2007)}}, + }, + { + Query: "SELECT MONTH('2007-12-11') FROM mytable", + Expected: []sql.Row{{int32(12)}, {int32(12)}, {int32(12)}}, + }, + { + Query: "SELECT DAY('2007-12-11') FROM mytable", + Expected: []sql.Row{{int32(11)}, {int32(11)}, {int32(11)}}, + }, + { + Query: "SELECT HOUR('2007-12-11 20:21:22') FROM mytable", + Expected: []sql.Row{{int32(20)}, {int32(20)}, {int32(20)}}, + }, + { + Query: "SELECT MINUTE('2007-12-11 20:21:22') FROM mytable", + Expected: []sql.Row{{int32(21)}, {int32(21)}, {int32(21)}}, + }, + { + Query: "SELECT SECOND('2007-12-11 20:21:22') FROM mytable", + Expected: []sql.Row{{int32(22)}, {int32(22)}, {int32(22)}}, + }, + { + Query: "SELECT DAYOFYEAR('2007-12-11 20:21:22') FROM mytable", + Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, + }, + { + Query: "SELECT SECOND('2007-12-11T20:21:22Z') FROM mytable", + Expected: []sql.Row{{int32(22)}, {int32(22)}, {int32(22)}}, + }, + { + Query: "SELECT DAYOFYEAR('2007-12-11') FROM mytable", + Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, + }, + { + Query: "SELECT DAYOFYEAR('20071211') FROM mytable", + Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, + }, + { + Query: "SELECT YEARWEEK('0000-01-01')", + Expected: []sql.Row{{int32(1)}}, + }, + { + Query: "SELECT YEARWEEK('9999-12-31')", + Expected: []sql.Row{{int32(999952)}}, + }, + { + Query: "SELECT YEARWEEK('2008-02-20', 1)", + Expected: []sql.Row{{int32(200808)}}, + }, + { + Query: "SELECT YEARWEEK('1987-01-01')", + Expected: []sql.Row{{int32(198652)}}, + }, + { + Query: "SELECT YEARWEEK('1987-01-01', 20), YEARWEEK('1987-01-01', 1), YEARWEEK('1987-01-01', 2), YEARWEEK('1987-01-01', 3), YEARWEEK('1987-01-01', 4), YEARWEEK('1987-01-01', 5), YEARWEEK('1987-01-01', 6), YEARWEEK('1987-01-01', 7)", + Expected: []sql.Row{{int32(198653), int32(198701), int32(198652), int32(198701), int32(198653), int32(198652), int32(198653), int32(198652)}}, + }, + + // Additional String Function Tests + { + Query: `SELECT CHAR_LENGTH('áé'), LENGTH('àè')`, + Expected: []sql.Row{{int32(2), int32(4)}}, + }, + { + Query: `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, + Expected: []sql.Row{{"6789"}}, + }, +} diff --git a/enginetest/queries/generated_columns.go b/enginetest/queries/generated_columns.go index 03485cd304..f44c268df5 100644 --- a/enginetest/queries/generated_columns.go +++ b/enginetest/queries/generated_columns.go @@ -65,6 +65,14 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 where b = 5 order by a", Expected: []sql.Row{{4, 5}}, }, + { + Query: "insert into t1 values (5, DEFAULT)", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from t1 where a = 5", + Expected: []sql.Row{{5, 6}}, + }, { Query: "update t1 set b = b + 1", ExpectedErr: sql.ErrGeneratedColumnValue, @@ -75,7 +83,7 @@ var GeneratedColumnTests = []ScriptTest{ }, { Query: "select * from t1 order by a", - Expected: []sql.Row{{2, 3}, {3, 4}, {4, 5}, {10, 11}}, + Expected: []sql.Row{{2, 3}, {3, 4}, {4, 5}, {5, 6}, {10, 11}}, }, { Query: "delete from t1 where b = 11", @@ -83,7 +91,80 @@ var GeneratedColumnTests = []ScriptTest{ }, { Query: "select * from t1 order by a", - Expected: []sql.Row{{2, 3}, {3, 4}, {4, 5}}, + Expected: []sql.Row{{2, 3}, {3, 4}, {4, 5}, {5, 6}}, + }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{4}}, + }, + }, + }, + { + Name: "generated column with DEFAULT in UPDATE clause (issue #9438)", + SetUpScript: []string{ + "create table t (i int primary key, j int generated always as (i + 10))", + "insert into t (i) values (1), (2), (3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from t order by i", + Expected: []sql.Row{{1, 11}, {2, 12}, {3, 13}}, + }, + { + Query: "update t set j = default", + Expected: []sql.Row{{NewUpdateResult(3, 0)}}, // 3 rows matched, 0 changed (values already correct) + }, + { + Query: "select * from t order by i", + Expected: []sql.Row{{1, 11}, {2, 12}, {3, 13}}, // Values should remain the same + }, + { + Query: "update t set i = 5 where i = 1", // This should update both i and j (through generation) + Expected: []sql.Row{{NewUpdateResult(1, 1)}}, + }, + { + Query: "select * from t order by i", + Expected: []sql.Row{{2, 12}, {3, 13}, {5, 15}}, // j should be updated to i + 10 = 15 + }, + { + Query: "update t set j = default where i = 5", // Explicit DEFAULT on specific row + Expected: []sql.Row{{NewUpdateResult(1, 0)}}, // 1 row matched, 0 changed (value already correct) + }, + { + Query: "select * from t where i = 5", + Expected: []sql.Row{{5, 15}}, // Value should still be correct + }, + { + Query: "update t set j = 99", // Should still fail for non-DEFAULT values + ExpectedErr: sql.ErrGeneratedColumnValue, + }, + }, + }, + { + Name: "generated column with DEFAULT in VALUES clause (issue #9428)", + SetUpScript: []string{ + "create table t (i int generated always as (1 + 1))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (default)", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from t", + Expected: []sql.Row{{2}}, + }, + { + Query: "insert into t values (default), (default)", + Expected: []sql.Row{{types.NewOkResult(2)}}, + }, + { + Query: "select * from t order by i", + Expected: []sql.Row{{2}, {2}, {2}}, + }, + { + Query: "insert into t values (5)", + ExpectedErr: sql.ErrGeneratedColumnValue, }, }, }, @@ -94,9 +175,15 @@ var GeneratedColumnTests = []ScriptTest{ "INSERT INTO t16 (pk) VALUES (1), (2)", "ALTER TABLE t16 ADD COLUMN v2 BIGINT AS (5) STORED FIRST", }, - Assertions: []ScriptTestAssertion{{ - Query: "SELECT * FROM t16", - Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t16", + Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}, + }, + { + Query: "select count(*) from t16", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -106,9 +193,15 @@ var GeneratedColumnTests = []ScriptTest{ "INSERT INTO t17 VALUES (1, 3), (2, 4)", "ALTER TABLE t17 ADD COLUMN v2 BIGINT AS (v1 + 2) STORED FIRST", }, - Assertions: []ScriptTestAssertion{{ - Query: "SELECT * FROM t17", - Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t17", + Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}, + }, + { + Query: "select count(*) from t17", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -198,6 +291,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 order by b", Expected: []sql.Row{{1, 2}, {2, 3}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -270,6 +367,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 order by b", Expected: []sql.Row{{1, 2, 3, 4}, {2, 3, 4, 5}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -350,6 +451,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 order by b", Expected: []sql.Row{{1, 2}, {2, 3}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -503,6 +608,10 @@ var GeneratedColumnTests = []ScriptTest{ " PRIMARY KEY (`a`)\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{3}}, + }, }, }, { @@ -539,6 +648,10 @@ var GeneratedColumnTests = []ScriptTest{ {1, 2, 4}, }, }, + { + Query: "select count(*) from t", + Expected: []sql.Row{{1}}, + }, { Query: "alter table tt add column `col 3` int generated always as (`col 1` + `col 2` + pow(`col 1`, `col 2`)) stored;", Expected: []sql.Row{ @@ -567,6 +680,10 @@ var GeneratedColumnTests = []ScriptTest{ {1, 2, 4}, }, }, + { + Query: "select count(*) from tt", + Expected: []sql.Row{{1}}, + }, }, }, { @@ -603,6 +720,10 @@ var GeneratedColumnTests = []ScriptTest{ {1, 2, 4}, }, }, + { + Query: "select count(*) from t", + Expected: []sql.Row{{1}}, + }, { Query: "alter table tt add column `col 3` int generated always as (`col 1` + `col 2` + pow(`col 1`, `col 2`)) virtual;", Expected: []sql.Row{ @@ -631,6 +752,10 @@ var GeneratedColumnTests = []ScriptTest{ {1, 2, 4}, }, }, + { + Query: "select count(*) from tt", + Expected: []sql.Row{{1}}, + }, }, }, { @@ -640,9 +765,15 @@ var GeneratedColumnTests = []ScriptTest{ "INSERT INTO t16 (pk) VALUES (1), (2)", "ALTER TABLE t16 ADD COLUMN v2 BIGINT AS (5) VIRTUAL FIRST", }, - Assertions: []ScriptTestAssertion{{ - Query: "SELECT * FROM t16", - Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t16", + Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}, + }, + { + Query: "select count(*) from t16", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -652,9 +783,15 @@ var GeneratedColumnTests = []ScriptTest{ "INSERT INTO t17 VALUES (1, 3), (2, 4)", "ALTER TABLE t17 ADD COLUMN v2 BIGINT AS (v1 + 2) VIRTUAL FIRST", }, - Assertions: []ScriptTestAssertion{{ - Query: "SELECT * FROM t17", - Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t17", + Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}, + }, + { + Query: "SELECT count(*) FROM t17", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -795,6 +932,14 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t2 order by c", Expected: []sql.Row{{1, 0}, {2, 1}, {3, 2}, {6, 5}, {7, 6}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{5}}, + }, + { + Query: "select count(*) from t2", + Expected: []sql.Row{{5}}, + }, }, }, { @@ -814,6 +959,10 @@ var GeneratedColumnTests = []ScriptTest{ {2, types.MustJSON(`{"a": 1}`), nil}, {3, types.MustJSON(`{"b": "300"}`), 300}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{3}}, + }, }, }, { @@ -834,6 +983,10 @@ var GeneratedColumnTests = []ScriptTest{ {"ghi", "", "ghi"}, }, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{3}}, + }, }, }, { @@ -874,6 +1027,10 @@ var GeneratedColumnTests = []ScriptTest{ {2, 3, 4, 5}, }, }, + { + Query: "select count(*) from t", + Expected: []sql.Row{{3}}, + }, }, }, { @@ -951,6 +1108,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 order by a", Expected: []sql.Row{{1, 2, 3}, {3, 4, 7}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -1015,6 +1176,10 @@ var GeneratedColumnTests = []ScriptTest{ {3, 4, 7}, }, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, { Query: "select * from t1 where c = 6", Expected: []sql.Row{ @@ -1044,6 +1209,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 where v = 2", Expected: []sql.Row{{"{\"a\": 2}", 2}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{3}}, + }, { Query: "update t1 set j = '{\"a\": 5}' where v = 2", Expected: []sql.Row{{NewUpdateResult(1, 1)}}, @@ -1140,6 +1309,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "select * from t1 order by b", Expected: []sql.Row{{1, 2, 3, 4}, {2, 3, 4, 5}}, }, + { + Query: "select count(*) from t1", + Expected: []sql.Row{{2}}, + }, }, }, { @@ -1224,6 +1397,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "insert into t2 (a) values (1), (2)", Expected: []sql.Row{{types.NewOkResult(2)}}, }, + { + Query: "select count(*) from t2", + Expected: []sql.Row{{2}}, + }, { Query: "select * from t2 order by a", Expected: []sql.Row{ @@ -1241,6 +1418,10 @@ var GeneratedColumnTests = []ScriptTest{ Query: "insert into t3 (a) values (1), (2)", Expected: []sql.Row{{types.NewOkResult(2)}}, }, + { + Query: "select count(*) from t3", + Expected: []sql.Row{{2}}, + }, { Query: "select * from t3 order by a", Expected: []sql.Row{ @@ -1256,6 +1437,45 @@ var GeneratedColumnTests = []ScriptTest{ }, }, }, + { + // https://github.com/dolthub/dolt/issues/8968 + Name: "can select all columns from table with generated column", + SetUpScript: []string{ + "create table t(pk int primary key, j1 json)", + `insert into t values (1, '{"name": "foo"}')`, + "alter table t add column g1 varchar(100) generated always as (json_unquote(json_extract(`j1`, '$.name')))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from t", + Expected: []sql.Row{{1, `{"name":"foo"}`, "foo"}}, + }, + { + Query: "select pk, j1, g1 from t", + Expected: []sql.Row{{1, `{"name":"foo"}`, "foo"}}, + }, + { + Query: "select pk, g1 from t", + Expected: []sql.Row{{1, "foo"}}, + }, + { + Query: "select g1 from t", + Expected: []sql.Row{{"foo"}}, + }, + { + Query: "select j1, g1 from t", + Expected: []sql.Row{{`{"name":"foo"}`, "foo"}}, + }, + { + Query: "select j1 from t", + Expected: []sql.Row{{`{"name":"foo"}`}}, + }, + { + Query: "select pk, j1 from t", + Expected: []sql.Row{{1, `{"name":"foo"}`}}, + }, + }, + }, } var BrokenGeneratedColumnTests = []ScriptTest{ diff --git a/enginetest/queries/index_queries.go b/enginetest/queries/index_queries.go index fdb72b9be7..fc3e8431a1 100644 --- a/enginetest/queries/index_queries.go +++ b/enginetest/queries/index_queries.go @@ -4011,7 +4011,7 @@ var IndexPrefixQueries = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "set @@strict_mysql_compatibility = true;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@strict_mysql_compatibility;", @@ -4065,6 +4065,121 @@ var IndexPrefixQueries = []ScriptTest{ }, }, }, + { + Name: "multiple nullable index prefixes", + SetUpScript: []string{ + "create table test(pk int primary key, shared1 int, shared2 int, a3 int, a4 int, b3 int, b4 int, unique key a_idx(shared1, shared2, a3, a4), unique key b_idx(shared1, shared2, b3, b4))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + }, + }, + { + Name: "multiple non-unique index prefixes", + SetUpScript: []string{ + "create table test(pk int primary key, shared1 int not null, shared2 int not null, a3 int not null, a4 int not null, b3 int not null, b4 int not null, key a_idx(shared1, shared2, a3, a4), key b_idx(shared1, shared2, b3, b4))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 > 3 and a3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 > 3 and b3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + }, + }, + { + Name: "multiple non-unique nullable index prefixes", + SetUpScript: []string{ + "create table test(pk int primary key, shared1 int, shared2 int, a3 int, a4 int, b3 int, b4 int, key a_idx(shared1, shared2, a3, a4), key b_idx(shared1, shared2, b3, b4))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 > 3 and a3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 > 3 and b3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + }, + }, + { + Name: "unique and non-unique nullable index prefixes", + SetUpScript: []string{ + "create table test(pk int primary key, shared1 int, shared2 int, a3 int, a4 int, b3 int, b4 int, unique key a_idx(shared1, shared2, a3, a4), key b_idx(shared1, shared2, b3, b4))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a3 > 3 and a3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"a_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and b3 > 3 and b3 < 5;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + }, + }, + { + Name: "avoid picking an index simply because it matches more filters if those filters are not in the prefix.", + SetUpScript: []string{ + "create table test(pk int primary key, shared1 int, shared2 int, a3 int, a4 int, b3 int, b4 int, unique key a_idx(shared1, a3, a4, shared2), key b_idx(shared1, shared2, b3, b4))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where shared1 = 1 and shared2 = 2 and a4 = 3;", + Expected: []sql.Row{}, + ExpectedIndexes: []string{"b_idx"}, + }, + }, + }, } var IndexQueries = []ScriptTest{ diff --git a/enginetest/queries/information_schema_queries.go b/enginetest/queries/information_schema_queries.go index b3540bc199..4b8c584793 100644 --- a/enginetest/queries/information_schema_queries.go +++ b/enginetest/queries/information_schema_queries.go @@ -30,6 +30,17 @@ var InfoSchemaQueries = []QueryTest{ Query: "SHOW KEYS FROM `columns` FROM `information_schema`;", Expected: []sql.Row{}, }, + { + Query: `SELECT table_schema AS TABLE_CAT, + NULL AS TABLE_SCHEM, + table_name, + CASE WHEN table_type = 'BASE TABLE' THEN + CASE WHEN table_schema = 'mysql' OR table_schema = 'performance_schema' THEN 'SYSTEM TABLE' + ELSE 'TABLE' END + WHEN table_type = 'TEMPORARY' THEN 'LOCAL_TEMPORARY' + ELSE table_type END AS TABLE_TYPE FROM information_schema.tables ORDER BY table_name LIMIT 1;`, + Expected: []sql.Row{{"information_schema", nil, "administrable_role_authorizations", "SYSTEM VIEW"}}, + }, { Query: `SELECT table_name, index_name, comment, non_unique, GROUP_CONCAT(column_name ORDER BY seq_in_index) AS COLUMNS @@ -1532,10 +1543,10 @@ FROM information_schema.COLUMNS WHERE TABLE_SCHEMA='mydb' AND TABLE_NAME='all_ty FROM information_schema.TABLE_CONSTRAINTS TC, information_schema.CHECK_CONSTRAINTS CC WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.CONSTRAINT_SCHEMA AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME AND TC.CONSTRAINT_TYPE = 'CHECK';`, Expected: []sql.Row{ - {"chk1", "(B > 0)", "YES"}, - {"chk2", "(b > 0)", "NO"}, - {"chk3", "(B > 1)", "YES"}, - {"chk4", "(upper(C) = c)", "YES"}, + {"chk1", "(`B` > 0)", "YES"}, + {"chk2", "(`b` > 0)", "NO"}, + {"chk3", "(`B` > 1)", "YES"}, + {"chk4", "(upper(`C`) = `c`)", "YES"}, }, }, { @@ -1551,10 +1562,10 @@ WHERE TABLE_SCHEMA = 'mydb' AND TABLE_NAME = 'checks' AND TC.TABLE_SCHEMA = CC.C { Query: `select * from information_schema.check_constraints where constraint_schema = 'mydb';`, Expected: []sql.Row{ - {"def", "mydb", "chk1", "(B > 0)"}, - {"def", "mydb", "chk2", "(b > 0)"}, - {"def", "mydb", "chk3", "(B > 1)"}, - {"def", "mydb", "chk4", "(upper(C) = c)"}, + {"def", "mydb", "chk1", "(`B` > 0)"}, + {"def", "mydb", "chk2", "(`b` > 0)"}, + {"def", "mydb", "chk3", "(`B` > 1)"}, + {"def", "mydb", "chk4", "(upper(`C`) = `c`)"}, }, }, { diff --git a/enginetest/queries/insert_queries.go b/enginetest/queries/insert_queries.go index 5b3e6ce1c2..0e898b4935 100644 --- a/enginetest/queries/insert_queries.go +++ b/enginetest/queries/insert_queries.go @@ -2276,6 +2276,40 @@ var InsertScripts = []ScriptTest{ }, }, }, + { + Name: "insert...returning... statements", + Dialect: "mysql", // actually mariadb + SetUpScript: []string{ + "CREATE TABLE animals (id int, name varchar(20))", + "CREATE TABLE auto_pk (`pk` int NOT NULL AUTO_INCREMENT, `name` varchar(20), PRIMARY KEY (`pk`))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into animals (id) values (2) returning id", + Expected: []sql.Row{{2}}, + }, + { + Query: "insert into animals(id,name) values (1, 'Dog'),(2,'Lion'),(3,'Tiger'),(4,'Leopard') returning id, id+id", + Expected: []sql.Row{{1, 2}, {2, 4}, {3, 6}, {4, 8}}, + }, + { + Query: "insert into animals set id=1,name='Bear' returning id,name", + Expected: []sql.Row{{1, "Bear"}}, + }, + { + Query: "insert into auto_pk (name) values ('Cat') returning pk,name", + Expected: []sql.Row{{1, "Cat"}}, + }, + { + Query: "insert into auto_pk values (NULL, 'Dog'),(5, 'Fish'),(NULL, 'Horse') returning *", + Expected: []sql.Row{{2, "Dog"}, {5, "Fish"}, {6, "Horse"}}, + }, + { + Query: "insert into auto_pk (name) select name from animals where id = 3 returning *", + Expected: []sql.Row{{7, "Tiger"}}, + }, + }, + }, } var InsertDuplicateKeyKeyless = []ScriptTest{ @@ -2802,7 +2836,7 @@ var InsertIgnoreScripts = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "insert into test_table values (1, 'invalid'), (2, 'comparative politics'), (3, null)", - ExpectedErr: types.ErrConvertingToEnum, // TODO: should be ErrDataTruncatedForColumn + ExpectedErr: types.ErrDataTruncatedForColumnAtRow, }, { Query: "insert ignore into test_table values (1, 'invalid'), (2, 'bye'), (3, null)", diff --git a/enginetest/queries/integration_plans.go b/enginetest/queries/integration_plans.go index 97cd20f2e9..1f6a8f6362 100644 --- a/enginetest/queries/integration_plans.go +++ b/enginetest/queries/integration_plans.go @@ -7148,7 +7148,7 @@ WHERE " │ │ └─ 0.5 (decimal(2,1))\n" + " │ └─ Eq\n" + " │ ├─ nrfj3.YHYLK:6\n" + - " │ └─ 0 (bigint)\n" + + " │ └─ 0 (tinyint)\n" + " │ THEN 1 (tinyint) ELSE 0 (tinyint) END), nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null, nrfj3.B5OUF:3\n" + " ├─ group: nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null\n" + " └─ SubqueryAlias\n" + @@ -8023,7 +8023,7 @@ WHERE " │ │ └─ 0.5 (decimal(2,1))\n" + " │ └─ Eq\n" + " │ ├─ nrfj3.YHYLK:6\n" + - " │ └─ 0 (bigint)\n" + + " │ └─ 0 (tinyint)\n" + " │ THEN 1 (tinyint) ELSE 0 (tinyint) END), nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null, nrfj3.B5OUF:3\n" + " ├─ group: nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null\n" + " └─ SubqueryAlias\n" + @@ -10401,7 +10401,7 @@ WHERE ON aac.id = MJR3D.M22QN`, ExpectedPlan: "Project\n" + " ├─ columns: [mf.FTQLQ:21!null->T4IBQ:0, CASE WHEN NOT\n" + - " │ └─ mjr3d.QNI57:9!null IS NULL\n" + + " │ └─ mjr3d.QNI57:9 IS NULL\n" + " │ THEN Subquery\n" + " │ ├─ cacheable: false\n" + " │ ├─ alias-string: select ei.M6T2N from FZFVD as ei where ei.id = MJR3D.QNI57\n" + @@ -10410,7 +10410,7 @@ WHERE " │ └─ Filter\n" + " │ ├─ Eq\n" + " │ │ ├─ ei.id:34!null\n" + - " │ │ └─ mjr3d.QNI57:9!null\n" + + " │ │ └─ mjr3d.QNI57:9\n" + " │ └─ SubqueryAlias\n" + " │ ├─ name: ei\n" + " │ ├─ outerVisibility: true\n" + @@ -10429,7 +10429,7 @@ WHERE " │ ├─ colSet: (1-10)\n" + " │ └─ tableId: 1\n" + " │ WHEN NOT\n" + - " │ └─ mjr3d.TDEIU:10!null IS NULL\n" + + " │ └─ mjr3d.TDEIU:10 IS NULL\n" + " │ THEN Subquery\n" + " │ ├─ cacheable: false\n" + " │ ├─ alias-string: select ei.M6T2N from FZFVD as ei where ei.id = MJR3D.TDEIU\n" + @@ -10438,7 +10438,7 @@ WHERE " │ └─ Filter\n" + " │ ├─ Eq\n" + " │ │ ├─ ei.id:34!null\n" + - " │ │ └─ mjr3d.TDEIU:10!null\n" + + " │ │ └─ mjr3d.TDEIU:10\n" + " │ └─ SubqueryAlias\n" + " │ ├─ name: ei\n" + " │ ├─ outerVisibility: true\n" + @@ -10458,8 +10458,8 @@ WHERE " │ └─ tableId: 1\n" + " │ END->M6T2N:0, mjr3d.GE5EL:4->GE5EL:0, mjr3d.F7A4Q:5->F7A4Q:0, mjr3d.CC4AX:7->CC4AX:0, mjr3d.SL76B:8!null->SL76B:0, aac.BTXC5:25->YEBDJ:0, mjr3d.PSMU6:2!null]\n" + " └─ Project\n" + - " ├─ columns: [mjr3d.FJDP5:0!null, mjr3d.BJUF2:1!null, mjr3d.PSMU6:2!null, mjr3d.M22QN:3!null, mjr3d.GE5EL:4, mjr3d.F7A4Q:5, mjr3d.ESFVY:6!null, mjr3d.CC4AX:7, mjr3d.SL76B:8!null, mjr3d.QNI57:9!null, mjr3d.TDEIU:10!null, sn.id:11!null, sn.BRQP2:12!null, sn.FFTBJ:13!null, sn.A7XO2:14, sn.KBO7R:15!null, sn.ECDKM:16, sn.NUMK2:17!null, sn.LETOE:18!null, sn.YKSSU:19, sn.FHCYT:20, mf.FTQLQ:21!null, mf.LUEVY:22!null, mf.M22QN:23!null, aac.id:24!null, aac.BTXC5:25, aac.FHCYT:26, mf.FTQLQ:21!null->T4IBQ:0, CASE WHEN NOT\n" + - " │ └─ mjr3d.QNI57:9!null IS NULL\n" + + " ├─ columns: [mjr3d.FJDP5:0!null, mjr3d.BJUF2:1!null, mjr3d.PSMU6:2!null, mjr3d.M22QN:3!null, mjr3d.GE5EL:4, mjr3d.F7A4Q:5, mjr3d.ESFVY:6!null, mjr3d.CC4AX:7, mjr3d.SL76B:8!null, mjr3d.QNI57:9, mjr3d.TDEIU:10, sn.id:11!null, sn.BRQP2:12!null, sn.FFTBJ:13!null, sn.A7XO2:14, sn.KBO7R:15!null, sn.ECDKM:16, sn.NUMK2:17!null, sn.LETOE:18!null, sn.YKSSU:19, sn.FHCYT:20, mf.FTQLQ:21!null, mf.LUEVY:22!null, mf.M22QN:23!null, aac.id:24!null, aac.BTXC5:25, aac.FHCYT:26, mf.FTQLQ:21!null->T4IBQ:0, CASE WHEN NOT\n" + + " │ └─ mjr3d.QNI57:9 IS NULL\n" + " │ THEN Subquery\n" + " │ ├─ cacheable: false\n" + " │ ├─ alias-string: select ei.M6T2N from FZFVD as ei where ei.id = MJR3D.QNI57\n" + @@ -10468,7 +10468,7 @@ WHERE " │ └─ Filter\n" + " │ ├─ Eq\n" + " │ │ ├─ ei.id:27!null\n" + - " │ │ └─ mjr3d.QNI57:9!null\n" + + " │ │ └─ mjr3d.QNI57:9\n" + " │ └─ SubqueryAlias\n" + " │ ├─ name: ei\n" + " │ ├─ outerVisibility: true\n" + @@ -10487,7 +10487,7 @@ WHERE " │ ├─ colSet: (1-10)\n" + " │ └─ tableId: 1\n" + " │ WHEN NOT\n" + - " │ └─ mjr3d.TDEIU:10!null IS NULL\n" + + " │ └─ mjr3d.TDEIU:10 IS NULL\n" + " │ THEN Subquery\n" + " │ ├─ cacheable: false\n" + " │ ├─ alias-string: select ei.M6T2N from FZFVD as ei where ei.id = MJR3D.TDEIU\n" + @@ -10496,7 +10496,7 @@ WHERE " │ └─ Filter\n" + " │ ├─ Eq\n" + " │ │ ├─ ei.id:27!null\n" + - " │ │ └─ mjr3d.TDEIU:10!null\n" + + " │ │ └─ mjr3d.TDEIU:10\n" + " │ └─ SubqueryAlias\n" + " │ ├─ name: ei\n" + " │ ├─ outerVisibility: true\n" + @@ -10534,15 +10534,15 @@ WHERE " │ │ │ │ │ ├─ AND\n" + " │ │ │ │ │ │ ├─ AND\n" + " │ │ │ │ │ │ │ ├─ NOT\n" + - " │ │ │ │ │ │ │ │ └─ mjr3d.QNI57:9!null IS NULL\n" + + " │ │ │ │ │ │ │ │ └─ mjr3d.QNI57:9 IS NULL\n" + " │ │ │ │ │ │ │ └─ Eq\n" + " │ │ │ │ │ │ │ ├─ sn.id:11!null\n" + - " │ │ │ │ │ │ │ └─ mjr3d.QNI57:9!null\n" + + " │ │ │ │ │ │ │ └─ mjr3d.QNI57:9\n" + " │ │ │ │ │ │ └─ mjr3d.BJUF2:1!null IS NULL\n" + " │ │ │ │ │ └─ AND\n" + " │ │ │ │ │ ├─ AND\n" + " │ │ │ │ │ │ ├─ NOT\n" + - " │ │ │ │ │ │ │ └─ mjr3d.QNI57:9!null IS NULL\n" + + " │ │ │ │ │ │ │ └─ mjr3d.QNI57:9 IS NULL\n" + " │ │ │ │ │ │ └─ NOT\n" + " │ │ │ │ │ │ └─ mjr3d.BJUF2:1!null IS NULL\n" + " │ │ │ │ │ └─ InSubquery\n" + @@ -10568,7 +10568,7 @@ WHERE " │ │ │ │ └─ AND\n" + " │ │ │ │ ├─ AND\n" + " │ │ │ │ │ ├─ NOT\n" + - " │ │ │ │ │ │ └─ mjr3d.TDEIU:10!null IS NULL\n" + + " │ │ │ │ │ │ └─ mjr3d.TDEIU:10 IS NULL\n" + " │ │ │ │ │ └─ mjr3d.BJUF2:1!null IS NULL\n" + " │ │ │ │ └─ InSubquery\n" + " │ │ │ │ ├─ left: sn.id:11!null\n" + @@ -10593,7 +10593,7 @@ WHERE " │ │ │ └─ AND\n" + " │ │ │ ├─ AND\n" + " │ │ │ │ ├─ NOT\n" + - " │ │ │ │ │ └─ mjr3d.TDEIU:10!null IS NULL\n" + + " │ │ │ │ │ └─ mjr3d.TDEIU:10 IS NULL\n" + " │ │ │ │ └─ NOT\n" + " │ │ │ │ └─ mjr3d.BJUF2:1!null IS NULL\n" + " │ │ │ └─ InSubquery\n" + @@ -20091,13 +20091,7 @@ FROM " ├─ columns: [id:0!null, FV24E:1!null, UJ6XY:2!null, M22QN:3!null, NZ4MQ:4!null, ETPQV:5, PRUV2:6, YKSSU:7, FHCYT:8]\n" + " └─ Union distinct\n" + " ├─ Project\n" + - " │ ├─ columns: [id:0!null, convert\n" + - " │ │ ├─ type: char\n" + - " │ │ └─ FV24E:1!null\n" + - " │ │ ->FV24E:0, convert\n" + - " │ │ ├─ type: char\n" + - " │ │ └─ UJ6XY:2!null\n" + - " │ │ ->UJ6XY:0, M22QN:3!null, NZ4MQ:4, ETPQV:5!null, convert\n" + + " │ ├─ columns: [id:0!null, FV24E:1!null, UJ6XY:2!null, M22QN:3!null, NZ4MQ:4, ETPQV:5!null, convert\n" + " │ │ ├─ type: char\n" + " │ │ └─ PRUV2:6\n" + " │ │ ->PRUV2:0, YKSSU:7, FHCYT:8]\n" + @@ -20227,7 +20221,7 @@ FROM " │ ├─ name: E2I7U\n" + " │ └─ columns: [id dkcaj kng7t tw55n qrqxw ecxaj fgg57 zh72s fsk67 xqdyt tce7a iwv2h hpcms n5cc2 fhcyt etaq7 a75x7]\n" + " └─ Project\n" + - " ├─ columns: [id:0!null, FV24E:1->FV24E:0, UJ6XY:2->UJ6XY:0, M22QN:3, NZ4MQ:4, ETPQV:5!null, convert\n" + + " ├─ columns: [id:0!null, FV24E:1, UJ6XY:2, M22QN:3, NZ4MQ:4, ETPQV:5!null, convert\n" + " │ ├─ type: char\n" + " │ └─ PRUV2:6!null\n" + " │ ->PRUV2:0, YKSSU:7, FHCYT:8]\n" + diff --git a/enginetest/queries/join_queries.go b/enginetest/queries/join_queries.go index ab7aef1901..2cd269c03a 100644 --- a/enginetest/queries/join_queries.go +++ b/enginetest/queries/join_queries.go @@ -1161,6 +1161,45 @@ var JoinScriptTests = []ScriptTest{ }, }, }, + { + // After this change: https://github.com/dolthub/go-mysql-server/pull/3038 + // hash.HashOf takes in a sql.Schema to convert and hash keys, so + // we need to pass in the schema of the join key. + // This tests a bug introduced in that same PR where we incorrectly pass in the entire schema, + // resulting in incorrect conversions. + Name: "HashLookup on multiple columns with tables with different schemas", + SetUpScript: []string{ + "create table t1 (i int primary key, k int);", + "create table t2 (i int primary key, j varchar(1), k int);", + "insert into t1 values (111111, 111111);", + "insert into t2 values (111111, 'a', 111111);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select /*+ HASH_JOIN(t1, t2) */ * from t1 join t2 on t1.i = t2.i and t1.k = t2.k;", + Expected: []sql.Row{ + {111111, 111111, 111111, "a", 111111}, + }, + }, + }, + }, + { + Name: "HashLookup on multiple columns with collations", + SetUpScript: []string{ + "create table t1 (i int primary key, j varchar(128) collate utf8mb4_0900_ai_ci);", + "create table t2 (i int primary key, j varchar(128) collate utf8mb4_0900_ai_ci);", + "insert into t1 values (1, 'ABCDE');", + "insert into t2 values (1, 'abcde');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select /*+ HASH_JOIN(t1, t2) */ * from t1 join t2 on t1.i = t2.i and t1.j = t2.j;", + Expected: []sql.Row{ + {1, "ABCDE", 1, "abcde"}, + }, + }, + }, + }, } var LateralJoinScriptTests = []ScriptTest{ diff --git a/enginetest/queries/json_scripts.go b/enginetest/queries/json_scripts.go index 475b625ac5..aa78939c09 100644 --- a/enginetest/queries/json_scripts.go +++ b/enginetest/queries/json_scripts.go @@ -187,6 +187,44 @@ var JsonScripts = []ScriptTest{ }, }, }, + { + Name: "json_object preserves escaped characters in key and values", + Assertions: []ScriptTestAssertion{ + { + Query: `select cast(JSON_OBJECT('key"with"quotes\n','3"\\') as char);`, + Expected: []sql.Row{ + {`{"key\"with\"quotes\n": "3\"\\"}`}, + }, + }, + }, + }, + { + Name: "json conversion works with escaped characters", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT CAST(CAST(JSON_OBJECT('key"with"quotes', 1) as CHAR) as JSON);`, + Expected: []sql.Row{ + {`{"key\"with\"quotes": 1}`}, + }, + }, + }, + }, + { + Name: "json_object with escaped k:v pairs from table", + SetUpScript: []string{ + `CREATE TABLE IF NOT EXISTS textt_7998 (t text);`, + `INSERT INTO textt_7998 VALUES ('first row\n\\'), ('second row"');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT JSON_OBJECT(t, t) FROM textt_7998;`, + Expected: []sql.Row{ + {types.MustJSON(`{"first row\n\\": "first row\n\\"}`)}, + {types.MustJSON(`{"second row\"": "second row\""}`)}, + }, + }, + }, + }, { Name: "json_value preserves types", Assertions: []ScriptTestAssertion{ @@ -966,4 +1004,25 @@ var JsonScripts = []ScriptTest{ }, }, }, + { + Name: "Comparisons with JSON values containing non-JSON types", + SetUpScript: []string{ + "CREATE TABLE test (j json);", + "insert into test VALUES ('{ \"key\": 1.0 }');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test where JSON_OBJECT(\"key\", 0.0) < test.j;", + Expected: []sql.Row{{types.MustJSON("{\"key\": 1.0}")}}, + }, + { + Query: `select * from test where JSON_OBJECT("key", 1.0) = test.j;`, + Expected: []sql.Row{{types.MustJSON("{\"key\": 1.0}")}}, + }, + { + Query: `select * from test where JSON_OBJECT("key", 2.0) > test.j;`, + Expected: []sql.Row{{types.MustJSON("{\"key\": 1.0}")}}, + }, + }, + }, } diff --git a/enginetest/queries/json_table_queries.go b/enginetest/queries/json_table_queries.go index 7059eae9af..c4d7e13bcc 100644 --- a/enginetest/queries/json_table_queries.go +++ b/enginetest/queries/json_table_queries.go @@ -139,6 +139,12 @@ var JSONTableQueryTests = []QueryTest{ {9}, }, }, + { + Query: "select * from json_table('[\"foo\", \"bar\"]', \"$[*]\" columns(tag text path '$')) as tags where tag like 'foo';", + Expected: []sql.Row{ + {"foo"}, + }, + }, } var JSONTableScriptTests = []ScriptTest{ diff --git a/enginetest/queries/logic_test_scripts.go b/enginetest/queries/logic_test_scripts.go index 08013a0484..f5bab9e92e 100644 --- a/enginetest/queries/logic_test_scripts.go +++ b/enginetest/queries/logic_test_scripts.go @@ -1004,39 +1004,41 @@ var SQLLogicSubqueryTests = []ScriptTest{ }, }, }, - //{ - // Name: "multiple nested subquery", - // SetUpScript: []string{ - // "CREATE TABLE `groups`(id SERIAL PRIMARY KEY, data JSON);", - // "INSERT INTO `groups`(data) VALUES('{\"name\": \"Group 1\", \"members\": [{\"name\": \"admin\", \"type\": \"USER\"}, {\"name\": \"user\", \"type\": \"USER\"}]}');", - // "INSERT INTO `groups`(data) VALUES('{\"name\": \"Group 2\", \"members\": [{\"name\": \"admin2\", \"type\": \"USER\"}]}');", - // "CREATE TABLE t32786 (id VARCHAR(36) PRIMARY KEY, parent_id VARCHAR(36), parent_path text);", - // "INSERT INTO t32786 VALUES ('3AAA2577-DBC3-47E7-9E85-9CC7E19CF48A', null, null);", - // "INSERT INTO t32786 VALUES ('5AE7EAFD-8277-4F41-83DE-0FD4B4482169', '3AAA2577-DBC3-47E7-9E85-9CC7E19CF48A', null);", - // "CREATE TABLE users (id INT8 NOT NULL, name VARCHAR(50), PRIMARY KEY (id));", - // "INSERT INTO users(id, name) VALUES (1, 'user1');", - // "INSERT INTO users(id, name) VALUES (2, 'user2');", - // "INSERT INTO users(id, name) VALUES (3, 'user3');", - // "CREATE TABLE stuff(id INT8 NOT NULL, date DATE, user_id INT8, PRIMARY KEY (id), FOREIGN KEY (user_id) REFERENCES users (id));", - // "INSERT INTO stuff(id, date, user_id) VALUES (1, '2007-10-15', 1);", - // "INSERT INTO stuff(id, date, user_id) VALUES (2, '2007-12-15', 1);", - // "INSERT INTO stuff(id, date, user_id) VALUES (3, '2007-11-15', 1);", - // "INSERT INTO stuff(id, date, user_id) VALUES (4, '2008-01-15', 2);", - // "INSERT INTO stuff(id, date, user_id) VALUES (5, '2007-06-15', 3);", - // "INSERT INTO stuff(id, date, user_id) VALUES (6, '2007-03-15', 3);", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Skip: true, - // Query: "SELECT users.id AS users_id, users.name AS users_name, stuff_1.id AS stuff_1_id, stuff_1.date AS stuff_1_date, stuff_1.user_id AS stuff_1_user_id FROM users LEFT JOIN stuff AS stuff_1 ON users.id = stuff_1.user_id AND stuff_1.id = (SELECT stuff_2.id FROM stuff AS stuff_2 WHERE stuff_2.user_id = users.id ORDER BY stuff_2.date DESC LIMIT 1) ORDER BY users.name;", - // Expected: []sql.Row{ - // {1, "user1", 2, 2007-12-15, 1}, - // {2, "user2", 4, 2008-01-15, 2}, - // {3, "user3", 5, 2007-06-15, 3}, - // }, - // }, - // }, - //}, + { + // Skipping because we don't convert Time objects to strings in enginetests + Skip: true, + Name: "multiple nested subquery", + SetUpScript: []string{ + "CREATE TABLE `groups`(id SERIAL PRIMARY KEY, data JSON);", + "INSERT INTO `groups`(data) VALUES('{\"name\": \"Group 1\", \"members\": [{\"name\": \"admin\", \"type\": \"USER\"}, {\"name\": \"user\", \"type\": \"USER\"}]}');", + "INSERT INTO `groups`(data) VALUES('{\"name\": \"Group 2\", \"members\": [{\"name\": \"admin2\", \"type\": \"USER\"}]}');", + "CREATE TABLE t32786 (id VARCHAR(36) PRIMARY KEY, parent_id VARCHAR(36), parent_path text);", + "INSERT INTO t32786 VALUES ('3AAA2577-DBC3-47E7-9E85-9CC7E19CF48A', null, null);", + "INSERT INTO t32786 VALUES ('5AE7EAFD-8277-4F41-83DE-0FD4B4482169', '3AAA2577-DBC3-47E7-9E85-9CC7E19CF48A', null);", + "CREATE TABLE users (id INT8 NOT NULL, name VARCHAR(50), PRIMARY KEY (id));", + "INSERT INTO users(id, name) VALUES (1, 'user1');", + "INSERT INTO users(id, name) VALUES (2, 'user2');", + "INSERT INTO users(id, name) VALUES (3, 'user3');", + "CREATE TABLE stuff(id INT8 NOT NULL, date DATE, user_id INT8, PRIMARY KEY (id), FOREIGN KEY (user_id) REFERENCES users (id));", + "INSERT INTO stuff(id, date, user_id) VALUES (1, '2007-10-15', 1);", + "INSERT INTO stuff(id, date, user_id) VALUES (2, '2007-12-15', 1);", + "INSERT INTO stuff(id, date, user_id) VALUES (3, '2007-11-15', 1);", + "INSERT INTO stuff(id, date, user_id) VALUES (4, '2008-01-15', 2);", + "INSERT INTO stuff(id, date, user_id) VALUES (5, '2007-06-15', 3);", + "INSERT INTO stuff(id, date, user_id) VALUES (6, '2007-03-15', 3);", + }, + Assertions: []ScriptTestAssertion{ + { + Skip: true, + Query: "SELECT users.id AS users_id, users.name AS users_name, stuff_1.id AS stuff_1_id, stuff_1.date AS stuff_1_date, stuff_1.user_id AS stuff_1_user_id FROM users LEFT JOIN stuff AS stuff_1 ON users.id = stuff_1.user_id AND stuff_1.id = (SELECT stuff_2.id FROM stuff AS stuff_2 WHERE stuff_2.user_id = users.id ORDER BY stuff_2.date DESC LIMIT 1) ORDER BY users.name;", + Expected: []sql.Row{ + {1, "user1", 2, "2007 - 12 - 15", 1}, + {2, "user2", 4, "2008 - 01 - 15", 2}, + {3, "user3", 5, "2007 - 06 - 15", 3}, + }, + }, + }, + }, { Name: "multiple nested subquery again", SetUpScript: []string{ diff --git a/enginetest/queries/order_by_group_by_queries.go b/enginetest/queries/order_by_group_by_queries.go index 84de8445ea..c08fd73b77 100644 --- a/enginetest/queries/order_by_group_by_queries.go +++ b/enginetest/queries/order_by_group_by_queries.go @@ -305,4 +305,45 @@ var OrderByGroupByScriptTests = []ScriptTest{ }, }, }, + { + Name: "Group by true and 1", + // https://github.com/dolthub/dolt/issues/9320 + Dialect: "mysql", + SetUpScript: []string{ + "create table t0(c0 int)", + "insert into t0(c0) values(1),(123)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select if(t0.c0 = 123, TRUE, t0.c0) AS ref0, min(t0.c0) as ref1 from t0 group by ref0", + Expected: []sql.Row{{1, 1}}, + }, + }, + }, + { + Name: "Group by null = 1", + // https://github.com/dolthub/dolt/issues/9035 + SetUpScript: []string{ + "create table t0(c0 int, c1 int)", + "insert into t0(c0, c1) values(NULL,1),(1,NULL)", + "create table t1(id int primary key, c0 int, c1 int)", + "insert into t1(id, c0, c1) values(1,NULL,NULL),(2,1,1),(3,1,NULL),(4,2,1),(5,NULL,1)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select t0.c0 = t0.c1 as ref0, sum(1) as ref1 from t0 group by ref0", + Expected: []sql.Row{ + {nil, float64(2)}, + }, + }, + { + Query: "select t1.c0 = t1.c1 as ref0, sum(1) as ref1 from t1 group by ref0", + Expected: []sql.Row{ + {nil, float64(3)}, + {true, float64(1)}, + {false, float64(1)}, + }, + }, + }, + }, } diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 350fda5343..0986385c46 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -325,20 +325,20 @@ END`, // need to filter out Result Sets that should be completely omitted. { Query: "CALL p1(0)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "CALL p1(1)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "CALL p1(2)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // https://github.com/dolthub/dolt/issues/6230 Query: "CALL p1(200)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, }, }, @@ -359,15 +359,15 @@ END`, // need to filter out Result Sets that should be completely omitted. { Query: "CALL p1(0)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "CALL p1(1)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "CALL p1(2)", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, }, }, @@ -985,7 +985,7 @@ END;`, Assertions: []ScriptTestAssertion{ { Query: "SET @x = 2;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // TODO: Set statements don't return anything for whatever reason @@ -2270,7 +2270,7 @@ end; Assertions: []ScriptTestAssertion{ { Query: "call proc();", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @v;", @@ -2309,6 +2309,49 @@ end; }, }, }, + { + Name: "stored procedure with exists subquery", + SetUpScript: []string{ + ` +create procedure exists_proc1(in x int) +begin + select 1 where exists (select x); +end; +`, + ` +create procedure exists_proc2(in x int) +begin + select exists (select x); +end; +`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "call exists_proc1(1);", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "call exists_proc1(0);", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "call exists_proc2(1);", + Expected: []sql.Row{ + {true}, + }, + }, + { + Query: "call exists_proc2(0);", + Expected: []sql.Row{ + {true}, + }, + }, + }, + }, } var ProcedureCallTests = []ScriptTest{ diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 9e350bbeee..26298ba4ed 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -813,7 +813,7 @@ var QueryTests = []QueryTest{ { // Assert that SYSDATE() returns different times on each call in a query (unlike NOW()) // Using the maximum precision for fractional seconds, lets us see a difference. - Query: "select now() = sysdate(), sleep(0.5), now(6) < sysdate(6);", + Query: "select sysdate() - now() <= 1, sleep(2), sysdate() - now() > 0;", Expected: []sql.Row{{true, 0, true}}, }, { @@ -828,10 +828,6 @@ var QueryTests = []QueryTest{ Query: "select y as x from xy group by (y) having AVG(x) > 0", Expected: []sql.Row{{0}, {1}, {3}}, }, - // { - // Query: "select y as z from xy group by (y) having AVG(z) > 0", - // Expected: []sql.Row{{1}, {2}, {3}}, - // }, { Query: "SELECT * FROM mytable t0 INNER JOIN mytable t1 ON (t1.i IN (((true)%(''))));", Expected: []sql.Row{}, @@ -1643,78 +1639,6 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT column_0 FROM (values row(1.5,2+2), row(floor(1.5),concat("a","b"))) a order by 1;`, Expected: []sql.Row{{"1.0"}, {"1.5"}}, }, - { - Query: `SELECT FORMAT(val, 2) FROM - (values row(4328904), row(432053.4853), row(5.93288775208e+08), row("5784029.372"), row(-4229842.122), row(-0.009)) a (val)`, - Expected: []sql.Row{ - {"4,328,904.00"}, - {"432,053.49"}, - {"593,288,775.21"}, - {"5,784,029.37"}, - {"-4,229,842.12"}, - {"-0.01"}, - }, - }, - { - Query: "SELECT FORMAT(i, 3) FROM mytable;", - Expected: []sql.Row{ - {"1.000"}, - {"2.000"}, - {"3.000"}, - }, - }, - { - Query: `SELECT FORMAT(val, 2, 'da_DK') FROM - (values row(4328904), row(432053.4853), row(5.93288775208e+08), row("5784029.372"), row(-4229842.122), row(-0.009)) a (val)`, - Expected: []sql.Row{ - {"4.328.904,00"}, - {"432.053,49"}, - {"593.288.775,21"}, - {"5.784.029,37"}, - {"-4.229.842,12"}, - {"-0,01"}, - }, - }, - { - Query: "SELECT FORMAT(i, 3, 'da_DK') FROM mytable;", - Expected: []sql.Row{ - {"1,000"}, - {"2,000"}, - {"3,000"}, - }, - }, - { - Query: "SELECT DATEDIFF(date_col, '2019-12-28') FROM datetime_table where date_col = date('2019-12-31T12:00:00');", - Expected: []sql.Row{ - {3}, - }, - }, - { - Query: `SELECT DATEDIFF(val, '2019/12/28') FROM - (values row('2017-11-30 22:59:59'), row('2020/01/02'), row('2021-11-30'), row('2020-12-31T12:00:00')) a (val)`, - Expected: []sql.Row{ - {-758}, - {5}, - {703}, - {369}, - }, - }, - { - Query: "SELECT TIMESTAMPDIFF(SECOND,'2007-12-31 23:59:58', '2007-12-31 00:00:00');", - Expected: []sql.Row{ - {-86398}, - }, - }, - { - Query: `SELECT TIMESTAMPDIFF(MINUTE, val, '2019/12/28') FROM - (values row('2017-11-30 22:59:59'), row('2020/01/02'), row('2019-12-27 23:15:55'), row('2019-12-31T12:00:00')) a (val);`, - Expected: []sql.Row{ - {1090140}, - {-7200}, - {44}, - {-5040}, - }, - }, { Query: "values row(1, 3), row(2, 2), row(3, 1);", Expected: []sql.Row{ @@ -1858,18 +1782,6 @@ SELECT * FROM cte WHERE d = 2;`, }, }, - { - Query: "SELECT TIMEDIFF(null, '2017-11-30 22:59:59');", - Expected: []sql.Row{{nil}}, - }, - { - Query: "SELECT DATEDIFF('2019/12/28', null);", - Expected: []sql.Row{{nil}}, - }, - { - Query: "SELECT TIMESTAMPDIFF(SECOND, null, '2007-12-31 00:00:00');", - Expected: []sql.Row{{nil}}, - }, { Query: `SELECT JSON_MERGE_PRESERVE('{ "a": 1, "b": 2 }','{ "a": 3, "c": 4 }','{ "a": 5, "d": 6 }')`, Expected: []sql.Row{ @@ -3837,331 +3749,7 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT substring("foo", 2, 2)`, Expected: []sql.Row{{"oo"}}, }, - { - Query: `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', 2)`, - Expected: []sql.Row{ - {"a.b"}, - }, - }, - { - Query: `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', -2)`, - Expected: []sql.Row{ - {"e.f"}, - }, - }, - { - Query: `SELECT SUBSTRING_INDEX(SUBSTRING_INDEX('source{d}', '{d}', 1), 'r', -1)`, - Expected: []sql.Row{ - {"ce"}, - }, - }, - { - Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS s FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY 1 HAVING s = 'secon';`, - Expected: []sql.Row{{"secon"}}, - }, - { - Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS s FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY s HAVING s = 'secon';`, - Expected: []sql.Row{}, - }, - { - Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS ss FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY s HAVING s = 'secon';`, - Expected: []sql.Row{}, - }, - { - Query: `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) AS ss FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY ss HAVING ss = 'secon';`, - Expected: []sql.Row{ - {"secon"}, - }, - }, - { - Query: `SELECT TRIM(mytable.s) AS s FROM mytable`, - Expected: []sql.Row{{"first row"}, {"second row"}, {"third row"}}, - }, - { - Query: `SELECT TRIM("row" from mytable.s) AS s FROM mytable`, - Expected: []sql.Row{{"first "}, {"second "}, {"third "}}, - }, - { - Query: `SELECT TRIM(mytable.s from "first row") AS s FROM mytable`, - Expected: []sql.Row{{""}, {"first row"}, {"first row"}}, - }, - { - Query: `SELECT TRIM(mytable.s from mytable.s) AS s FROM mytable`, - Expected: []sql.Row{{""}, {""}, {""}}, - }, - { - Query: `SELECT TRIM(" foo ")`, - Expected: []sql.Row{{"foo"}}, - }, - { - Query: `SELECT TRIM(" " FROM " foo ")`, - Expected: []sql.Row{{"foo"}}, - }, - { - Query: `SELECT TRIM(LEADING " " FROM " foo ")`, - Expected: []sql.Row{{"foo "}}, - }, - { - Query: `SELECT TRIM(TRAILING " " FROM " foo ")`, - Expected: []sql.Row{{" foo"}}, - }, - { - Query: `SELECT TRIM(BOTH " " FROM " foo ")`, - Expected: []sql.Row{{"foo"}}, - }, - { - Query: `SELECT TRIM("" FROM " foo")`, - Expected: []sql.Row{{" foo"}}, - }, - { - Query: `SELECT TRIM("bar" FROM "barfoobar")`, - Expected: []sql.Row{{"foo"}}, - }, - { - Query: `SELECT TRIM(TRAILING "bar" FROM "barfoobar")`, - Expected: []sql.Row{{"barfoo"}}, - }, - { - Query: `SELECT TRIM(TRAILING "foo" FROM "foo")`, - Expected: []sql.Row{{""}}, - }, - { - Query: `SELECT TRIM(LEADING "ooo" FROM TRIM("oooo"))`, - Expected: []sql.Row{{"o"}}, - }, - { - Query: `SELECT TRIM(BOTH "foo" FROM TRIM("barfoobar"))`, - Expected: []sql.Row{{"barfoobar"}}, - }, - { - Query: `SELECT TRIM(LEADING "bar" FROM TRIM("foobar"))`, - Expected: []sql.Row{{"foobar"}}, - }, - { - Query: `SELECT TRIM(TRAILING "oo" FROM TRIM("oof"))`, - Expected: []sql.Row{{"oof"}}, - }, - { - Query: `SELECT TRIM(LEADING "test" FROM TRIM(" test "))`, - Expected: []sql.Row{{""}}, - }, - { - Query: `SELECT TRIM(LEADING CONCAT("a", "b") FROM TRIM("ababab"))`, - Expected: []sql.Row{{""}}, - }, - { - Query: `SELECT TRIM(TRAILING CONCAT("a", "b") FROM CONCAT("test","ab"))`, - Expected: []sql.Row{{"test"}}, - }, - { - Query: `SELECT TRIM(LEADING 1 FROM "11111112")`, - Expected: []sql.Row{{"2"}}, - }, - { - Query: `SELECT TRIM(LEADING 1 FROM 11111112)`, - Expected: []sql.Row{{"2"}}, - }, - { - Query: `SELECT INET_ATON("10.0.5.10")`, - Expected: []sql.Row{{uint64(167773450)}}, - }, - { - Query: `SELECT INET_NTOA(167773450)`, - Expected: []sql.Row{{"10.0.5.10"}}, - }, - { - Query: `SELECT INET_ATON("10.0.5.11")`, - Expected: []sql.Row{{uint64(167773451)}}, - }, - { - Query: `SELECT INET_NTOA(167773451)`, - Expected: []sql.Row{{"10.0.5.11"}}, - }, - { - Query: `SELECT INET_NTOA(INET_ATON("12.34.56.78"))`, - Expected: []sql.Row{{"12.34.56.78"}}, - }, - { - Query: `SELECT INET_ATON(INET_NTOA("12345678"))`, - Expected: []sql.Row{{uint64(12345678)}}, - }, - { - Query: `SELECT INET_ATON("notanipaddress")`, - Expected: []sql.Row{{nil}}, - }, - { - Query: `SELECT INET_NTOA("spaghetti")`, - Expected: []sql.Row{{"0.0.0.0"}}, - }, - { - Query: `SELECT HEX(INET6_ATON("10.0.5.9"))`, - Expected: []sql.Row{{"0A000509"}}, - }, - { - Query: `SELECT HEX(INET6_ATON("::10.0.5.9"))`, - Expected: []sql.Row{{"0000000000000000000000000A000509"}}, - }, - { - Query: `SELECT HEX(INET6_ATON("1.2.3.4"))`, - Expected: []sql.Row{{"01020304"}}, - }, - { - Query: `SELECT HEX(INET6_ATON("fdfe::5455:caff:fefa:9098"))`, - Expected: []sql.Row{{"FDFE0000000000005455CAFFFEFA9098"}}, - }, - { - Query: `SELECT HEX(INET6_ATON("1111:2222:3333:4444:5555:6666:7777:8888"))`, - Expected: []sql.Row{{"11112222333344445555666677778888"}}, - }, - { - Query: `SELECT INET6_ATON("notanipaddress")`, - Expected: []sql.Row{{nil}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("1234ffff5678ffff1234ffff5678ffff"))`, - Expected: []sql.Row{{"1234:ffff:5678:ffff:1234:ffff:5678:ffff"}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("ffffffff"))`, - Expected: []sql.Row{{"255.255.255.255"}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("000000000000000000000000ffffffff"))`, - Expected: []sql.Row{{"::255.255.255.255"}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("00000000000000000000ffffffffffff"))`, - Expected: []sql.Row{{"::ffff:255.255.255.255"}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("0000000000000000000000000000ffff"))`, - Expected: []sql.Row{{"::ffff"}}, - }, - { - Query: `SELECT INET6_NTOA(UNHEX("00000000000000000000000000000000"))`, - Expected: []sql.Row{{"::"}}, - }, - { - Query: `SELECT INET6_NTOA("notanipaddress")`, - Expected: []sql.Row{{nil}}, - }, - { - Query: `SELECT IS_IPV4("10.0.1.10")`, - Expected: []sql.Row{{true}}, - }, - { - Query: `SELECT IS_IPV4("::10.0.1.10")`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4("notanipaddress")`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV6("10.0.1.10")`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV6("::10.0.1.10")`, - Expected: []sql.Row{{true}}, - }, - { - Query: `SELECT IS_IPV6("notanipaddress")`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("10.0.1.10"))`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("::10.0.1.10"))`, - Expected: []sql.Row{{true}}, - }, - { - Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("::ffff:10.0.1.10"))`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("notanipaddress"))`, - Expected: []sql.Row{{nil}}, - }, - { - Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("10.0.1.10"))`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("::10.0.1.10"))`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT IS_IPV4_MAPPED(INET6_ATON("::ffff:10.0.1.10"))`, - Expected: []sql.Row{{true}}, - }, - { - Query: `SELECT IS_IPV4_COMPAT(INET6_ATON("notanipaddress"))`, - Expected: []sql.Row{{nil}}, - }, - { - Query: "SELECT YEAR('2007-12-11') FROM mytable", - Expected: []sql.Row{{int32(2007)}, {int32(2007)}, {int32(2007)}}, - }, - { - Query: "SELECT MONTH('2007-12-11') FROM mytable", - Expected: []sql.Row{{int32(12)}, {int32(12)}, {int32(12)}}, - }, - { - Query: "SELECT DAY('2007-12-11') FROM mytable", - Expected: []sql.Row{{int32(11)}, {int32(11)}, {int32(11)}}, - }, - { - Query: "SELECT HOUR('2007-12-11 20:21:22') FROM mytable", - Expected: []sql.Row{{int32(20)}, {int32(20)}, {int32(20)}}, - }, - { - Query: "SELECT MINUTE('2007-12-11 20:21:22') FROM mytable", - Expected: []sql.Row{{int32(21)}, {int32(21)}, {int32(21)}}, - }, - { - Query: "SELECT SECOND('2007-12-11 20:21:22') FROM mytable", - Expected: []sql.Row{{int32(22)}, {int32(22)}, {int32(22)}}, - }, - { - Query: "SELECT DAYOFYEAR('2007-12-11 20:21:22') FROM mytable", - Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, - }, - { - Query: "SELECT SECOND('2007-12-11T20:21:22Z') FROM mytable", - Expected: []sql.Row{{int32(22)}, {int32(22)}, {int32(22)}}, - }, - { - Query: "SELECT DAYOFYEAR('2007-12-11') FROM mytable", - Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, - }, - { - Query: "SELECT DAYOFYEAR('20071211') FROM mytable", - Expected: []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, - }, - { - Query: "SELECT YEARWEEK('0000-01-01')", - Expected: []sql.Row{{int32(1)}}, - }, - { - Query: "SELECT YEARWEEK('9999-12-31')", - Expected: []sql.Row{{int32(999952)}}, - }, - { - Query: "SELECT YEARWEEK('2008-02-20', 1)", - Expected: []sql.Row{{int32(200808)}}, - }, - { - Query: "SELECT YEARWEEK('1987-01-01')", - Expected: []sql.Row{{int32(198652)}}, - }, - { - Query: "SELECT YEARWEEK('1987-01-01', 20), YEARWEEK('1987-01-01', 1), YEARWEEK('1987-01-01', 2), YEARWEEK('1987-01-01', 3), YEARWEEK('1987-01-01', 4), YEARWEEK('1987-01-01', 5), YEARWEEK('1987-01-01', 6), YEARWEEK('1987-01-01', 7)", - Expected: []sql.Row{{int32(198653), int32(198701), int32(198652), int32(198701), int32(198653), int32(198652), int32(198653), int32(198652)}}, - }, { Query: `select 'a'+4;`, Expected: []sql.Row{{4.0}}, @@ -4830,6 +4418,47 @@ SELECT * FROM cte WHERE d = 2;`, Query: "SELECT subdate(da, f32/10) from typestable;", Expected: []sql.Row{{time.Date(2019, time.December, 30, 0, 0, 0, 0, time.UTC)}}, }, + { + Query: "SELECT date_add('4444-01-01', INTERVAL 5400000 DAY);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT date_add('4444-01-01', INTERVAL -5300000 DAY);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT subdate('2008-01-02', 12e10);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT date_add('2008-01-02', INTERVAL 1000000 day);", + Expected: []sql.Row{{"4745-11-29"}}, + }, + { + Query: "SELECT subdate('2008-01-02', INTERVAL 700000 day);", + Expected: []sql.Row{{"0091-06-20"}}, + }, + { + Query: "SELECT date_add('0000-01-01:01:00:00', INTERVAL 0 day);", + // MYSQL uses a proleptic gregorian, however, Go's time package does normal gregorian. + Expected: []sql.Row{{"0000-01-01 01:00:00"}}, + }, + { + Query: "SELECT date_add('9999-12-31:23:59:59.9999994', INTERVAL 0 day);", + Expected: []sql.Row{{"9999-12-31 23:59:59.999999"}}, + }, + { + Query: "SELECT date_add('9999-12-31:23:59:59.9999995', INTERVAL 0 day);", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT date_add('9999-12-31:23:59:59.99999945', INTERVAL 0 day);", + Expected: []sql.Row{{"9999-12-31 23:59:59.999999"}}, + }, + { + Query: "SELECT date_add('9999-12-31:23:59:59.99999944444444444-', INTERVAL 0 day);", + Expected: []sql.Row{{nil}}, + }, { Query: `SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM othertable) othertable_one) othertable_two) othertable_three WHERE s2 = 'first'`, Expected: []sql.Row{ @@ -5384,204 +5013,6 @@ SELECT * FROM cte WHERE d = 2;`, {int64(1)}, }, }, - { - Query: `SELECT CONCAT("a", "b", "c")`, - Expected: []sql.Row{ - {string("abc")}, - }, - }, - { - Query: `SELECT COALESCE(NULL, NULL, NULL, 'example', NULL, 1234567890)`, - Expected: []sql.Row{ - {string("example")}, - }, - }, - { - Query: `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, - Expected: []sql.Row{ - {int32(1234567890)}, - }, - }, - { - Query: "SELECT COALESCE (NULL, NULL)", - Expected: []sql.Row{{nil}}, - ExpectedColumns: []*sql.Column{ - { - Name: "COALESCE (NULL, NULL)", - Type: types.Null, - }, - }, - }, - { - Query: `SELECT COALESCE(CAST('{"a": "one \\n two"}' as json), '');`, - Expected: []sql.Row{ - {"{\"a\": \"one \\n two\"}"}, - }, - }, - { - Query: "SELECT concat(s, i) FROM mytable", - Expected: []sql.Row{ - {string("first row1")}, - {string("second row2")}, - {string("third row3")}, - }, - }, - { - Query: "SELECT version()", - Expected: []sql.Row{ - {"8.0.31"}, - }, - }, - { - Query: `SELECT RAND(100)`, - Expected: []sql.Row{ - {float64(0.8165026937796166)}, - }, - }, - { - Query: `SELECT RAND(i) from mytable order by i`, - Expected: []sql.Row{{0.6046602879796196}, {0.16729663442585624}, {0.7199826688373036}}, - }, - { - Query: `SELECT RAND(100) = RAND(100)`, - Expected: []sql.Row{ - {true}, - }, - }, - { - Query: `SELECT RAND() = RAND()`, - Expected: []sql.Row{ - {false}, - }, - }, - { - Query: "SELECT MOD(i, 2) from mytable order by i limit 1", - Expected: []sql.Row{ - {"1"}, - }, - }, - { - Query: "SELECT SIN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.8414709848078965}, - }, - }, - { - Query: "SELECT COS(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.5403023058681398}, - }, - }, - { - Query: "SELECT TAN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {1.557407724654902}, - }, - }, - { - Query: "SELECT ASIN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {1.5707963267948966}, - }, - }, - { - Query: "SELECT ACOS(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.0}, - }, - }, - { - Query: "SELECT ATAN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.7853981633974483}, - }, - }, - { - Query: "SELECT COT(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.6420926159343308}, - }, - }, - { - Query: "SELECT DEGREES(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {57.29577951308232}, - }, - }, - { - Query: "SELECT RADIANS(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {0.017453292519943295}, - }, - }, - { - Query: "SELECT CRC32(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {uint64(0x83dcefb7)}, - }, - }, - { - Query: "SELECT SIGN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {1}, - }, - }, - { - Query: "SELECT ASCII(s) from mytable order by i limit 1", - Expected: []sql.Row{ - {uint64(0x66)}, - }, - }, - { - Query: "SELECT HEX(s) from mytable order by i limit 1", - Expected: []sql.Row{ - {"666972737420726F77"}, - }, - }, - { - Query: "SELECT UNHEX(s) from mytable order by i limit 1", - Expected: []sql.Row{ - {nil}, - }, - }, - { - Query: "SELECT BIN(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {"1"}, - }, - }, - { - Query: "SELECT BIT_LENGTH(i) from mytable order by i limit 1", - Expected: []sql.Row{ - {64}, - }, - }, - { - Query: "select date_format(datetime_col, '%D') from datetime_table order by 1", - Expected: []sql.Row{ - {"1st"}, - {"4th"}, - {"7th"}, - }, - }, - { - Query: "select time_format(time_col, '%h%p') from datetime_table order by 1", - Expected: []sql.Row{ - {"03AM"}, - {"03PM"}, - {"04AM"}, - }, - }, - { - Query: "select from_unixtime(i) from mytable order by 1", - Expected: []sql.Row{ - {UnixTimeInLocal(1, 0)}, - {UnixTimeInLocal(2, 0)}, - {UnixTimeInLocal(3, 0)}, - }, - }, - // TODO: add additional tests for other functions. Every function needs an engine test to ensure it works correctly - // with the analyzer. { Query: "SELECT * FROM mytable WHERE 1 > 5", Expected: nil, @@ -5687,7 +5118,7 @@ SELECT * FROM cte WHERE d = 2;`, sql.Collation_Default.CharacterSet().String() + " */", Expected: []sql.Row{ - {}, + {types.NewOkResult(0)}, }, }, { @@ -5695,7 +5126,7 @@ SELECT * FROM cte WHERE d = 2;`, sql.Collation_Default.String() + "';", Expected: []sql.Row{ - {}, + {types.NewOkResult(0)}, }, }, { @@ -6092,7 +5523,7 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT if(123 = 123, NULL, NULL = 1)`, Expected: []sql.Row{{nil}}, ExpectedColumns: []*sql.Column{ - {Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Int64}, // TODO: this should be getting coerced to bool + {Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Boolean}, }, }, { @@ -6123,7 +5554,7 @@ SELECT * FROM cte WHERE d = 2;`, { Query: `SELECT if(0, "abc", 456)`, Expected: []sql.Row{ - {456}, + {"456"}, }, }, { @@ -6852,10 +6283,6 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT LEAST(@@back_log,@@auto_increment_offset)`, Expected: []sql.Row{{-1}}, }, - { - Query: `SELECT CHAR_LENGTH('áé'), LENGTH('àè')`, - Expected: []sql.Row{{int32(2), int32(4)}}, - }, { Query: "SELECT i, COUNT(i) AS `COUNT(i)` FROM (SELECT i FROM mytable) t GROUP BY i ORDER BY i, `COUNT(i)` DESC", Expected: []sql.Row{{int64(1), int64(1)}, {int64(2), int64(1)}, {int64(3), int64(1)}}, @@ -6901,10 +6328,6 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, Expected: []sql.Row{{int64(1)}}, }, - { - Query: `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, - Expected: []sql.Row{{"6789"}}, - }, { Query: `SELECT CASE i WHEN 1 THEN i ELSE NULL END FROM mytable`, Expected: []sql.Row{{int64(1)}, {nil}, {nil}}, @@ -8387,6 +7810,78 @@ SELECT * FROM cte WHERE d = 2;`, Query: "SELECT CONV(i, 10, 2) FROM mytable", Expected: []sql.Row{{"1"}, {"10"}, {"11"}}, }, + { + Query: "SELECT OCT(8)", + Expected: []sql.Row{{"10"}}, + }, + { + Query: "SELECT OCT(255)", + Expected: []sql.Row{{"377"}}, + }, + { + Query: "SELECT OCT(0)", + Expected: []sql.Row{{"0"}}, + }, + { + Query: "SELECT OCT(1)", + Expected: []sql.Row{{"1"}}, + }, + { + Query: "SELECT OCT(NULL)", + Expected: []sql.Row{{nil}}, + }, + { + Query: "SELECT OCT(-1)", + Expected: []sql.Row{{"1777777777777777777777"}}, + }, + { + Query: "SELECT OCT(-8)", + Expected: []sql.Row{{"1777777777777777777770"}}, + }, + { + Query: "SELECT OCT(OCT(4))", + Expected: []sql.Row{{"4"}}, + }, + { + Query: "SELECT OCT('16')", + Expected: []sql.Row{{"20"}}, + }, + { + Query: "SELECT OCT('abc')", + Expected: []sql.Row{{"0"}}, + }, + { + Query: "SELECT OCT(15.7)", + Expected: []sql.Row{{"17"}}, + }, + { + Query: "SELECT OCT(-15.2)", + Expected: []sql.Row{{"1777777777777777777761"}}, + }, + { + Query: "SELECT OCT(HEX(SUBSTRING('127.0', 1, 3)))", + Expected: []sql.Row{{"1143625"}}, + }, + { + Query: "SELECT i, OCT(i), OCT(-i), OCT(i * 2) FROM mytable ORDER BY i", + Expected: []sql.Row{ + {1, "1", "1777777777777777777777", "2"}, + {2, "2", "1777777777777777777776", "4"}, + {3, "3", "1777777777777777777775", "6"}, + }, + }, + { + Query: "SELECT OCT(i) FROM mytable ORDER BY CONV(i, 10, 16)", + Expected: []sql.Row{{"1"}, {"2"}, {"3"}}, + }, + { + Query: "SELECT i FROM mytable WHERE OCT(s) > 0", + Expected: []sql.Row{}, + }, + { + Query: "SELECT s FROM mytable WHERE OCT(i*123) < 400", + Expected: []sql.Row{{"first row"}, {"second row"}}, + }, { Query: `SELECT t1.pk from one_pk join (one_pk t1 join one_pk t2 on t1.pk = t2.pk) on t1.pk = one_pk.pk and one_pk.pk = 1 join (one_pk t3 join one_pk t4 on t3.c1 is not null) on t3.pk = one_pk.pk and one_pk.c1 = 10`, Expected: []sql.Row{{1}, {1}, {1}, {1}}, @@ -9696,7 +9191,7 @@ from typestable`, { Query: "select if('', 1, char(''));", Expected: []sql.Row{ - {[]byte{0}}, + {"\x00"}, }, }, { @@ -10309,6 +9804,28 @@ from typestable`, {2, "second row"}, }, }, + { + Query: "select * from two_pk group by pk1, pk2", + Expected: []sql.Row{ + {0, 0, 0, 1, 2, 3, 4}, + {0, 1, 10, 11, 12, 13, 14}, + {1, 0, 20, 21, 22, 23, 24}, + {1, 1, 30, 31, 32, 33, 34}, + }, + }, + { + Query: "select pk1+1 from two_pk group by pk1 + 1, mod(pk2, 2)", + Expected: []sql.Row{ + {1}, {1}, {2}, {2}, + }, + }, + { + Query: "select mod(pk2, 2) from two_pk group by pk1 + 1, mod(pk2, 2)", + Expected: []sql.Row{ + // mod is a Decimal type, which we convert to a string in our enginetests + {"0"}, {"1"}, {"0"}, {"1"}, + }, + }, } var KeylessQueries = []QueryTest{ @@ -10560,6 +10077,20 @@ FROM mytable;`, {"DECIMAL"}, }, }, + // https://github.com/dolthub/dolt/issues/7095 + // References in group by and having should be allowed to match select aliases + { + Query: "select y as z from xy group by (y) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y as z from xy group by (z) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1", + Expected: []sql.Row{{2}, {3}, {4}}, + }, } var VersionedQueries = []QueryTest{ @@ -11327,6 +10858,15 @@ var ErrorQueries = []QueryErrorTest{ Query: "SELECT 1 INTO mytable;", ExpectedErr: sql.ErrUndeclaredVariable, }, + { + Query: "select * from two_pk group by pk1", + ExpectedErr: analyzererrors.ErrValidationGroupBy, + }, + { + // Grouping over functions and math expressions over PK does not count, and must appear in select + Query: "select * from two_pk group by pk1 + 1, mod(pk2, 2)", + ExpectedErr: analyzererrors.ErrValidationGroupBy, + }, } var BrokenErrorQueries = []QueryErrorTest{ @@ -11343,10 +10883,10 @@ var BrokenErrorQueries = []QueryErrorTest{ ExpectedErr: sql.ErrTableNotFound, }, - // Our behavior in when sql_mode = ONLY_FULL_GROUP_BY is inconsistent with MySQL + // Our behavior in when sql_mode = ONLY_FULL_GROUP_BY is inconsistent with MySQL. This is because we skip validation + // for GroupBys wrapped in a Project since we are not able to validate selected expressions that get optimized as an + // alias. // Relevant issue: https://github.com/dolthub/dolt/issues/4998 - // Special case: If you are grouping by every field of the PK, then you can select anything - // Otherwise, whatever you are selecting must be in the Group By (with the exception of aggregations) { Query: "SELECT col0, floor(col1) FROM tab1 GROUP by col0;", ExpectedErr: analyzererrors.ErrValidationGroupBy, @@ -11355,39 +10895,11 @@ var BrokenErrorQueries = []QueryErrorTest{ Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", ExpectedErr: analyzererrors.ErrValidationGroupBy, }, - { - Query: "select * from two_pk group by pk1, pk2", - // No error - }, - { - Query: "select * from two_pk group by pk1", - ExpectedErr: analyzererrors.ErrValidationGroupBy, - }, - { - // Grouping over functions and math expressions over PK does not count, and must appear in select - Query: "select * from two_pk group by pk1 + 1, mod(pk2, 2)", - ExpectedErr: analyzererrors.ErrValidationGroupBy, - }, - { - // Grouping over functions and math expressions over PK does not count, and must appear in select - Query: "select pk1+1 from two_pk group by pk1 + 1, mod(pk2, 2)", - // No error - }, - { - // Grouping over functions and math expressions over PK does not count, and must appear in select - Query: "select mod(pk2, 2) from two_pk group by pk1 + 1, mod(pk2, 2)", - // No error - }, - { - // Grouping over functions and math expressions over PK does not count, and must appear in select - Query: "select mod(pk2, 2) from two_pk group by pk1 + 1, mod(pk2, 2)", - // No error - }, { Query: `SELECT any_value(pk), (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) AS x FROM one_pk opk WHERE (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) > 0 GROUP BY (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) ORDER BY x`, - // No error, but we get opk.pk does not exist + // No error, but we get opk.pk does not exist (aliasing error) }, // Unimplemented JSON functions { diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index c293f63241..a8e9b9724c 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -1440,19 +1440,15 @@ where " ├─ columns: [style.assetId:1]\n" + " └─ LookupJoin\n" + " ├─ LookupJoin\n" + - " │ ├─ Filter\n" + - " │ │ ├─ Eq\n" + - " │ │ │ ├─ style.val:3\n" + - " │ │ │ └─ curve (longtext)\n" + - " │ │ └─ TableAlias(style)\n" + - " │ │ └─ IndexedTableAccess(asset)\n" + - " │ │ ├─ index: [asset.orgId,asset.name,asset.assetId]\n" + - " │ │ ├─ static: [{[org1, org1], [style, style], [NULL, ∞)}]\n" + - " │ │ ├─ colSet: (1-5)\n" + - " │ │ ├─ tableId: 1\n" + - " │ │ └─ Table\n" + - " │ │ ├─ name: asset\n" + - " │ │ └─ columns: [orgid assetid name val]\n" + + " │ ├─ TableAlias(style)\n" + + " │ │ └─ IndexedTableAccess(asset)\n" + + " │ │ ├─ index: [asset.orgId,asset.name,asset.val]\n" + + " │ │ ├─ static: [{[org1, org1], [style, style], [curve, curve]}]\n" + + " │ │ ├─ colSet: (1-5)\n" + + " │ │ ├─ tableId: 1\n" + + " │ │ └─ Table\n" + + " │ │ ├─ name: asset\n" + + " │ │ └─ columns: [orgid assetid name val]\n" + " │ └─ Filter\n" + " │ ├─ AND\n" + " │ │ ├─ AND\n" + @@ -1498,15 +1494,13 @@ where "", ExpectedEstimates: "Project\n" + " ├─ columns: [style.assetId]\n" + - " └─ LookupJoin (estimated cost=16.500 rows=5)\n" + - " ├─ LookupJoin (estimated cost=16.500 rows=5)\n" + - " │ ├─ Filter\n" + - " │ │ ├─ (style.val = 'curve')\n" + - " │ │ └─ TableAlias(style)\n" + - " │ │ └─ IndexedTableAccess(asset)\n" + - " │ │ ├─ index: [asset.orgId,asset.name,asset.assetId]\n" + - " │ │ ├─ filters: [{[org1, org1], [style, style], [NULL, ∞)}]\n" + - " │ │ └─ columns: [orgid assetid name val]\n" + + " └─ LookupJoin (estimated cost=19.800 rows=6)\n" + + " ├─ LookupJoin (estimated cost=19.800 rows=6)\n" + + " │ ├─ TableAlias(style)\n" + + " │ │ └─ IndexedTableAccess(asset)\n" + + " │ │ ├─ index: [asset.orgId,asset.name,asset.val]\n" + + " │ │ ├─ filters: [{[org1, org1], [style, style], [curve, curve]}]\n" + + " │ │ └─ columns: [orgid assetid name val]\n" + " │ └─ Filter\n" + " │ ├─ (((dimension.val = 'wide') AND (dimension.name = 'dimension')) AND (dimension.orgId = 'org1'))\n" + " │ └─ TableAlias(dimension)\n" + @@ -1524,15 +1518,13 @@ where "", ExpectedAnalysis: "Project\n" + " ├─ columns: [style.assetId]\n" + - " └─ LookupJoin (estimated cost=16.500 rows=5) (actual rows=1 loops=1)\n" + - " ├─ LookupJoin (estimated cost=16.500 rows=5) (actual rows=1 loops=1)\n" + - " │ ├─ Filter\n" + - " │ │ ├─ (style.val = 'curve')\n" + - " │ │ └─ TableAlias(style)\n" + - " │ │ └─ IndexedTableAccess(asset)\n" + - " │ │ ├─ index: [asset.orgId,asset.name,asset.assetId]\n" + - " │ │ ├─ filters: [{[org1, org1], [style, style], [NULL, ∞)}]\n" + - " │ │ └─ columns: [orgid assetid name val]\n" + + " └─ LookupJoin (estimated cost=19.800 rows=6) (actual rows=1 loops=1)\n" + + " ├─ LookupJoin (estimated cost=19.800 rows=6) (actual rows=1 loops=1)\n" + + " │ ├─ TableAlias(style)\n" + + " │ │ └─ IndexedTableAccess(asset)\n" + + " │ │ ├─ index: [asset.orgId,asset.name,asset.val]\n" + + " │ │ ├─ filters: [{[org1, org1], [style, style], [curve, curve]}]\n" + + " │ │ └─ columns: [orgid assetid name val]\n" + " │ └─ Filter\n" + " │ ├─ (((dimension.val = 'wide') AND (dimension.name = 'dimension')) AND (dimension.orgId = 'org1'))\n" + " │ └─ TableAlias(dimension)\n" + @@ -6724,32 +6716,24 @@ inner join pq on true }, { Query: `SELECT * FROM one_pk_two_idx WHERE v1 IN (1, 2) AND v2 <= 2`, - ExpectedPlan: "Filter\n" + - " ├─ LessThanOrEqual\n" + - " │ ├─ one_pk_two_idx.v2:2\n" + - " │ └─ 2 (bigint)\n" + - " └─ IndexedTableAccess(one_pk_two_idx)\n" + - " ├─ index: [one_pk_two_idx.v1]\n" + - " ├─ static: [{[1, 1]}, {[2, 2]}]\n" + - " ├─ colSet: (1-3)\n" + - " ├─ tableId: 1\n" + - " └─ Table\n" + - " ├─ name: one_pk_two_idx\n" + - " └─ columns: [pk v1 v2]\n" + - "", - ExpectedEstimates: "Filter\n" + - " ├─ (one_pk_two_idx.v2 <= 2)\n" + - " └─ IndexedTableAccess(one_pk_two_idx)\n" + - " ├─ index: [one_pk_two_idx.v1]\n" + - " ├─ filters: [{[1, 1]}, {[2, 2]}]\n" + + ExpectedPlan: "IndexedTableAccess(one_pk_two_idx)\n" + + " ├─ index: [one_pk_two_idx.v1,one_pk_two_idx.v2]\n" + + " ├─ static: [{[1, 1], (NULL, 2]}, {[2, 2], (NULL, 2]}]\n" + + " ├─ colSet: (1-3)\n" + + " ├─ tableId: 1\n" + + " └─ Table\n" + + " ├─ name: one_pk_two_idx\n" + " └─ columns: [pk v1 v2]\n" + "", - ExpectedAnalysis: "Filter\n" + - " ├─ (one_pk_two_idx.v2 <= 2)\n" + - " └─ IndexedTableAccess(one_pk_two_idx)\n" + - " ├─ index: [one_pk_two_idx.v1]\n" + - " ├─ filters: [{[1, 1]}, {[2, 2]}]\n" + - " └─ columns: [pk v1 v2]\n" + + ExpectedEstimates: "IndexedTableAccess(one_pk_two_idx)\n" + + " ├─ index: [one_pk_two_idx.v1,one_pk_two_idx.v2]\n" + + " ├─ filters: [{[1, 1], (NULL, 2]}, {[2, 2], (NULL, 2]}]\n" + + " └─ columns: [pk v1 v2]\n" + + "", + ExpectedAnalysis: "IndexedTableAccess(one_pk_two_idx)\n" + + " ├─ index: [one_pk_two_idx.v1,one_pk_two_idx.v2]\n" + + " ├─ filters: [{[1, 1], (NULL, 2]}, {[2, 2], (NULL, 2]}]\n" + + " └─ columns: [pk v1 v2]\n" + "", }, { diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 1123a5b546..c5b97637f4 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -52,6 +52,9 @@ type ScriptTest struct { // Dialect is the supported dialect for this script, which must match the dialect of the harness if specified. // The script is skipped if the dialect doesn't match. Dialect string + // Skip is used to completely skip a test, not execute any part of the script, and to record it as a skipped test in + // the test suite results. + Skip bool } type ScriptTestAssertion struct { @@ -168,23 +171,6 @@ CREATE TABLE teams ( }, }, }, - { - Name: "alter nil enum", - Dialect: "mysql", - SetUpScript: []string{ - "create table xy (x int primary key, y enum ('a', 'b'));", - "insert into xy values (0, NULL),(1, 'b')", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "alter table xy modify y enum('a','b','c')", - }, - { - Query: "alter table xy modify y enum('a')", - ExpectedErr: types.ErrDataTruncatedForColumn, - }, - }, - }, { Name: "issue 7958, update join uppercase table name validation", SetUpScript: []string{ @@ -259,7 +245,8 @@ CREATE TABLE sourceTable_test ( }, }, { - Name: "GMS issue 2369", + // https://github.com/dolthub/go-mysql-server/issues/2369 + Name: "auto_increment with self-referencing foreign key", SetUpScript: []string{ `CREATE TABLE table1 ( id int NOT NULL AUTO_INCREMENT, @@ -292,6 +279,31 @@ CREATE TABLE sourceTable_test ( }, }, }, + { + // https://github.com/dolthub/go-mysql-server/issues/2349 + Name: "auto_increment with foreign key", + SetUpScript: []string{ + "CREATE TABLE table1 (id int NOT NULL AUTO_INCREMENT primary key, name text)", + ` +CREATE TABLE table2 ( + id int NOT NULL AUTO_INCREMENT, + name text, + fk int, + PRIMARY KEY (id), + CONSTRAINT myConstraint FOREIGN KEY (fk) REFERENCES table1 (id) +)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO table1 (name) VALUES ('tbl1 row 1');", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 1}}}, + }, + { + Query: "INSERT INTO table1 (name) VALUES ('tbl1 row 2');", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 2}}}, + }, + }, + }, { Name: "index match only exact string, no prefix", SetUpScript: []string{ @@ -514,7 +526,7 @@ SET entity_test.value = joined.value;`, Expected: []sql.Row{{1, "john", "doe", 0, 42}}, }, { - Query: "UPDATE test_users JOIN (SELECT id, 1 FROM test_users) AS tu SET test_users.favorite_number = 420;", + Query: "UPDATE test_users JOIN (SELECT 1 FROM test_users) AS tu SET test_users.favorite_number = 420;", Expected: []sql.Row{{NewUpdateResult(1, 1)}}, }, { @@ -522,7 +534,7 @@ SET entity_test.value = joined.value;`, Expected: []sql.Row{{1, "john", "doe", 0, 420}}, }, { - Query: "UPDATE test_users JOIN (SELECT id, 1 FROM test_users) AS tu SET test_users.deleted = 1;", + Query: "UPDATE test_users JOIN (SELECT 1 FROM test_users) AS tu SET test_users.deleted = 1;", Expected: []sql.Row{{NewUpdateResult(1, 1)}}, }, { @@ -531,30 +543,6 @@ SET entity_test.value = joined.value;`, }, }, }, - { - Name: "GMS issue 2349", - SetUpScript: []string{ - "CREATE TABLE table1 (id int NOT NULL AUTO_INCREMENT primary key, name text)", - ` -CREATE TABLE table2 ( - id int NOT NULL AUTO_INCREMENT, - name text, - fk int, - PRIMARY KEY (id), - CONSTRAINT myConstraint FOREIGN KEY (fk) REFERENCES table1 (id) -)`, - }, - Assertions: []ScriptTestAssertion{ - { - Query: "INSERT INTO table1 (name) VALUES ('tbl1 row 1');", - Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 1}}}, - }, - { - Query: "INSERT INTO table1 (name) VALUES ('tbl1 row 2');", - Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 2}}}, - }, - }, - }, { Name: "missing indexes", SetUpScript: []string{ @@ -1094,59 +1082,6 @@ CREATE TABLE tab3 ( }, }, }, - { - Name: "alter keyless table", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (c1 int, c2 varchar(200), c3 enum('one', 'two'));", - "insert into t values (1, 'one', NULL);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: `alter table t modify column c1 int unsigned`, - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "describe t;", - Expected: []sql.Row{ - {"c1", "int unsigned", "YES", "", nil, ""}, - {"c2", "varchar(200)", "YES", "", nil, ""}, - {"c3", "enum('one','two')", "YES", "", nil, ""}, - }, - }, - { - Query: `alter table t drop column c1;`, - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "describe t;", - Expected: []sql.Row{ - {"c2", "varchar(200)", "YES", "", nil, ""}, - {"c3", "enum('one','two')", "YES", "", nil, ""}, - }, - }, - { - Query: "alter table t add column new3 int;", - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: `insert into t values ('two', 'two', -2);`, - Expected: []sql.Row{{types.NewOkResult(1)}}, - }, - { - Query: "describe t;", - Expected: []sql.Row{ - {"c2", "varchar(200)", "YES", "", nil, ""}, - {"c3", "enum('one','two')", "YES", "", nil, ""}, - {"new3", "int", "YES", "", nil, ""}, - }, - }, - { - Query: "select * from t;", - Expected: []sql.Row{{"one", nil, nil}, {"two", "two", -2}}, - }, - }, - }, { Name: "topN stable output", SetUpScript: []string{ @@ -1180,99 +1115,6 @@ CREATE TABLE tab3 ( }, }, }, - { - Name: "enums with default, case-sensitive collation (utf8mb4_0900_bin)", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE enumtest1 (pk int primary key, e enum('abc', 'XYZ'));", - "CREATE TABLE enumtest2 (pk int PRIMARY KEY, e enum('x ', 'X ', 'y', 'Y'));", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "INSERT INTO enumtest1 VALUES (1, 'abc'), (2, 'abc'), (3, 'XYZ');", - Expected: []sql.Row{{types.NewOkResult(3)}}, - }, - { - Query: "SELECT * FROM enumtest1;", - Expected: []sql.Row{{1, "abc"}, {2, "abc"}, {3, "XYZ"}}, - }, - { - // enum values must match EXACTLY for case-sensitive collations - Query: "INSERT INTO enumtest1 VALUES (10, 'ABC'), (11, 'aBc'), (12, 'xyz');", - ExpectedErrStr: "value ABC is not valid for this Enum", - }, - { - Query: "SHOW CREATE TABLE enumtest1;", - Expected: []sql.Row{{ - "enumtest1", - "CREATE TABLE `enumtest1` (\n `pk` int NOT NULL,\n `e` enum('abc','XYZ'),\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, - }, - { - // Trailing whitespace should be removed from enum values, except when using the "binary" charset and collation - Query: "SHOW CREATE TABLE enumtest2;", - Expected: []sql.Row{{ - "enumtest2", - "CREATE TABLE `enumtest2` (\n `pk` int NOT NULL,\n `e` enum('x','X','y','Y'),\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, - }, - { - Query: "DESCRIBE enumtest1;", - Expected: []sql.Row{ - {"pk", "int", "NO", "PRI", nil, ""}, - {"e", "enum('abc','XYZ')", "YES", "", nil, ""}}, - }, - { - Query: "DESCRIBE enumtest2;", - Expected: []sql.Row{ - {"pk", "int", "NO", "PRI", nil, ""}, - {"e", "enum('x','X','y','Y')", "YES", "", nil, ""}}, - }, - { - Query: "select data_type, column_type from information_schema.columns where table_name='enumtest1' and column_name='e';", - Expected: []sql.Row{{"enum", "enum('abc','XYZ')"}}, - }, - { - Query: "select data_type, column_type from information_schema.columns where table_name='enumtest2' and column_name='e';", - Expected: []sql.Row{{"enum", "enum('x','X','y','Y')"}}, - }, - }, - }, - { - Name: "enums with case-insensitive collation (utf8mb4_0900_ai_ci)", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE enumtest1 (pk int primary key, e enum('abc', 'XYZ') collate utf8mb4_0900_ai_ci);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "INSERT INTO enumtest1 VALUES (1, 'abc'), (2, 'abc'), (3, 'XYZ');", - Expected: []sql.Row{{types.NewOkResult(3)}}, - }, - { - Query: "SHOW CREATE TABLE enumtest1;", - Expected: []sql.Row{{ - "enumtest1", - "CREATE TABLE `enumtest1` (\n `pk` int NOT NULL,\n `e` enum('abc','XYZ') COLLATE utf8mb4_0900_ai_ci,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, - }, - { - Query: "DESCRIBE enumtest1;", - Expected: []sql.Row{ - {"pk", "int", "NO", "PRI", nil, ""}, - {"e", "enum('abc','XYZ') COLLATE utf8mb4_0900_ai_ci", "YES", "", nil, ""}}, - }, - { - Query: "select data_type, column_type from information_schema.columns where table_name='enumtest1' and column_name='e';", - Expected: []sql.Row{{"enum", "enum('abc','XYZ')"}}, - }, - { - Query: "CREATE TABLE enumtest2 (pk int PRIMARY KEY, e enum('x ', 'X ', 'y', 'Y'));", - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "INSERT INTO enumtest1 VALUES (10, 'ABC'), (11, 'aBc'), (12, 'xyz');", - Expected: []sql.Row{{types.NewOkResult(3)}}, - }, - }, - }, { Name: "failed statements data validation for INSERT, UPDATE", SetUpScript: []string{ @@ -2775,12 +2617,10 @@ CREATE TABLE tab3 ( }, Assertions: []ScriptTestAssertion{ { - Skip: true, Query: "SELECT category, group_concat(name ORDER BY (SELECT COUNT(*) FROM test_data t2 WHERE t2.category = test_data.category AND t2.age < test_data.age)) FROM test_data GROUP BY category ORDER BY category", Expected: []sql.Row{{"A", "Charlie,Alice,Frank"}, {"B", "Bob,Eve"}, {"C", "Diana"}}, }, { - Skip: true, Query: "SELECT group_concat(name ORDER BY (SELECT AVG(age) FROM test_data t2 WHERE t2.category = test_data.category), id) FROM test_data;", Expected: []sql.Row{{"Alice,Charlie,Frank,Diana,Bob,Eve"}}, }, @@ -2804,22 +2644,18 @@ CREATE TABLE tab3 ( }, Assertions: []ScriptTestAssertion{ { - Skip: true, Query: "SELECT category_id, GROUP_CONCAT(name ORDER BY (SELECT rating FROM suppliers WHERE suppliers.id = products.supplier_id) DESC, id ASC) FROM products GROUP BY category_id ORDER BY category_id", Expected: []sql.Row{{1, "Laptop,Keyboard,Mouse,Monitor"}, {2, "Chair,Desk"}}, }, { - Skip: true, Query: "SELECT GROUP_CONCAT(name ORDER BY (SELECT COUNT(*) FROM products p2 WHERE p2.price < products.price), id) FROM products", Expected: []sql.Row{{"Mouse,Keyboard,Chair,Monitor,Desk,Laptop"}}, }, { - Skip: true, Query: "SELECT category_id, GROUP_CONCAT(DISTINCT supplier_id ORDER BY (SELECT rating FROM suppliers WHERE suppliers.id = products.supplier_id)) FROM products GROUP BY category_id", Expected: []sql.Row{{1, "2,1"}, {2, "3"}}, }, { - Skip: true, Query: "SELECT GROUP_CONCAT(name ORDER BY (SELECT priority FROM categories WHERE categories.id = products.category_id), price) FROM products", Expected: []sql.Row{{"Mouse,Keyboard,Monitor,Laptop,Chair,Desk"}}, }, @@ -2861,21 +2697,31 @@ CREATE TABLE tab3 ( Assertions: []ScriptTestAssertion{ { // Test with subquery returning NULL values - Skip: true, - Query: "SELECT category, GROUP_CONCAT(name ORDER BY (SELECT CASE WHEN complex_test.value > 80 THEN NULL ELSE complex_test.value END), name) FROM complex_test GROUP BY category ORDER BY category", - Expected: []sql.Row{{"X", "Alpha,Gamma"}, {"Y", "Epsilon,Beta"}, {"Z", "Delta"}}, + Query: "SELECT category, GROUP_CONCAT(name ORDER BY (SELECT CASE WHEN complex_test.value > 80 THEN NULL ELSE complex_test.value END), name) FROM complex_test GROUP BY category ORDER BY category", + Expected: []sql.Row{ + {"X", "Alpha,Gamma"}, + {"Y", "Epsilon,Beta"}, + {"Z", "Delta"}, + }, }, { // Test with correlated subquery using multiple tables - Skip: true, Query: "SELECT GROUP_CONCAT(name ORDER BY (SELECT COUNT(*) FROM complex_test c2 WHERE c2.category = complex_test.category AND c2.value > complex_test.value), name) FROM complex_test", Expected: []sql.Row{{"Alpha,Delta,Epsilon,Beta,Gamma"}}, }, + { + // Test with subquery using multiple columns errors + Query: "SELECT category, GROUP_CONCAT(name ORDER BY (SELECT AVG(value), name FROM complex_test c2 WHERE c2.id <= complex_test.id HAVING AVG(value) > 50) DESC) FROM complex_test GROUP BY category ORDER BY category", + ExpectedErr: sql.ErrInvalidOperandColumns, + }, { // Test with subquery using aggregate functions with HAVING - Skip: true, - Query: "SELECT category, GROUP_CONCAT(name ORDER BY (SELECT AVG(value), name FROM complex_test c2 WHERE c2.id <= complex_test.id HAVING AVG(value) > 50) DESC) FROM complex_test GROUP BY category ORDER BY category", - Expected: []sql.Row{{"X", "Alpha,Gamma"}, {"Y", "Beta,Epsilon"}, {"Z", "Delta"}}, + Query: "SELECT category, GROUP_CONCAT(name ORDER BY (SELECT AVG(value) FROM complex_test c2 WHERE c2.id <= complex_test.id HAVING AVG(value) > 50) DESC) FROM complex_test GROUP BY category ORDER BY category", + Expected: []sql.Row{ + {"X", "Alpha,Gamma"}, + {"Y", "Beta,Epsilon"}, + {"Z", "Delta"}, + }, }, { // Test with DISTINCT and complex subquery @@ -2884,9 +2730,8 @@ CREATE TABLE tab3 ( }, { // Test with nested subqueries - Skip: true, - Query: "SELECT GROUP_CONCAT(name ORDER BY (SELECT COUNT(*) FROM complex_test c2 WHERE c2.value > (SELECT MIN(value) FROM complex_test c3 WHERE c3.category = complex_test.category))) FROM complex_test", - Expected: []sql.Row{{"Gamma,Alpha,Epsilon,Beta,Delta"}}, + Query: "SELECT GROUP_CONCAT(name ORDER BY (SELECT SUM(value) FROM complex_test c2 WHERE c2.value != (SELECT MIN(value) FROM complex_test c3 where c3.id = complex_test.id))) FROM complex_test;", + Expected: []sql.Row{{"Alpha,Epsilon,Gamma,Beta,Delta"}}, }, }, }, @@ -2905,13 +2750,11 @@ CREATE TABLE tab3 ( }, { // Test with subquery using LIMIT - Skip: true, Query: "SELECT GROUP_CONCAT(data ORDER BY (SELECT weight FROM perf_test p2 WHERE p2.id = perf_test.id LIMIT 1)) FROM perf_test", Expected: []sql.Row{{"C,A,E,B,D"}}, }, { // Test with very small decimal differences in ORDER BY subquery - Skip: true, Query: "SELECT GROUP_CONCAT(data ORDER BY (SELECT weight + 0.001 * perf_test.id FROM perf_test p2 WHERE p2.id = perf_test.id)) FROM perf_test", Expected: []sql.Row{{"C,A,E,B,D"}}, }, @@ -3237,7 +3080,7 @@ CREATE TABLE tab3 ( // in +8:00 { Query: "set @@session.time_zone='+08:00'", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select from_unixtime(1)", @@ -3254,7 +3097,7 @@ CREATE TABLE tab3 ( // in utc { Query: "set @@session.time_zone='UTC'", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select from_unixtime(1)", @@ -3323,36 +3166,6 @@ CREATE TABLE tab3 ( // todo(max): fix arithmatic on bindvar typing SkipPrepared: true, }, - { - Name: "WHERE clause considers ENUM/SET types for comparisons", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 ENUM('a', 'b', 'c'), v2 SET('a', 'b', 'c'));", - "INSERT INTO test VALUES (1, 2, 2), (2, 1, 1);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "b", "b"}, {2, "a", "a"}}, - }, - { - Query: "UPDATE test SET v1 = 3 WHERE v1 = 2;", - Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 0, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}}, - }, - { - Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "b"}, {2, "a", "a"}}, - }, - { - Query: "UPDATE test SET v2 = 3 WHERE 2 = v2;", - Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 0, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}}, - }, - { - Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "a,b"}, {2, "a", "a"}}, - }, - }, - }, { Name: "Slightly more complex example for the Exists Clause", SetUpScript: []string{ @@ -3864,18 +3677,6 @@ CREATE TABLE tab3 ( }, }, }, - { - Name: "ALTER AUTO INCREMENT TABLE ADD column", - SetUpScript: []string{ - "CREATE TABLE test (pk int primary key, uk int UNIQUE KEY auto_increment);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "alter table test add column j int;", - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - }, - }, { Name: "alter json column default; from scorewarrior: https://github.com/dolthub/dolt/issues/4543", SetUpScript: []string{ @@ -4087,51 +3888,15 @@ CREATE TABLE tab3 ( }, }, { - Name: "ALTER TABLE MODIFY column with multiple UNIQUE KEYS", - Dialect: "mysql", + // https://github.com/dolthub/dolt/issues/3065 + Name: "join index lookups do not handle filters", SetUpScript: []string{ - "CREATE table test (pk int primary key, uk1 int, uk2 int, unique(uk1, uk2))", - "ALTER TABLE `test` MODIFY column uk1 int auto_increment", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "describe test", - Expected: []sql.Row{ - {"pk", "int", "NO", "PRI", nil, ""}, - {"uk1", "int", "NO", "MUL", nil, "auto_increment"}, - {"uk2", "int", "YES", "", nil, ""}, - }, - }, - }, - }, - { - Name: "ALTER TABLE MODIFY column with multiple KEYS", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE table test (pk int primary key, mk1 int, mk2 int, index(mk1, mk2))", - "ALTER TABLE `test` MODIFY column mk1 int auto_increment", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "describe test", - Expected: []sql.Row{ - {"pk", "int", "NO", "PRI", nil, ""}, - {"mk1", "int", "NO", "MUL", nil, "auto_increment"}, - {"mk2", "int", "YES", "", nil, ""}, - }, - }, - }, - }, - { - // https://github.com/dolthub/dolt/issues/3065 - Name: "join index lookups do not handle filters", - SetUpScript: []string{ - "create table a (x int primary key)", - "create table b (y int primary key, x int, index idx_x(x))", - "create table c (z int primary key, x int, y int, index idx_x(x))", - "insert into a values (0),(1),(2),(3)", - "insert into b values (0,1), (1,1), (2,2), (3,2)", - "insert into c values (0,1,0), (1,1,0), (2,2,1), (3,2,1)", + "create table a (x int primary key)", + "create table b (y int primary key, x int, index idx_x(x))", + "create table c (z int primary key, x int, y int, index idx_x(x))", + "insert into a values (0),(1),(2),(3)", + "insert into b values (0,1), (1,1), (2,2), (3,2)", + "insert into c values (0,1,0), (1,1,0), (2,2,1), (3,2,1)", }, Query: "select a.* from a join b on a.x = b.x join c where c.x = a.x and b.x = 1", Expected: []sql.Row{ @@ -4721,72 +4486,6 @@ CREATE TABLE tab3 ( }, }, }, - { - Name: "enum columns work as expected in when clauses", - Dialect: "mysql", - SetUpScript: []string{ - "create table enums (e enum('a'));", - "insert into enums values ('a');", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "select (case e when 'a' then 42 end) from enums", - Expected: []sql.Row{{42}}, - }, - { - Query: "select (case 'a' when e then 42 end) from enums", - Expected: []sql.Row{{42}}, - }, - }, - }, - { - Name: "SET and ENUM properly handle integers using UPDATE and DELETE statements", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE setenumtest (pk INT PRIMARY KEY, v1 ENUM('a', 'b', 'c'), v2 SET('a', 'b', 'c'));", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "INSERT INTO setenumtest VALUES (1, 1, 1), (2, 1, 1), (3, 3, 1), (4, 1, 3);", - Expected: []sql.Row{{types.NewOkResult(4)}}, - }, - { - Query: "UPDATE setenumtest SET v1 = 2, v2 = 2 WHERE pk = 2;", - Expected: []sql.Row{{types.OkResult{ - RowsAffected: 1, - Info: plan.UpdateInfo{ - Matched: 1, - Updated: 1, - Warnings: 0, - }, - }}}, - }, - { - Query: "SELECT * FROM setenumtest ORDER BY pk;", - Expected: []sql.Row{ - {1, "a", "a"}, - {2, "b", "b"}, - {3, "c", "a"}, - {4, "a", "a,b"}, - }, - }, - { - Query: "DELETE FROM setenumtest WHERE v1 = 3;", - Expected: []sql.Row{{types.NewOkResult(1)}}, - }, - { - Query: "DELETE FROM setenumtest WHERE v2 = 3;", - Expected: []sql.Row{{types.NewOkResult(1)}}, - }, - { - Query: "SELECT * FROM setenumtest ORDER BY pk;", - Expected: []sql.Row{ - {1, "a", "a"}, - {2, "b", "b"}, - }, - }, - }, - }, { Name: "identical expressions over different windows should produce different results", SetUpScript: []string{ @@ -4885,94 +4584,7 @@ CREATE TABLE tab3 ( }, }, }, - { - Name: "find_in_set tests", - Dialect: "mysql", - SetUpScript: []string{ - "create table set_tbl (i int primary key, s set('a','b','c'));", - "insert into set_tbl values (0, '');", - "insert into set_tbl values (1, 'a');", - "insert into set_tbl values (2, 'b');", - "insert into set_tbl values (3, 'c');", - "insert into set_tbl values (4, 'a,b');", - "insert into set_tbl values (6, 'b,c');", - "insert into set_tbl values (7, 'a,c');", - "insert into set_tbl values (8, 'a,b,c');", - - "create table collate_tbl (i int primary key, s varchar(10) collate utf8mb4_0900_ai_ci);", - "insert into collate_tbl values (0, '');", - "insert into collate_tbl values (1, 'a');", - "insert into collate_tbl values (2, 'b');", - "insert into collate_tbl values (3, 'c');", - "insert into collate_tbl values (4, 'a,b');", - "insert into collate_tbl values (6, 'b,c');", - "insert into collate_tbl values (7, 'a,c');", - "insert into collate_tbl values (8, 'a,b,c');", - "create table text_tbl (i int primary key, s text);", - "insert into text_tbl values (0, '');", - "insert into text_tbl values (1, 'a');", - "insert into text_tbl values (2, 'b');", - "insert into text_tbl values (3, 'c');", - "insert into text_tbl values (4, 'a,b');", - "insert into text_tbl values (6, 'b,c');", - "insert into text_tbl values (7, 'a,c');", - "insert into text_tbl values (8, 'a,b,c');", - - "create table enum_tbl (i int primary key, s enum('a','b','c'));", - "insert into enum_tbl values (0, 'a'), (1, 'b'), (2, 'c');", - "select i, s, find_in_set('a', s) from enum_tbl;", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "select i, find_in_set('a', s) from set_tbl;", - Expected: []sql.Row{ - {0, 0}, - {1, 1}, - {2, 0}, - {3, 0}, - {4, 1}, - {6, 0}, - {7, 1}, - {8, 1}, - }, - }, - { - Query: "select i, find_in_set('A', s) from collate_tbl;", - Expected: []sql.Row{ - {0, 0}, - {1, 1}, - {2, 0}, - {3, 0}, - {4, 1}, - {6, 0}, - {7, 1}, - {8, 1}, - }, - }, - { - Query: "select i, find_in_set('a', s) from text_tbl;", - Expected: []sql.Row{ - {0, 0}, - {1, 1}, - {2, 0}, - {3, 0}, - {4, 1}, - {6, 0}, - {7, 1}, - {8, 1}, - }, - }, - { - Query: "select i, find_in_set('a', s) from enum_tbl;", - Expected: []sql.Row{ - {0, 1}, - {1, 0}, - {2, 0}, - }, - }, - }, - }, { Name: "coalesce tests", Dialect: "mysql", @@ -5100,7 +4712,7 @@ CREATE TABLE tab3 ( { // Set the timezone set to UTC as an offset Query: `set @@time_zone='+00:00';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // When the session's time zone is set to UTC, NOW() and UTC_TIMESTAMP() should return the same value @@ -5114,7 +4726,7 @@ CREATE TABLE tab3 ( }, { Query: `set @@time_zone='+02:00';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // When the session's time zone is set to +2:00, NOW() should report two hours ahead of UTC_TIMESTAMP() @@ -5147,7 +4759,7 @@ CREATE TABLE tab3 ( }, { Query: `set @@time_zone='-08:00';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // TODO: Unskip after adding support for converting timestamp values to/from session time_zone @@ -5161,7 +4773,7 @@ CREATE TABLE tab3 ( }, { Query: `set @@time_zone='+5:00';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // Test with explicit timezone in datetime literal @@ -5180,7 +4792,7 @@ CREATE TABLE tab3 ( }, { Query: `set @@time_zone='+0:00';`, - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { // TODO: Unskip after adding support for converting timestamp values to/from session time_zone @@ -5338,7 +4950,7 @@ CREATE TABLE tab3 ( Assertions: []ScriptTestAssertion{ { Query: "SET time_zone = '+07:00';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT UNIX_TIMESTAMP('2023-09-25 07:02:57');", @@ -5350,7 +4962,7 @@ CREATE TABLE tab3 ( }, { Query: "SET time_zone = '+00:00';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT UNIX_TIMESTAMP('2023-09-25 07:02:57');", @@ -5358,7 +4970,7 @@ CREATE TABLE tab3 ( }, { Query: "SET time_zone = '-06:00';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT UNIX_TIMESTAMP('2023-09-25 07:02:57');", @@ -7860,336 +7472,130 @@ where }, }, { - Name: "preserve enums through alter statements", + Name: "coalesce with system types", SetUpScript: []string{ - "create table t (i int primary key, e enum('a', 'b', 'c'));", - "insert ignore into t values (0, 'error');", - "insert into t values (1, 'a');", - "insert into t values (2, 'b');", - "insert into t values (3, 'c');", + "create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;", }, Assertions: []ScriptTestAssertion{ { - Query: "select i, e, e + 0 from t;", - Expected: []sql.Row{ - {0, "", float64(0)}, - {1, "a", float64(1)}, - {2, "b", float64(2)}, - {3, "c", float64(3)}, - }, - }, - { - Query: "alter table t modify column e enum('c', 'a', 'b');", + Query: "describe t;", Expected: []sql.Row{ - {types.NewOkResult(0)}, + {"port1", "bigint", "NO", "", nil, ""}, + {"port2", "bigint", "NO", "", nil, ""}, + {"port3", "bigint", "NO", "", nil, ""}, }, }, + }, + }, + + { + Name: "not expression optimization", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int);", + "insert into t values (123);", + }, + Assertions: []ScriptTestAssertion{ { - Query: "select i, e, e + 0 from t;", + Query: "select * from t where 1 = (not(not(i)))", Expected: []sql.Row{ - {0, "", float64(0)}, - {1, "a", float64(2)}, - {2, "b", float64(3)}, - {3, "c", float64(1)}, + {123}, }, }, { - Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c');", + Query: "select * from t where true = (not(not(i)))", Expected: []sql.Row{ - {types.NewOkResult(0)}, + {123}, }, }, { - Query: "select i, e, e + 0 from t;", + Query: "select * from t where true = (not(not(i = 123)))", Expected: []sql.Row{ - {0, "", float64(0)}, - {1, "a", float64(2)}, - {2, "b", float64(3)}, - {3, "c", float64(4)}, + {123}, }, }, { - Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c', 'd');", + Query: "select * from t where false = (not(not(i != 123)))", Expected: []sql.Row{ - {types.NewOkResult(0)}, + {123}, }, }, { - Query: "select i, e, e + 0 from t;", + Query: "select * from t where i != (false or i);", Expected: []sql.Row{ - {0, "", float64(0)}, - {1, "a", float64(2)}, - {2, "b", float64(3)}, - {3, "c", float64(4)}, + {123}, }, }, { - Query: "alter table t modify column e enum('a', 'b', 'c');", + Query: "select * from t where ((true and -1) >= 0);", Expected: []sql.Row{ - {types.NewOkResult(0)}, + {123}, }, }, + }, + }, + { + Name: "negative int limits", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE TABLE t(i8 tinyint, i16 smallint, i24 mediumint, i32 int, i64 bigint);", + "INSERT INTO t VALUES(-128, -32768, -8388608, -2147483648, -9223372036854775808);", + }, + Assertions: []ScriptTestAssertion{ { - Query: "select i, e, e + 0 from t;", + SkipResultCheckOnServerEngine: true, + Query: "SELECT -i8, -i16, -i24, -i32 from t;", Expected: []sql.Row{ - {0, "", float64(0)}, - {1, "a", float64(1)}, - {2, "b", float64(2)}, - {3, "c", float64(3)}, + {128, 32768, 8388608, 2147483648}, }, }, { - Query: "alter table t modify column e enum('abc');", - ExpectedErr: types.ErrDataTruncatedForColumn, + Query: "SELECT -i64 from t;", + ExpectedErrStr: "BIGINT out of range for -9223372036854775808", }, }, }, { - Name: "coalesce with system types", + Name: "negative int limits", + Dialect: "mysql", SetUpScript: []string{ - "create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;", + "CREATE TABLE t(i8 tinyint, i16 smallint, i24 mediumint, i32 int, i64 bigint);", + "INSERT INTO t VALUES(-128, -32768, -8388608, -2147483648, -9223372036854775808);", }, Assertions: []ScriptTestAssertion{ { - Query: "describe t;", + SkipResultCheckOnServerEngine: true, + Query: "SELECT -i8, -i16, -i24, -i32 from t;", Expected: []sql.Row{ - {"port1", "bigint", "NO", "", nil, ""}, - {"port2", "bigint", "NO", "", nil, ""}, - {"port3", "bigint", "NO", "", nil, ""}, + {128, 32768, 8388608, 2147483648}, }, }, + { + Query: "SELECT -i64 from t;", + ExpectedErrStr: "BIGINT out of range for -9223372036854775808", + }, }, }, { - Name: "multi enum return types", + Name: "std, stdev, stddev_pop, variance, var_pop, var_samp tests", + Dialect: "mysql", SetUpScript: []string{ - "create table t (i int primary key, e enum('abc', 'def', 'ghi'));", - "insert into t values (1, 'abc'), (2, 'def'), (3, 'ghi');", + "create table t (i int);", + "create table tt (i int, j int);", + "insert into tt values (0, 1), (0, 2), (0, 3);", + "insert into tt values (1, 123), (1, 456), (1, 789);", }, Assertions: []ScriptTestAssertion{ { - Query: "select i, (case e when 'abc' then e when 'def' then e when 'ghi' then e end) as e from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {1, "abc"}, - {2, "def"}, - {3, "ghi"}, + {nil, nil, nil, nil}, }, }, { - // https://github.com/dolthub/dolt/issues/8598 - Skip: true, - Query: "select i, (case e when 'abc' then e when 'def' then e when 'ghi' then 'something' end) as e from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {1, "abc"}, - {2, "def"}, - {3, "something"}, - }, - }, - { - // https://github.com/dolthub/dolt/issues/8598 - Skip: true, - Query: "select i, (case e when 'abc' then e when 'def' then e when 'ghi' then 123 end) as e from t;", - Expected: []sql.Row{ - {1, "abc"}, - {2, "def"}, - {3, "123"}, - }, - }, - }, - }, - { - // https://github.com/dolthub/dolt/issues/8598 - Name: "enum cast to int and string", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (i int primary key, e enum('abc', 'def', 'ghi'));", - "insert into t values (1, 'abc'), (2, 'def'), (3, 'ghi');", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "select i, cast(e as signed) from t;", - Expected: []sql.Row{ - {1, 1}, - {2, 2}, - {3, 3}, - }, - }, - { - Query: "select i, cast(e as char) from t;", - Expected: []sql.Row{ - {1, "abc"}, - {2, "def"}, - {3, "ghi"}, - }, - }, - { - Query: "select i, cast(e as binary) from t;", - Expected: []sql.Row{ - {1, []uint8("abc")}, - {2, []uint8("def")}, - {3, []uint8("ghi")}, - }, - }, - { - Query: "select case when e = 'abc' then 'abc' when e = 'def' then 123 else e end from t", - Expected: []sql.Row{ - {"abc"}, - {"123"}, - {"ghi"}, - }, - }, - }, - }, - { - Name: "enum errors", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (i int primary key, e enum('abc', 'def', 'ghi'));", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "insert into t values (1, 500)", - ExpectedErrStr: "value 500 is not valid for this Enum", - }, - { - Query: "insert into t values (1, -1)", - ExpectedErrStr: "value -1 is not valid for this Enum", - }, - }, - }, - { - Name: "special case for not null default enum", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (i int primary key, e enum('abc', 'def', 'ghi') not null);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "insert into t(i) values (1)", - Expected: []sql.Row{ - {types.NewOkResult(1)}, - }, - }, - { - Query: "insert into t values (2, null)", - ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull, - }, - { - Query: "select * from t;", - Expected: []sql.Row{ - {1, "abc"}, - }, - }, - }, - }, - { - Name: "not expression optimization", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (i int);", - "insert into t values (123);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "select * from t where 1 = (not(not(i)))", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "select * from t where true = (not(not(i)))", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "select * from t where true = (not(not(i = 123)))", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "select * from t where false = (not(not(i != 123)))", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "select * from t where i != (false or i);", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "select * from t where ((true and -1) >= 0);", - Expected: []sql.Row{ - {123}, - }, - }, - }, - }, - { - Name: "negative int limits", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE t(i8 tinyint, i16 smallint, i24 mediumint, i32 int, i64 bigint);", - "INSERT INTO t VALUES(-128, -32768, -8388608, -2147483648, -9223372036854775808);", - }, - Assertions: []ScriptTestAssertion{ - { - SkipResultCheckOnServerEngine: true, - Query: "SELECT -i8, -i16, -i24, -i32 from t;", - Expected: []sql.Row{ - {128, 32768, 8388608, 2147483648}, - }, - }, - { - Query: "SELECT -i64 from t;", - ExpectedErrStr: "BIGINT out of range for -9223372036854775808", - }, - }, - }, - { - Name: "negative int limits", - Dialect: "mysql", - SetUpScript: []string{ - "CREATE TABLE t(i8 tinyint, i16 smallint, i24 mediumint, i32 int, i64 bigint);", - "INSERT INTO t VALUES(-128, -32768, -8388608, -2147483648, -9223372036854775808);", - }, - Assertions: []ScriptTestAssertion{ - { - SkipResultCheckOnServerEngine: true, - Query: "SELECT -i8, -i16, -i24, -i32 from t;", - Expected: []sql.Row{ - {128, 32768, 8388608, 2147483648}, - }, - }, - { - Query: "SELECT -i64 from t;", - ExpectedErrStr: "BIGINT out of range for -9223372036854775808", - }, - }, - }, - { - Name: "std, stdev, stddev_pop, variance, var_pop, var_samp tests", - Dialect: "mysql", - SetUpScript: []string{ - "create table t (i int);", - "create table tt (i int, j int);", - "insert into tt values (0, 1), (0, 2), (0, 3);", - "insert into tt values (1, 123), (1, 456), (1, 789);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", - Expected: []sql.Row{ - {nil, nil, nil, nil}, - }, - }, - { - Query: "select variance(i), var_pop(i), var_samp(i) from t;", - Expected: []sql.Row{ - {nil, nil, nil}, + {nil, nil, nil}, }, }, { @@ -8471,219 +7877,2368 @@ where }, }, { - Query: "select i, ntile(7) over() from t;", + Query: "select i, ntile(7) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(2)}, + {4, uint64(2)}, + {5, uint64(3)}, + {6, uint64(3)}, + {7, uint64(4)}, + {8, uint64(5)}, + {9, uint64(6)}, + {10, uint64(7)}, + }, + }, + { + Query: "select i, ntile(6) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(2)}, + {4, uint64(2)}, + {5, uint64(3)}, + {6, uint64(3)}, + {7, uint64(4)}, + {8, uint64(4)}, + {9, uint64(5)}, + {10, uint64(6)}, + }, + }, + { + Query: "select i, ntile(5) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(2)}, + {4, uint64(2)}, + {5, uint64(3)}, + {6, uint64(3)}, + {7, uint64(4)}, + {8, uint64(4)}, + {9, uint64(5)}, + {10, uint64(5)}, + }, + }, + { + Query: "select i, ntile(4) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(1)}, + {4, uint64(2)}, + {5, uint64(2)}, + {6, uint64(2)}, + {7, uint64(3)}, + {8, uint64(3)}, + {9, uint64(4)}, + {10, uint64(4)}, + }, + }, + { + Query: "select i, ntile(3) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(1)}, + {4, uint64(1)}, + {5, uint64(2)}, + {6, uint64(2)}, + {7, uint64(2)}, + {8, uint64(3)}, + {9, uint64(3)}, + {10, uint64(3)}, + }, + }, + { + Query: "select i, ntile(2) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(1)}, + {4, uint64(1)}, + {5, uint64(1)}, + {6, uint64(2)}, + {7, uint64(2)}, + {8, uint64(2)}, + {9, uint64(2)}, + {10, uint64(2)}, + }, + }, + { + Query: "select i, ntile(1) over() from t;", + Expected: []sql.Row{ + {1, uint64(1)}, + {2, uint64(1)}, + {3, uint64(1)}, + {4, uint64(1)}, + {5, uint64(1)}, + {6, uint64(1)}, + {7, uint64(1)}, + {8, uint64(1)}, + {9, uint64(1)}, + {10, uint64(1)}, + }, + }, + { + Query: "select i, j, ntile(2) over(partition by j) from t;", + Expected: []sql.Row{ + {1, 1, uint64(1)}, + {2, 1, uint64(1)}, + {3, 1, uint64(1)}, + {4, 1, uint64(2)}, + {5, 1, uint64(2)}, + {6, 2, uint64(1)}, + {7, 2, uint64(1)}, + {8, 2, uint64(1)}, + {9, 2, uint64(2)}, + {10, 2, uint64(2)}, + }, + }, + }, + }, + { + Name: "bit default value", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int primary key, b bit(2) default 2);", + "insert into t(i) values (1);", + "create table tt (b bit(2) default 2 primary key);", + "insert into tt values ();", + }, + Assertions: []ScriptTestAssertion{ + { + Skip: true, // this fails on server engine, even when skipped + Query: "select * from t;", + Expected: []sql.Row{ + {1, uint8(2)}, + }, + }, + { + Skip: true, // this fails on server engine, even when skipped + Query: "select * from tt;", + Expected: []sql.Row{ + {uint8(2)}, + }, + }, + }, + }, + { + Name: "hash tuples", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE TABLE test (id longtext);", + "INSERT INTO test (id) VALUES ('test_id');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM test WHERE id IN ('test_id');", + Expected: []sql.Row{ + {"test_id"}, + }, + }, + }, + }, + { + // This is a script test here because every table in the harness setup data is in all lowercase + Name: "case insensitive update with insubqueries and update joins", + Dialect: "mysql", + SetUpScript: []string{ + "create table MiXeDcAsE (i int primary key, j int)", + "insert into mixedcase values (1, 1);", + "insert into mixedcase values (2, 2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update mixedcase set j = 999 where i in (select 1)", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + }, + }}, + }, + }, + { + Query: "select * from mixedcase;", + Expected: []sql.Row{ + {1, 999}, + {2, 2}, + }, + }, + { + Query: " with cte(x) as (select 2) update mixedcase set j = 999 where i in (select x from cte)", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + }, + }}, + }, + }, + { + Query: "select * from mixedcase;", + Expected: []sql.Row{ + {1, 999}, + {2, 999}, + }, + }, + }, + }, + { + Name: "substring function tests with wrappers", + Dialect: "mysql", + SetUpScript: []string{ + "create table tbl (t text);", + "insert into tbl values ('abcdef');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select left(t, 3) from tbl;", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select right(t, 3) from tbl;", + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Query: "select instr(t, 'bcd') from tbl;", + Expected: []sql.Row{ + {2}, + }, + }, + }, + }, + { + Name: "tinyint column does not restrict IF or IFNULL output", + // https://github.com/dolthub/dolt/issues/9321 + SetUpScript: []string{ + "create table t0 (c0 tinyint);", + "insert into t0 values (null);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select ifnull(t0.c0, 128) as ref0 from t0", + Expected: []sql.Row{ + {128}, + }, + }, + { + Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0", + Expected: []sql.Row{{128}}, + }, + }, + }, + { + Name: "subquery with case insensitive collation", + Dialect: "mysql", + SetUpScript: []string{ + "create table tbl (t text) collate=utf8mb4_0900_ai_ci;", + "insert into tbl values ('abcdef');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select 'AbCdEf' in (select t from tbl);", + Expected: []sql.Row{ + {true}, + }, + }, + }, + }, + + // Char tests + { + Skip: true, + Name: "char with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (c char primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'c'", + }, + }, + }, + + // Varchar tests + { + Skip: true, + Name: "varchar with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (vc char(100) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'vc'", // We throw the wrong error + }, + }, + }, + + // Binary tests + { + Skip: true, + Name: "binary with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (b binary(100) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + }, + }, + + // Varbinary tests + { + Skip: true, + Name: "varbinary with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (vb varbinary(100) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'vb'", + }, + }, + }, + + // Blob tests + { + Skip: true, + Name: "blob with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (b blob primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + { + Query: "create table bad (tb tinyblob primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + { + Query: "create table bad (mb mediumblob primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + { + Query: "create table bad (lb longblob primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + }, + }, + + // Text Tests + { + Skip: true, + Name: "text with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (t text primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 't'", // We throw the wrong error + }, + { + Query: "create table bad (tt tinytext primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'tt'", // We throw the wrong error + }, + { + Query: "create table bad (mt mediumtext primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'mt'", // We throw the wrong error + }, + { + Query: "create table bad (lt longtext primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'lt'", // We throw the wrong error + }, + }, + }, + + // Enum tests + { + Name: "enum errors", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int primary key, e enum('abc', 'def', 'ghi'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1, 500)", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "insert into t values (1, -1)", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + }, + }, + { + Name: "enums with default, case-sensitive collation (utf8mb4_0900_bin)", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE TABLE enumtest1 (pk int primary key, e enum('abc', 'XYZ'));", + "CREATE TABLE enumtest2 (pk int PRIMARY KEY, e enum('x ', 'X ', 'y', 'Y'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO enumtest1 VALUES (1, 'abc'), (2, 'abc'), (3, 'XYZ');", + Expected: []sql.Row{{types.NewOkResult(3)}}, + }, + { + Query: "SELECT * FROM enumtest1;", + Expected: []sql.Row{{1, "abc"}, {2, "abc"}, {3, "XYZ"}}, + }, + { + // enum values must match EXACTLY for case-sensitive collations + Query: "INSERT INTO enumtest1 VALUES (10, 'ABC'), (11, 'aBc'), (12, 'xyz');", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "SHOW CREATE TABLE enumtest1;", + Expected: []sql.Row{{ + "enumtest1", + "CREATE TABLE `enumtest1` (\n `pk` int NOT NULL,\n `e` enum('abc','XYZ'),\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + // Trailing whitespace should be removed from enum values, except when using the "binary" charset and collation + Query: "SHOW CREATE TABLE enumtest2;", + Expected: []sql.Row{{ + "enumtest2", + "CREATE TABLE `enumtest2` (\n `pk` int NOT NULL,\n `e` enum('x','X','y','Y'),\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "DESCRIBE enumtest1;", + Expected: []sql.Row{ + {"pk", "int", "NO", "PRI", nil, ""}, + {"e", "enum('abc','XYZ')", "YES", "", nil, ""}}, + }, + { + Query: "DESCRIBE enumtest2;", + Expected: []sql.Row{ + {"pk", "int", "NO", "PRI", nil, ""}, + {"e", "enum('x','X','y','Y')", "YES", "", nil, ""}}, + }, + { + Query: "select data_type, column_type from information_schema.columns where table_name='enumtest1' and column_name='e';", + Expected: []sql.Row{{"enum", "enum('abc','XYZ')"}}, + }, + { + Query: "select data_type, column_type from information_schema.columns where table_name='enumtest2' and column_name='e';", + Expected: []sql.Row{{"enum", "enum('x','X','y','Y')"}}, + }, + }, + }, + { + Name: "enum columns work as expected in when clauses", + Dialect: "mysql", + SetUpScript: []string{ + "create table enums (e enum('a'));", + "insert into enums values ('a');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select (case e when 'a' then 42 end) from enums", + Expected: []sql.Row{{42}}, + }, + { + Query: "select (case 'a' when e then 42 end) from enums", + Expected: []sql.Row{{42}}, + }, + }, + }, + { + Name: "enums with case-insensitive collation (utf8mb4_0900_ai_ci)", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE TABLE enumtest1 (pk int primary key, e enum('abc', 'XYZ') collate utf8mb4_0900_ai_ci);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO enumtest1 VALUES (1, 'abc'), (2, 'abc'), (3, 'XYZ');", + Expected: []sql.Row{{types.NewOkResult(3)}}, + }, + { + Query: "SHOW CREATE TABLE enumtest1;", + Expected: []sql.Row{{ + "enumtest1", + "CREATE TABLE `enumtest1` (\n `pk` int NOT NULL,\n `e` enum('abc','XYZ') COLLATE utf8mb4_0900_ai_ci,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "DESCRIBE enumtest1;", + Expected: []sql.Row{ + {"pk", "int", "NO", "PRI", nil, ""}, + {"e", "enum('abc','XYZ') COLLATE utf8mb4_0900_ai_ci", "YES", "", nil, ""}}, + }, + { + Query: "select data_type, column_type from information_schema.columns where table_name='enumtest1' and column_name='e';", + Expected: []sql.Row{{"enum", "enum('abc','XYZ')"}}, + }, + { + Query: "CREATE TABLE enumtest2 (pk int PRIMARY KEY, e enum('x ', 'X ', 'y', 'Y'));", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "INSERT INTO enumtest1 VALUES (10, 'ABC'), (11, 'aBc'), (12, 'xyz');", + Expected: []sql.Row{{types.NewOkResult(3)}}, + }, + }, + }, + { + Name: "special case for not null default enum", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int primary key, e enum('abc', 'def', 'ghi') not null);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t;", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int NOT NULL,\n" + + " `e` enum('abc','def','ghi') NOT NULL,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t(i) values (1)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t values (2, null)", + ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull, + }, + { + Skip: true, + Query: "insert into t values (2, default)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1, "abc"}, + }, + }, + }, + }, + { + Name: "ensure that special case does not apply for nullable enums", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int primary key, e enum('abc', 'def', 'ghi'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t(i) values (1)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1, nil}, + }, + }, + }, + }, + { + Name: "enums with default values", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (e enum('a') primary key default null);", + ExpectedErr: sql.ErrIncompatibleDefaultType, + }, + { + Query: "create table bad (e enum('a') default 0);", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Query: "create table bad (e enum('a') default '');", + ExpectedErr: sql.ErrIncompatibleDefaultType, + }, + { + Query: "create table bad (e enum('a') default '1');", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Query: "create table bad (e enum('a') default 1);", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + + { + Query: "create table t1 (e enum('a') default 'a');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + // TODO: while this is round-trippable, it doesn't match MySQL + Skip: true, + Query: "show create table t1;", + Expected: []sql.Row{ + {"t1", "CREATE TABLE `t1` (\n" + + " `e` enum('a') DEFAULT 'a'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t1 values (default);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t1 values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t1() values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t1 order by e;", + Expected: []sql.Row{ + {"a"}, + {"a"}, + {"a"}, + }, + }, + { + Query: "insert into t1 values (null)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t1 order by e;", + Expected: []sql.Row{ + {nil}, + {"a"}, + {"a"}, + {"a"}, + }, + }, + + { + Query: "create table t2 (e enum('a') default (1));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "show create table t2;", + Expected: []sql.Row{ + {"t2", "CREATE TABLE `t2` (\n" + + " `e` enum('a') DEFAULT (1)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t2 values (default);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t2 values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t2() values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t2 order by e;", + Expected: []sql.Row{ + {"a"}, + {"a"}, + {"a"}, + }, + }, + { + Query: "insert into t2 values (null)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t2 order by e;", + Expected: []sql.Row{ + {nil}, + {"a"}, + {"a"}, + {"a"}, + }, + }, + + { + Query: "create table t3 (e enum('a') default ('1'));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + // TODO: we don't print the collation before the string + Skip: true, + Query: "show create table t3;", + Expected: []sql.Row{ + {"t3", "CREATE TABLE `t3` (\n" + + " `e` enum('a') DEFAULT (_utf8mb4'1')\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t3 values (default);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t3 values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into t3() values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t3 order by e;", + Expected: []sql.Row{ + {"a"}, + {"a"}, + {"a"}, + }, + }, + { + Query: "insert into t3 values (null)", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t3 order by e;", + Expected: []sql.Row{ + {nil}, + {"a"}, + {"a"}, + {"a"}, + }, + }, + }, + }, + { + Name: "enums with auto increment", + Dialect: "mysql", + SetUpScript: []string{ + "CREATE TABLE t (e enum('a', 'b', 'c') PRIMARY KEY)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE t2 (e enum('a', 'b', 'c') PRIMARY KEY AUTO_INCREMENT)", + ExpectedErrStr: "Incorrect column specifier for column 'e'", + }, + { + Query: "ALTER TABLE t MODIFY e enum('a', 'b', 'c') AUTO_INCREMENT", + ExpectedErrStr: "Incorrect column specifier for column 'e'", + }, + { + Query: "ALTER TABLE t MODIFY COLUMN e enum('a', 'b', 'c') AUTO_INCREMENT", + ExpectedErrStr: "Incorrect column specifier for column 'e'", + }, + { + Query: "ALTER TABLE t CHANGE e e enum('a', 'b', 'c') AUTO_INCREMENT", + ExpectedErrStr: "Incorrect column specifier for column 'e'", + }, + { + Query: "ALTER TABLE t CHANGE COLUMN e e enum('a', 'b', 'c') AUTO_INCREMENT", + ExpectedErrStr: "Incorrect column specifier for column 'e'", + }, + }, + }, + { + // This is with STRICT_TRANS_TABLES or STRICT_ALL_TABLES in sql_mode + Name: "enums with zero", + Dialect: "mysql", + SetUpScript: []string{ + "SET sql_mode = 'STRICT_TRANS_TABLES';", + "create table t (e enum('a', 'b', 'c'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (0);", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "insert into t values ('a'), (0), ('b');", + ExpectedErrStr: "Data truncated for column 'e' at row 2", + }, + { + Query: "create table tt (e enum('a', 'b', 'c') default 0)", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Query: "create table et (e enum('a', 'b', '', 'c'));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into et values (0);", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + }, + }, + { + Name: "enums with zero strict all tables", + Dialect: "mysql", + SetUpScript: []string{ + "SET sql_mode = 'STRICT_ALL_TABLES';", + "create table t (e enum('a', 'b', 'c'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (0);", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "insert into t values ('a'), (0), ('b');", + ExpectedErrStr: "Data truncated for column 'e' at row 2", + }, + { + Query: "create table tt (e enum('a', 'b', 'c') default 0)", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + }, + }, + { + Name: "enums with zero non-strict mode", + Dialect: "mysql", + SetUpScript: []string{ + "SET sql_mode = '';", + "create table t (e enum('a', 'b', 'c'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (0);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {""}, + }, + }, + }, + }, + { + Name: "enum import error message validation", + Dialect: "mysql", + SetUpScript: []string{ + "SET sql_mode = 'STRICT_TRANS_TABLES';", + "CREATE TABLE shirts (name VARCHAR(40), size ENUM('x-small', 'small', 'medium', 'large', 'x-large'), color ENUM('red', 'blue'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO shirts VALUES ('shirt1', 'x-small', 'red');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "INSERT INTO shirts VALUES ('shirt2', 'other', 'green');", + ExpectedErrStr: "Data truncated for column 'size' at row 1", + }, + }, + }, + { + Name: "enum default null validation", + Dialect: "mysql", + SetUpScript: []string{ + "SET sql_mode = 'STRICT_TRANS_TABLES';", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CREATE TABLE test_enum (pk int NOT NULL, e enum('a','b') DEFAULT NULL, PRIMARY KEY (pk));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "INSERT INTO test_enum (pk) VALUES (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "SELECT pk, e FROM test_enum;", + Expected: []sql.Row{ + {1, nil}, + }, + }, + }, + }, + { + // This is with STRICT_TRANS_TABLES or STRICT_ALL_TABLES in sql_mode + Skip: true, // TODO: Fix error type to match MySQL exactly (should be ErrInvalidColumnDefaultValue) + Name: "enums with empty string", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (e enum('a', 'b', 'c'));", + "create table et (e enum('a', 'b', '', 'c'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values ('');", + ExpectedErrStr: "Data truncated for column 'e'", // TODO should be truncated error + }, + { + Query: "create table tt (e enum('a', 'b', 'c') default '')", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Query: "insert into et values (1), (2), (3), (4), ('');", + Expected: []sql.Row{ + {types.NewOkResult(5)}, + }, + }, + { + Query: "select e, cast(e as signed) from et order by e;", + Expected: []sql.Row{ + {"a", 1}, + {"b", 2}, + {"", 3}, + {"", 3}, + {"c", 4}, + }, + }, + }, + }, + { + Name: "enum conversion to strings", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (e enum('abc', 'defg', 'hijkl'));", + "insert into t values(1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select e, length(e) from t order by e;", + Expected: []sql.Row{ + {"abc", 3}, + {"defg", 4}, + {"hijkl", 5}, + }, + }, + { + Query: "select e, concat(e, 'test') from t order by e;", + Expected: []sql.Row{ + {"abc", "abctest"}, + {"defg", "defgtest"}, + {"hijkl", "hijkltest"}, + }, + }, + { + Query: "select e, e like 'a%', e like '%g' from t order by e;", + Expected: []sql.Row{ + {"abc", true, false}, + {"defg", false, true}, + {"hijkl", false, false}, + }, + }, + { + Skip: true, + Query: "select e from t where e like 'a%' order by e;", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select group_concat(e order by e) as grouped from t;", + Expected: []sql.Row{ + {"abc,defg,hijkl"}, + }, + }, + { + Query: "select e from t where e = 'abc';", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select count(*) from t where e = 'defg';", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "select (case e when 'abc' then 42 end) from t order by e;", + Expected: []sql.Row{ + {42}, + {nil}, + {nil}, + }, + }, + { + Query: "select case when e = 'abc' then 'abc' when e = 'defg' then 123 else e end from t order by e;", + Expected: []sql.Row{ + {"abc"}, + {"123"}, + {"hijkl"}, + }, + }, + { + Query: "select (case 'abc' when e then 42 end) from t order by e;", + Expected: []sql.Row{ + {42}, + {nil}, + {nil}, + }, + }, + { + Query: "select (case e when 'abc' then e when 'defg' then e when 'hijkl' then e end) as e from t order by e;", + Expected: []sql.Row{ + {"abc"}, + {"defg"}, + {"hijkl"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/8598 + Skip: true, + Query: "select (case e when 'abc' then e when 'defg' then e when 'hijkl' then 'something' end) as e from t order by e;", + Expected: []sql.Row{ + {"abc"}, + {"defg"}, + {"something"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/8598 + Skip: true, + Query: "select (case e when 'abc' then e when 'defg' then e when 'hijkl' then 123 end) as e from t order by e;", + Expected: []sql.Row{ + {"123"}, + {"abc"}, + {"def"}, + }, + }, + { + Query: "select e, cast(e as signed) from t order by e;", + Expected: []sql.Row{ + {"abc", 1}, + {"defg", 2}, + {"hijkl", 3}, + }, + }, + { + Query: "select e, cast(e as char) from t order by e;", + Expected: []sql.Row{ + {"abc", "abc"}, + {"defg", "defg"}, + {"hijkl", "hijkl"}, + }, + }, + { + Query: "select e, cast(e as binary) from t order by e;", + Expected: []sql.Row{ + {"abc", []uint8("abc")}, + {"defg", []uint8("defg")}, + {"hijkl", []uint8("hijkl")}, + }, + }, + }, + }, + { + Name: "enum conversion with system variables", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (e enum('ON', 'OFF', 'AUTO'));", + "set autocommit = 'ON';", + "insert into t values(@@autocommit), ('OFF'), ('AUTO');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select e, @@autocommit, e = @@autocommit from t order by e;", + Expected: []sql.Row{ + {"ON", 1, true}, + {"OFF", 1, false}, + {"AUTO", 1, false}, + }, + }, + { + Query: "select e, concat(e, @@version_comment) from t order by e;", + Expected: []sql.Row{ + {"ON", "ONDolt"}, + {"OFF", "OFFDolt"}, + {"AUTO", "AUTODolt"}, + }, + }, + }, + }, + { + Name: "enums with foreign keys", + Dialect: "mysql", + SetUpScript: []string{ + "create table parent (e enum('a', 'b', 'c') primary key);", + "insert into parent values (1), (2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "create table child0 (e enum('a', 'b', 'c'), foreign key (e) references parent (e));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child0 values (1), (2), (NULL);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from child0 order by e", + Expected: []sql.Row{ + {nil}, + {"a"}, + {"b"}, + }, + }, + + { + Query: "create table child1 (e enum('x', 'y', 'z'), foreign key (e) references parent (e));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child1 values (1), (2);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child1 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child1 values ('x'), ('y');", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child1 values ('z');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child1 values ('a');", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "select * from child1 order by e;", + Expected: []sql.Row{ + {"x"}, + {"x"}, + {"y"}, + {"y"}, + }, + }, + + { + Query: "create table child2 (e enum('b', 'c', 'a'), foreign key (e) references parent (e));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child2 values (1), (2);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child2 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child2 values ('c');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into child2 values ('a');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "select * from child2 order by e;", + Expected: []sql.Row{ + {"b"}, + {"c"}, + {"c"}, + }, + }, + + { + Query: "create table child3 (e enum('x', 'y', 'z', 'a', 'b', 'c'), foreign key (e) references parent (e));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child3 values (1), (2);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child3 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child3 values ('x'), ('y');", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child3 values ('z');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child3 values ('a');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "select * from child3 order by e;", + Expected: []sql.Row{ + {"x"}, + {"x"}, + {"y"}, + {"y"}, + }, + }, + + { + Query: "create table child4 (e enum('q'), foreign key (e) references parent (e));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child4 values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into child4 values (3);", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "insert into child4 values ('q');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into child4 values ('a');", + ExpectedErrStr: "Data truncated for column 'e' at row 1", + }, + { + Query: "select * from child4 order by e;", + Expected: []sql.Row{ + {"q"}, + {"q"}, + }, + }, + }, + }, + { + Skip: true, + Name: "enums with foreign keys and cascade", + Dialect: "mysql", + SetUpScript: []string{ + "create table parent (e enum('a', 'b', 'c') primary key);", + "insert into parent values (1), (2);", + "create table child (e enum('x', 'y', 'z'), foreign key (e) references parent (e) on update cascade on delete cascade);", + "insert into child values (1), (2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update parent set e = 'c' where e = 'a';", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "select * from child order by e;", + Expected: []sql.Row{ + {"y"}, + {"z"}, + }, + }, + { + Query: "delete from parent where e = 'b';", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from child order by e;", + Expected: []sql.Row{ + {"z"}, + }, + }, + }, + }, + { + Name: "enums in update and delete statements", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (pk int primary key, e enum('abc', 'def', 'ghi'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1, 1), (2, 3), (3, 2);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "update t set e = 2 where e = 'ghi';", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + Warnings: 0, + }, + }}, + }, + }, + { + Query: "update t set e = 'ghi' where e = '3';", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 0, + Info: plan.UpdateInfo{ + Matched: 0, + Updated: 0, + Warnings: 0, + }, + }}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1, "abc"}, + {2, "def"}, + {3, "def"}, + }, + }, + { + Query: "delete from t where e = 2;", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {1, "abc"}, + }, + }, + }, + }, + { + // https://github.com/dolthub/dolt/issues/9024 + Name: "subqueries should coerce union types to enum", + Dialect: "mysql", + SetUpScript: []string{ + "create table enum_table (i int primary key, e enum('a','b') not null)", + "insert into enum_table values (1,'a'),(2,'b')", + "create table uv (u int primary key, v varchar(10))", + "insert into uv values (0, 'bug'),(1,'ant'),(3, null)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from (select e from enum_table union select v from uv) sq", + Expected: []sql.Row{{"a"}, {"b"}, {"bug"}, {"ant"}, {nil}}, + }, + { + Query: "with a as (select e from enum_table union select v from uv) select * from a", + Expected: []sql.Row{{"a"}, {"b"}, {"bug"}, {"ant"}, {nil}}, + }, + }, + }, + + // Set tests + { + Name: "find_in_set tests", + Dialect: "mysql", + SetUpScript: []string{ + "create table set_tbl (i int primary key, s set('a','b','c'));", + "insert into set_tbl values (0, '');", + "insert into set_tbl values (1, 'a');", + "insert into set_tbl values (2, 'b');", + "insert into set_tbl values (3, 'c');", + "insert into set_tbl values (4, 'a,b');", + "insert into set_tbl values (6, 'b,c');", + "insert into set_tbl values (7, 'a,c');", + "insert into set_tbl values (8, 'a,b,c');", + + "create table collate_tbl (i int primary key, s varchar(10) collate utf8mb4_0900_ai_ci);", + "insert into collate_tbl values (0, '');", + "insert into collate_tbl values (1, 'a');", + "insert into collate_tbl values (2, 'b');", + "insert into collate_tbl values (3, 'c');", + "insert into collate_tbl values (4, 'a,b');", + "insert into collate_tbl values (6, 'b,c');", + "insert into collate_tbl values (7, 'a,c');", + "insert into collate_tbl values (8, 'a,b,c');", + + "create table text_tbl (i int primary key, s text);", + "insert into text_tbl values (0, '');", + "insert into text_tbl values (1, 'a');", + "insert into text_tbl values (2, 'b');", + "insert into text_tbl values (3, 'c');", + "insert into text_tbl values (4, 'a,b');", + "insert into text_tbl values (6, 'b,c');", + "insert into text_tbl values (7, 'a,c');", + "insert into text_tbl values (8, 'a,b,c');", + + "create table enum_tbl (i int primary key, s enum('a','b','c'));", + "insert into enum_tbl values (0, 'a'), (1, 'b'), (2, 'c');", + "select i, s, find_in_set('a', s) from enum_tbl;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, find_in_set('a', s) from set_tbl;", + Expected: []sql.Row{ + {0, 0}, + {1, 1}, + {2, 0}, + {3, 0}, + {4, 1}, + {6, 0}, + {7, 1}, + {8, 1}, + }, + }, + { + Query: "select i, find_in_set('A', s) from collate_tbl;", + Expected: []sql.Row{ + {0, 0}, + {1, 1}, + {2, 0}, + {3, 0}, + {4, 1}, + {6, 0}, + {7, 1}, + {8, 1}, + }, + }, + { + Query: "select i, find_in_set('a', s) from text_tbl;", + Expected: []sql.Row{ + {0, 0}, + {1, 1}, + {2, 0}, + {3, 0}, + {4, 1}, + {6, 0}, + {7, 1}, + {8, 1}, + }, + }, + { + Query: "select i, find_in_set('a', s) from enum_tbl;", + Expected: []sql.Row{ + {0, 1}, + {1, 0}, + {2, 0}, + }, + }, + }, + }, + { + Name: "set with empty string", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int primary key, s set(''));", + "insert into t values (0, 0), (1, 1), (2, '');", + "create table tt (i int primary key, s set('something',''));", + "insert into tt values (0, 'something,'), (1, ',something,'), (2, ',,,,,,');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, s + 0, s from t;", + Expected: []sql.Row{ + {0, float64(0), ""}, + {1, float64(1), ""}, + {2, float64(0), ""}, + }, + }, + { + Query: "select i, s + 0, s from t where s = 0;", + Expected: []sql.Row{ + {0, float64(0), ""}, + {2, float64(0), ""}, + }, + }, + { + Skip: true, + Query: "select i, s + 0, s from t where s = '';", + Expected: []sql.Row{ + {0, float64(0), ""}, + {1, float64(1), ""}, // We miss this one + {2, float64(0), ""}, + }, + }, + { + Skip: true, + Query: "select i, s + 0, s from tt;", + Expected: []sql.Row{ + {0, float64(0), "something,"}, + {1, float64(1), "something,"}, + {2, float64(2), ""}, + }, + }, + }, + }, + { + Skip: true, + Name: "set conversion to strings", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (s set('abc', 'defg', 'hijkl'));", + "insert into t values(1), (2), (3), (7);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select s, length(s) from t order by s;", + Expected: []sql.Row{ + {"abc", 3}, + {"defg", 4}, + {"abc,defg", 8}, + {"abc,defg,hijkl", 14}, + }, + }, + { + Query: "select s, concat(s, 'test') from t order by s;", + Expected: []sql.Row{ + {"abc", "abctest"}, + {"defg", "defgtest"}, + {"abc,defg", "abc,defgtest"}, + {"abc,defg,hijkl", "abc,defg,hijkltest"}, + }, + }, + { + Query: "select s, s like 'a%', s like '%g' from t order by s;", + Expected: []sql.Row{ + {"abc", true, false}, + {"defg", false, true}, + {"abc,defg", true, true}, + {"abc,defg,hijkl", true, false}, + }, + }, + { + Query: "select s from t where s like 'a%' order by s;", + Expected: []sql.Row{ + {"abc"}, + {"abc,defg"}, + {"abc,defg,hijkl"}, + }, + }, + { + Query: "select group_concat(s order by s) as grouped from t;", + Expected: []sql.Row{ + {"abc,defg,abc,defg,abc,defg,hijkl"}, + }, + }, + { + Query: "select s from t where s = 'abc';", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select count(*) from t where s = 'defg';", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "select (case s when 'abc' then 42 end) from t order by s;", + Expected: []sql.Row{ + {42}, + {nil}, + {nil}, + {nil}, + }, + }, + { + Query: "select case when s = 'abc' then 'abc' when s = 'defg' then 123 else s end from t order by s;", + Expected: []sql.Row{ + {"abc"}, + {"123"}, + {"abc,defg"}, + {"abc,defg,hijkl"}, + }, + }, + { + Query: "select (case 'abc' when s then 42 end) from t order by s;", + Expected: []sql.Row{ + {42}, + {nil}, + {nil}, + {nil}, + }, + }, + { + Query: "select (case s when 'abc' then s when 'defg' then s when 'hijkl' then s end) as s from t order by s;", + Expected: []sql.Row{ + {nil}, + {nil}, + {"abc"}, + {"defg"}, + }, + }, + { + Query: "select (case s when 'abc' then s when 'defg' then s when 'hijkl' then 'something' end) as s from t order by s;", + Expected: []sql.Row{ + {nil}, + {nil}, + {"abc"}, + {"defg"}, + }, + }, + { + Query: "select (case s when 'abc' then s when 'defg' then s when 'hijkl' then 123 end) as s from t order by s;", + Expected: []sql.Row{ + {nil}, + {nil}, + {"abc"}, + {"defg"}, + }, + }, + { + Query: "select s, cast(s as signed) from t order by s;", + Expected: []sql.Row{ + {"abc", 1}, + {"defg", 2}, + {"abc,defg", 3}, + {"abc,defg,hijkl", 7}, + }, + }, + { + Query: "select s, cast(s as char) from t order by s;", + Expected: []sql.Row{ + {"abc", "abc"}, + {"abc,defg", "abc,defg"}, + {"abc,defg,hijkl", "abc,defg,hijkl"}, + }, + }, + { + Query: "select s, cast(s as binary) from t order by s;", + Expected: []sql.Row{ + {"abc", []uint8("abc")}, + {"abc,defg", []uint8("abc,defg")}, + {"abc,defg,hijkl", []uint8("abc,defg,hijkl")}, + }, + }, + }, + }, + { + Name: "set with duplicates", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (s set('a', 'b', 'c'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values ('a,b,a,c,a,b,b,b,c,c,c,a,a');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select s + 0, s from t;", + Expected: []sql.Row{ + {float64(7), "a,b,c"}, + }, + }, + { + // This is with STRICT_TRANS_TABLES; errors are warnings when not strict + Query: "create table tt (s set('a', 'a'));", + ExpectedErr: sql.ErrDuplicateEntrySet, + }, + }, + }, + { + Name: "set in update and delete statements", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (pk int primary key, s set('abc', 'def'));", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (0, 0), (1, 1), (2, 3), (3, 2);", + Expected: []sql.Row{ + {types.NewOkResult(4)}, + }, + }, + { + Query: "update t set s = 3 where s = 2;", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + Warnings: 0, + }, + }}, + }, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {0, ""}, + {1, "abc"}, + {2, "abc,def"}, + {3, "abc,def"}, + }, + }, + { + Query: "delete from t where s = 'abc,def'", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {0, ""}, + {1, "abc"}, + }, + }, + }, + }, + { + Skip: true, + Name: "set with auto increment", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (s set('a', 'b', 'c') primary key);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "create table t2 (s set('a', 'b', 'c') primary key auto_increment)", + ExpectedErrStr: "Incorrect column specifier for column 's'", + }, + { + Query: "alter table t modify s set('a', 'b', 'c') auto_increment;", + ExpectedErrStr: "Incorrect column specifier for column 's'", + }, + { + Query: "alter table t modify column s set('a', 'b', 'c') auto_increment;", + ExpectedErrStr: "Incorrect column specifier for column 's'", + }, + { + Query: "alter table t change s s set('a', 'b', 'c') auto_increment;", + ExpectedErrStr: "Incorrect column specifier for column 's'", + }, + { + Query: "alter table t change column s s set('a', 'b', 'c') auto_increment;", + ExpectedErrStr: "Incorrect column specifier for column 's'", + }, + }, + }, + { + Name: "set with default values", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Skip: true, + Query: "create table bad (s set('a', 'b', 'c') default 0);", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Skip: true, + Query: "create table bad (s set('a', 'b', 'c') default 1);", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Skip: true, + Query: "create table bad (s set('a', 'b', 'c') default 'notexists');", + ExpectedErr: sql.ErrInvalidColumnDefaultValue, + }, + { + Query: "create table t0 (s set('a', 'b', 'c') default (0));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into t0 values ();", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t0", + Expected: []sql.Row{ + {""}, + }, + }, + + { + Query: "create table t (s set('a', 'b', 'c') not null);", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Skip: true, + Query: "insert into t values ();", + ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn, // wrong error + }, + { + Skip: true, + Query: "insert into t values (default);", + ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn, // wrong error + }, + }, + }, + { + Name: "set with collations", + Dialect: "mysql", + SetUpScript: []string{ + "create table t1 (s set('a', 'b', 'c') collate utf8mb4_0900_ai_ci);", + "create table t2 (s set('a', 'b', 'c') collate utf8mb4_0900_bin);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t1;", + Expected: []sql.Row{ + {"t1", "CREATE TABLE `t1` (\n" + + " `s` set('a','b','c') COLLATE utf8mb4_0900_ai_ci\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t1 values ('A,B,c');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t1", + Expected: []sql.Row{ + {"a,b,c"}, + }, + }, + { + Query: "show create table t2;", + Expected: []sql.Row{ + {"t2", "CREATE TABLE `t2` (\n" + + " `s` set('a','b','c')\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "insert into t2 values ('A,B,c');", + ExpectedErr: sql.ErrInvalidSetValue, + }, + { + Query: "select * from t2", + Expected: []sql.Row{}, + }, + { + Query: "create table bad (s set('a', 'A') collate utf8mb4_0900_ai_ci);", + ExpectedErr: sql.ErrDuplicateEntrySet, + }, + }, + }, + { + Skip: true, + Name: "set with foreign keys", + Dialect: "mysql", + SetUpScript: []string{ + "create table parent (s set('a', 'b', 'c') primary key);", + "insert into parent values (1), (2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "create table child0 (s set('a', 'b', 'c'), foreign key (s) references parent (s));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child0 values (1), (2), (NULL);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from child0 order by s;", + Expected: []sql.Row{ + {nil}, + {"a"}, + {"b"}, + }, + }, + + { + Query: "create table child1 (s set('x', 'y', 'z'), foreign key (s) references parent (s));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child1 values (1), (2);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child1 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child1 values ('x'), ('y');", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child1 values ('z');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child1 values ('a');", + ExpectedErrStr: "Data truncated for column 's' at row 1", + }, + { + Query: "select * from child1 order by s;", + Expected: []sql.Row{ + {"x"}, + {"x"}, + {"y"}, + {"y"}, + }, + }, + + { + Query: "create table child2 (s set('b', 'c', 'a'), foreign key (s) references parent (s));", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "insert into child2 values (1), (2);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "insert into child2 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child2 values ('c');", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "insert into child2 values ('a');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "select * from child2 order by s;", + Expected: []sql.Row{ + {"b"}, + {"c"}, + {"c"}, + }, + }, + + { + Query: "create table child3 (s set('x', 'y', 'z', 'a', 'b', 'c'), foreign key (s) references parent (s));", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(2)}, - {4, uint64(2)}, - {5, uint64(3)}, - {6, uint64(3)}, - {7, uint64(4)}, - {8, uint64(5)}, - {9, uint64(6)}, - {10, uint64(7)}, + {types.NewOkResult(0)}, }, }, { - Query: "select i, ntile(6) over() from t;", + Query: "insert into child3 values (1), (2);", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(2)}, - {4, uint64(2)}, - {5, uint64(3)}, - {6, uint64(3)}, - {7, uint64(4)}, - {8, uint64(4)}, - {9, uint64(5)}, - {10, uint64(6)}, + {types.NewOkResult(2)}, }, }, { - Query: "select i, ntile(5) over() from t;", + Query: "insert into child3 values (3);", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child3 values ('x'), ('y');", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(2)}, - {4, uint64(2)}, - {5, uint64(3)}, - {6, uint64(3)}, - {7, uint64(4)}, - {8, uint64(4)}, - {9, uint64(5)}, - {10, uint64(5)}, + {types.NewOkResult(2)}, }, }, { - Query: "select i, ntile(4) over() from t;", + Query: "insert into child3 values ('z');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into child3 values ('a');", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "select * from child3 order by s;", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(1)}, - {4, uint64(2)}, - {5, uint64(2)}, - {6, uint64(2)}, - {7, uint64(3)}, - {8, uint64(3)}, - {9, uint64(4)}, - {10, uint64(4)}, + {"x"}, + {"x"}, + {"y"}, + {"y"}, }, }, + { - Query: "select i, ntile(3) over() from t;", + Query: "create table child4 (s set('q'), foreign key (s) references parent (s));", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(1)}, - {4, uint64(1)}, - {5, uint64(2)}, - {6, uint64(2)}, - {7, uint64(2)}, - {8, uint64(3)}, - {9, uint64(3)}, - {10, uint64(3)}, + {types.NewOkResult(0)}, }, }, { - Query: "select i, ntile(2) over() from t;", + Query: "insert into child4 values (1);", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(1)}, - {4, uint64(1)}, - {5, uint64(1)}, - {6, uint64(2)}, - {7, uint64(2)}, - {8, uint64(2)}, - {9, uint64(2)}, - {10, uint64(2)}, + {types.NewOkResult(1)}, }, }, { - Query: "select i, ntile(1) over() from t;", + Query: "insert into child4 values (3);", + ExpectedErrStr: "Data truncated for column 's' at row 1", + }, + { + Query: "insert into child4 values ('q');", Expected: []sql.Row{ - {1, uint64(1)}, - {2, uint64(1)}, - {3, uint64(1)}, - {4, uint64(1)}, - {5, uint64(1)}, - {6, uint64(1)}, - {7, uint64(1)}, - {8, uint64(1)}, - {9, uint64(1)}, - {10, uint64(1)}, + {types.NewOkResult(1)}, }, }, { - Query: "select i, j, ntile(2) over(partition by j) from t;", + Query: "insert into child4 values ('a');", + ExpectedErrStr: "Data truncated for column 's' at row 1", + }, + { + Query: "select * from child4 order by s;", Expected: []sql.Row{ - {1, 1, uint64(1)}, - {2, 1, uint64(1)}, - {3, 1, uint64(1)}, - {4, 1, uint64(2)}, - {5, 1, uint64(2)}, - {6, 2, uint64(1)}, - {7, 2, uint64(1)}, - {8, 2, uint64(1)}, - {9, 2, uint64(2)}, - {10, 2, uint64(2)}, + {"q"}, + {"q"}, }, }, }, }, { - Name: "bit default value", + Skip: true, + Name: "set with foreign keys and cascade", Dialect: "mysql", SetUpScript: []string{ - "create table t (i int primary key, b bit(2) default 2);", - "insert into t(i) values (1);", - "create table tt (b bit(2) default 2 primary key);", - "insert into tt values ();", + "create table parent (s set('a', 'b', 'c') primary key);", + "insert into parent values (1), (2);", + "create table child (s set('x', 'y', 'z'), foreign key (s) references parent (s) on update cascade on delete cascade);", + "insert into child values (1), (2);", }, Assertions: []ScriptTestAssertion{ { - Skip: true, // this fails on server engine, even when skipped - Query: "select * from t;", + Query: "update parent set s = 'c' where s = 'a';", Expected: []sql.Row{ - {1, uint8(2)}, + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, }, }, { - Skip: true, // this fails on server engine, even when skipped - Query: "select * from tt;", + Query: "select * from child order by s;", Expected: []sql.Row{ - {uint8(2)}, + {"y"}, + {"z"}, + }, + }, + { + Query: "delete from parent where s = 'b';", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from child order by s;", + Expected: []sql.Row{ + {"z"}, }, }, }, }, + + // Bit Tests { - Name: "hash tuples", + Skip: true, + Name: "bit with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (b bit(1) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + { + Query: "create table bad (b bit(64) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'b'", + }, + }, + }, + + // Bool Tests + { + Name: "bool with auto_increment", Dialect: "mysql", SetUpScript: []string{ - "CREATE TABLE test (id longtext);", - "INSERT INTO test (id) VALUES ('test_id');", + "create table bool_tbl (b bool primary key auto_increment);", }, Assertions: []ScriptTestAssertion{ { - Query: "SELECT * FROM test WHERE id IN ('test_id');", + Query: "show create table bool_tbl;", Expected: []sql.Row{ - {"test_id"}, + {"bool_tbl", "CREATE TABLE `bool_tbl` (\n" + + " `b` tinyint(1) NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`b`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, }, }, + + // Int Tests { - // This is a script test here because every table in the harness setup data is in all lowercase - Name: "case insensitive update with insubqueries and update joins", + Name: "int with auto_increment", Dialect: "mysql", SetUpScript: []string{ - "create table MiXeDcAsE (i int primary key, j int)", - "insert into mixedcase values (1, 1);", - "insert into mixedcase values (2, 2);", + "create table int_tbl (i int primary key auto_increment);", + "create table tinyint_tbl (i tinyint primary key auto_increment);", + "create table smallint_tbl (i smallint primary key auto_increment);", + "create table mediumint_tbl (i mediumint primary key auto_increment);", + "create table bigint_tbl (i bigint primary key auto_increment);", }, Assertions: []ScriptTestAssertion{ { - Query: "update mixedcase set j = 999 where i in (select 1)", + Query: "show create table int_tbl;", Expected: []sql.Row{ - {types.OkResult{ - RowsAffected: 1, - Info: plan.UpdateInfo{ - Matched: 1, - Updated: 1, - }, - }}, + {"int_tbl", "CREATE TABLE `int_tbl` (\n" + + " `i` int NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Query: "select * from mixedcase;", + Query: "show create table tinyint_tbl;", Expected: []sql.Row{ - {1, 999}, - {2, 2}, + {"tinyint_tbl", "CREATE TABLE `tinyint_tbl` (\n" + + " `i` tinyint NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Query: " with cte(x) as (select 2) update mixedcase set j = 999 where i in (select x from cte)", + Query: "show create table smallint_tbl;", Expected: []sql.Row{ - {types.OkResult{ - RowsAffected: 1, - Info: plan.UpdateInfo{ - Matched: 1, - Updated: 1, - }, - }}, + {"smallint_tbl", "CREATE TABLE `smallint_tbl` (\n" + + " `i` smallint NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Query: "select * from mixedcase;", + Query: "show create table mediumint_tbl;", Expected: []sql.Row{ - {1, 999}, - {2, 999}, + {"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" + + " `i` mediumint NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "show create table bigint_tbl;", + Expected: []sql.Row{ + {"bigint_tbl", "CREATE TABLE `bigint_tbl` (\n" + + " `i` bigint NOT NULL AUTO_INCREMENT,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, }, }, + + // Float Tests + { + Skip: true, + Name: "float with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (f float primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'f'", + }, + }, + }, + + // Double Tests + { + Skip: true, + Name: "double with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (d double primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'vc'", + }, + }, + }, + + // Decimal Tests + { + Skip: true, + Name: "decimal with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (d decimal primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'd'", + }, + { + Query: "create table bad (d decimal(65,30) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'd'", + }, + }, + }, + + // Date Tests + { + Skip: true, + Name: "date with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (d date primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'd'", + }, + }, + }, + + // Datetime Tests + { + Skip: true, + Name: "datetime with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (dt datetime primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'dt'", + }, + { + Query: "create table bad (dt datetime(6) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'dt'", + }, + }, + }, + + // Timestamp Tests + { + Skip: true, + Name: "timestamp with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (ts timestamp primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'ts'", + }, + { + Query: "create table bad (ts timestamp(6) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'ts'", + }, + }, + }, + + // Time Tests + { + Skip: true, + Name: "time with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (t time primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 't'", + }, + { + Query: "create table bad (t time(6) primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 't'", + }, + }, + }, + + // Year Tests + { + Skip: true, + Name: "year with auto_increment", + Dialect: "mysql", + SetUpScript: []string{}, + Assertions: []ScriptTestAssertion{ + { + Query: "create table bad (y year primary key auto_increment);", + ExpectedErrStr: "Incorrect column specifier for column 'y'", + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ @@ -9892,7 +11447,7 @@ var BrokenScriptTests = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "SET SESSION time_zone = '-05:00';", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "SELECT DATE_FORMAT(ts, '%H:%i:%s'), DATE_FORMAT(dt, '%H:%i:%s') from timezone_test;", diff --git a/enginetest/queries/transaction_queries.go b/enginetest/queries/transaction_queries.go index bdc1fb753a..b06ae92bb2 100644 --- a/enginetest/queries/transaction_queries.go +++ b/enginetest/queries/transaction_queries.go @@ -40,11 +40,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ select @@autocommit;", @@ -120,11 +120,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ select * from t order by x", @@ -191,11 +191,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ insert into t values (2,2)", @@ -208,7 +208,7 @@ var TransactionTests = []TransactionTest{ // should commit any pending transaction { Query: "/* client b */ set autocommit = on", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ select * from t order by x", @@ -217,7 +217,7 @@ var TransactionTests = []TransactionTest{ // client a sees the committed transaction from client b when it begins a new transaction { Query: "/* client a */ set autocommit = on", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ select * from t order by x", @@ -283,11 +283,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction", @@ -360,11 +360,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction", @@ -529,11 +529,11 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction", @@ -666,15 +666,15 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client b */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client c */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, // Client a starts by insert into t { @@ -958,7 +958,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set autocommit = off", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ create temporary table tmp(pk int primary key)", @@ -1074,7 +1074,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1131,7 +1131,7 @@ var TransactionTests = []TransactionTest{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1243,7 +1243,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1285,7 +1285,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1327,7 +1327,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1365,7 +1365,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1386,7 +1386,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1408,7 +1408,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", @@ -1430,7 +1430,7 @@ var TransactionTests = []TransactionTest{ Assertions: []ScriptTestAssertion{ { Query: "/* client a */ set @@autocommit = 0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "/* client a */ start transaction;", diff --git a/enginetest/queries/trigger_queries.go b/enginetest/queries/trigger_queries.go index f817eac804..c18a78ed6e 100644 --- a/enginetest/queries/trigger_queries.go +++ b/enginetest/queries/trigger_queries.go @@ -3784,6 +3784,42 @@ end; }, }, }, + + // Invalid triggers + { + Name: "insert trigger with subquery projections", + SetUpScript: []string{ + "create table t (i int);", + "create trigger trig before insert on t for each row begin replace into t select 1; end;", + "alter table t add column j int;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create trigger trig", + Expected: []sql.Row{ + { + "trig", + "", + "create trigger trig before insert on t for each row begin replace into t select 1; end", + sql.Collation_Default.CharacterSet().String(), + sql.Collation_Default.String(), + sql.Collation_Default.String(), + time.Unix(0, 0).UTC(), + }, + }, + }, + { + Query: "insert into t values (1, 2);", + ExpectedErr: sql.ErrInsertIntoMismatchValueCount, + }, + { + Query: "drop trigger trig;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + }, + }, } var TriggerCreateInSubroutineTests = []ScriptTest{ diff --git a/enginetest/queries/update_queries.go b/enginetest/queries/update_queries.go index b5c313bce4..28490233c7 100644 --- a/enginetest/queries/update_queries.go +++ b/enginetest/queries/update_queries.go @@ -24,7 +24,7 @@ import ( "github.com/dolthub/vitess/go/mysql" ) -var UpdateTests = []WriteQueryTest{ +var UpdateWriteQueryTests = []WriteQueryTest{ { WriteQuery: "UPDATE mytable SET s = 'updated';", ExpectedWriteResult: []sql.Row{{NewUpdateResult(3, 3)}}, @@ -470,6 +470,191 @@ var UpdateTests = []WriteQueryTest{ }, } +var UpdateScriptTests = []ScriptTest{ + { + Dialect: "mysql", + Name: "UPDATE join – single table, with FK constraint", + SetUpScript: []string{ + "CREATE TABLE customers (id INT PRIMARY KEY, name TEXT);", + "CREATE TABLE orders (id INT PRIMARY KEY, customer_id INT, amount INT, FOREIGN KEY (customer_id) REFERENCES customers(id));", + "INSERT INTO customers VALUES (1, 'Alice'), (2, 'Bob');", + "INSERT INTO orders VALUES (101, 1, 50), (102, 2, 75);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "SELECT * FROM orders;", + Expected: []sql.Row{ + {101, 1, 50}, {102, 2, 75}, + }, + }, + }, + }, + { + Dialect: "mysql", + Name: "UPDATE join – multiple tables, with FK constraint", + SetUpScript: []string{ + "CREATE TABLE parent1 (id INT PRIMARY KEY);", + "CREATE TABLE parent2 (id INT PRIMARY KEY);", + "CREATE TABLE child1 (id INT PRIMARY KEY, p1_id INT, FOREIGN KEY (p1_id) REFERENCES parent1(id));", + "CREATE TABLE child2 (id INT PRIMARY KEY, p2_id INT, FOREIGN KEY (p2_id) REFERENCES parent2(id));", + "INSERT INTO parent1 VALUES (1), (3);", + "INSERT INTO parent2 VALUES (1), (3);", + "INSERT INTO child1 VALUES (10, 1);", + "INSERT INTO child2 VALUES (20, 1);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `UPDATE child1 c1 + JOIN child2 c2 ON c1.id = 10 AND c2.id = 20 + SET c1.p1_id = 999, c2.p2_id = 3;`, + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: `UPDATE child1 c1 + JOIN child2 c2 ON c1.id = 10 AND c2.id = 20 + SET c1.p1_id = 3, c2.p2_id = 999;`, + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "SELECT * FROM child1;", + Expected: []sql.Row{{10, 1}}, + }, + { + Query: "SELECT * FROM child2;", + Expected: []sql.Row{{20, 1}}, + }, + }, + }, + { + Dialect: "mysql", + Name: "UPDATE join – multiple tables, with trigger", + SetUpScript: []string{ + "CREATE TABLE a (id INT PRIMARY KEY, x INT);", + "CREATE TABLE b (pk INT PRIMARY KEY, y INT);", + "CREATE TABLE logbook (entry TEXT);", + `CREATE TRIGGER trig_a AFTER UPDATE ON a FOR EACH ROW + BEGIN + INSERT INTO logbook VALUES ('a updated'); + END;`, + `CREATE TRIGGER trig_b AFTER UPDATE ON b FOR EACH ROW + BEGIN + INSERT INTO logbook VALUES ('b updated'); + END;`, + "INSERT INTO a VALUES (5, 100);", + "INSERT INTO b VALUES (6, 200);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `UPDATE a + JOIN b ON a.id = 5 AND b.pk = 6 + SET a.x = 101, b.y = 201;`, + }, + { + Query: "SELECT * FROM logbook ORDER BY entry;", + Expected: []sql.Row{ + {"a updated"}, + {"b updated"}, + }, + }, + }, + }, + { + Dialect: "mysql", + Name: "UPDATE join – multiple tables with triggers that reference row values", + SetUpScript: []string{ + "create table customers (id int primary key, name text, tier text)", + "create table orders (order_id int primary key, customer_id int, status text)", + "create table trigger_log (msg text)", + `CREATE TRIGGER after_orders_update after update on orders for each row + begin + insert into trigger_log (msg) values( + concat('Order ', OLD.order_id, ' status changed from ', OLD.status, ' to ', NEW.status)); + end;`, + `Create trigger after_customers_update after update on customers for each row + begin + insert into trigger_log (msg) values( + concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier)); + end;`, + "insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');", + "insert into orders values (101, 1, 'pending'), (102, 2, 'pending');", + "update customers c join orders o on c.id = o.customer_id " + + "set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM trigger_log order by msg;", + Expected: []sql.Row{ + {"Customer 1 tier changed from silver to platinum"}, + {"Customer 2 tier changed from gold to platinum"}, + {"Order 101 status changed from pending to shipped"}, + {"Order 102 status changed from pending to shipped"}, + }, + }, + }, + }, + { + Dialect: "mysql", + Name: "UPDATE join – multiple tables with same column names with triggers", + SetUpScript: []string{ + "create table customers (id int primary key, name text, tier text)", + "create table orders (id int primary key, customer_id int, status text)", + "create table trigger_log (msg text)", + `CREATE TRIGGER after_orders_update after update on orders for each row + begin + insert into trigger_log (msg) values( + concat('Order ', OLD.id, ' status changed from ', OLD.status, ' to ', NEW.status)); + end;`, + `Create trigger after_customers_update after update on customers for each row + begin + insert into trigger_log (msg) values( + concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier)); + end;`, + "insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');", + "insert into orders values (101, 1, 'pending'), (102, 2, 'pending');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update customers c join orders o on c.id = o.customer_id " + + "set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'", + // TODO: we shouldn't expect an error once we're able to handle conflicting column names + // https://github.com/dolthub/dolt/issues/9403 + ExpectedErrStr: "Unable to apply triggers when joined tables have columns with the same name", + }, + { + // TODO: unskip once we're able to handle conflicting column names + // https://github.com/dolthub/dolt/issues/9403 + Skip: true, + Query: "SELECT * FROM trigger_log order by msg;", + Expected: []sql.Row{ + {"Customer 1 tier changed from silver to platinum"}, + {"Customer 2 tier changed from gold to platinum"}, + {"Order 101 status changed from pending to shipped"}, + {"Order 102 status changed from pending to shipped"}, + }, + }, + }, + }, + { + Name: "UPDATE with subquery in keyless tables", + // https://github.com/dolthub/dolt/issues/9334 + SetUpScript: []string{ + "create table t (i int)", + "insert into t values (1)", + "update t set i = 10 where i in (select 1)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from t", + Expected: []sql.Row{{10}}, + }, + }, + }, +} + var SpatialUpdateTests = []WriteQueryTest{ { WriteQuery: "UPDATE point_table SET p = point(123.456,789);", diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index 173be4222a..f530e216e3 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -32,7 +32,7 @@ var VariableQueries = []ScriptTest{ Name: "use string name for foreign_key checks", SetUpScript: []string{}, Query: "set @@foreign_key_checks = off;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Name: "set system variables", @@ -115,15 +115,15 @@ var VariableQueries = []ScriptTest{ }, { Query: "set @@server_id=123;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "set @@GLOBAL.server_id=123;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "set @@GLOBAL.server_id=0;", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, }, }, @@ -523,7 +523,7 @@ var VariableQueries = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "set transaction isolation level serializable, read only", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation, @@transaction_read_only", @@ -531,7 +531,7 @@ var VariableQueries = []ScriptTest{ }, { Query: "set transaction read write, isolation level read uncommitted", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation, @@transaction_read_only", @@ -539,7 +539,7 @@ var VariableQueries = []ScriptTest{ }, { Query: "set transaction isolation level read committed", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation", @@ -547,7 +547,7 @@ var VariableQueries = []ScriptTest{ }, { Query: "set transaction isolation level repeatable read", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation", @@ -555,7 +555,7 @@ var VariableQueries = []ScriptTest{ }, { Query: "set session transaction isolation level serializable, read only", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation, @@transaction_read_only", @@ -563,7 +563,7 @@ var VariableQueries = []ScriptTest{ }, { Query: "set global transaction read write, isolation level read uncommitted", - Expected: []sql.Row{{}}, + Expected: []sql.Row{{types.NewOkResult(0)}}, }, { Query: "select @@transaction_isolation, @@transaction_read_only", diff --git a/enginetest/server_engine.go b/enginetest/server_engine.go index 0a222ef534..9d56dbfb2b 100644 --- a/enginetest/server_engine.go +++ b/enginetest/server_engine.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "strconv" "strings" @@ -206,39 +207,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa return s.queryOrExec(ctx, stmt, parsed, query, args) } +func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + var rows *gosql.Rows + var err error + if stmt != nil { + rows, err = stmt.Query(args...) + } else { + rows, err = s.conn.Query(query, args...) + } + if err != nil { + return nil, nil, nil, trimMySQLErrCodePrefix(err) + } + return convertRowsResult(ctx, rows, query) +} + +func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + var res gosql.Result + var err error + if stmt != nil { + res, err = stmt.Exec(args...) + } else { + res, err = s.conn.Exec(query, args...) + } + if err != nil { + return nil, nil, nil, trimMySQLErrCodePrefix(err) + } + return convertExecResult(res) +} + // queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan. // If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot // be run as prepared. // TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds. // -// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result. +// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult. func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { - var err error - switch parsed.(type) { // TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned. - case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain: - var rows *gosql.Rows - if stmt != nil { - rows, err = stmt.Query(args...) - } else { - rows, err = s.conn.Query(query, args...) - } - if err != nil { - return nil, nil, nil, trimMySQLErrCodePrefix(err) - } - return convertRowsResult(ctx, rows) + var shouldQuery bool + switch p := parsed.(type) { + // Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec + case *sqlparser.Insert: + if p.Returning != nil { + shouldQuery = true + } + case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, + *sqlparser.Call, *sqlparser.Begin, + *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, + *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain: + shouldQuery = true default: - var res gosql.Result - if stmt != nil { - res, err = stmt.Exec(args...) - } else { - res, err = s.conn.Exec(query, args...) - } - if err != nil { - return nil, nil, nil, trimMySQLErrCodePrefix(err) - } - return convertExecResult(res) } + + if shouldQuery { + return s.query(ctx, stmt, query, args) + } + return s.exec(ctx, stmt, query, args) } // trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server. @@ -280,7 +303,7 @@ func convertExecResult(exec gosql.Result) (sql.Schema, sql.RowIter, *sql.QueryFl return types.OkResultSchema, sql.RowsToRowIter(sql.NewRow(okResult)), nil, nil } -func convertRowsResult(ctx *sql.Context, rows *gosql.Rows) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { +func convertRowsResult(ctx *sql.Context, rows *gosql.Rows, query string) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { sch, err := schemaForRows(rows) if err != nil { return nil, nil, nil, err @@ -291,6 +314,36 @@ func convertRowsResult(ctx *sql.Context, rows *gosql.Rows) (sql.Schema, sql.RowI return nil, nil, nil, err } + // If we have no columns and no rows, this might mean a CALL statement that should return OkResult + // (like a CALL to a stored procedure that only does SET operations) + // But we should NOT convert USE, SHOW, etc. statements to OkResult + // Also, external procedures (starting with "memory_") should return empty results, not OkResult + if len(sch) == 0 && strings.HasPrefix(strings.ToUpper(strings.TrimSpace(query)), "CALL") && + !strings.Contains(strings.ToLower(query), "memory_") { + // Check if we actually have any rows by trying to get the first row + firstRow, err := rowIter.Next(ctx) + if err == io.EOF { + // No rows available for a CALL statement, this should be OkResult + okResult := types.NewOkResult(0) + return types.OkResultSchema, sql.RowsToRowIter(sql.NewRow(okResult)), nil, nil + } else if err == nil { + // We do have a row, so create a new iterator that includes this row plus the rest + restRows := []sql.Row{firstRow} + for { + row, err := rowIter.Next(ctx) + if err != nil { + break + } + restRows = append(restRows, row) + } + rowIter.Close(ctx) + return sch, sql.RowsToRowIter(restRows...), nil, nil + } + // Some other error occurred, close the iterator and return the error + rowIter.Close(ctx) + return nil, nil, nil, err + } + return sch, rowIter, nil, nil } diff --git a/go.mod b/go.mod index 81211175d8..c0e7e7befc 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250605180032-fa2a634c215b + github.com/dolthub/vitess v0.0.0-20250611225316-90a5898bfe26 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 0879605e74..d1ac8982bc 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250605180032-fa2a634c215b h1:rgZXgRYZ3SZbb4Tz5Y6vnzvB7P9pFvEP+Q7UGfRC9uY= -github.com/dolthub/vitess v0.0.0-20250605180032-fa2a634c215b/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250611225316-90a5898bfe26 h1:9Npf0JYVCrwe9edTfYD/pjIncCePNDiu4j50xLcV334= +github.com/dolthub/vitess v0.0.0-20250611225316-90a5898bfe26/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/memory/table_data.go b/memory/table_data.go index 79c0e79dba..87f74ccd98 100644 --- a/memory/table_data.go +++ b/memory/table_data.go @@ -15,7 +15,6 @@ package memory import ( - "context" "fmt" "sort" "strconv" @@ -25,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -275,7 +275,7 @@ func (td *TableData) numRows(ctx *sql.Context) (uint64, error) { } // throws an error if any two or more rows share the same |cols| values. -func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string, idxName string) error { +func (td *TableData) errIfDuplicateEntryExist(ctx *sql.Context, cols []string, idxName string) error { columnMapping, err := td.columnIndexes(cols) // We currently skip validating duplicates on unique virtual columns. @@ -297,7 +297,7 @@ func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string if hasNulls(idxPrefixKey) { continue } - h, err := sql.HashOf(ctx, idxPrefixKey) + h, err := hash.HashOf(ctx, td.schema.Schema, idxPrefixKey) if err != nil { return err } diff --git a/server/handler.go b/server/handler.go index e3c7d57a50..113e9cc978 100644 --- a/server/handler.go +++ b/server/handler.go @@ -157,7 +157,10 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p // than they will at execution time. func nodeReturnsOkResultSchema(node sql.Node) bool { switch node.(type) { - case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: + case *plan.InsertInto: + insertNode, _ := node.(*plan.InsertInto) + return insertNode.Returning == nil + case *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: return true } return false diff --git a/server/handler_test.go b/server/handler_test.go index 969c1b408c..03bf918754 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -212,7 +212,7 @@ func TestHandlerOutput(t *testing.T) { }) require.NoError(t, err) require.Equal(t, 1, len(result.Rows)) - require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type()) + require.Equal(t, sqltypes.Int16, result.Rows[0][0].Type()) require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes()) }) } @@ -471,7 +471,8 @@ func TestHandlerComPrepareExecute(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -550,7 +551,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -567,7 +569,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { BindVars: nil, }, schema: []*query.Field{ - {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {1000}, @@ -584,7 +587,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { BindVars: nil, }, schema: []*query.Field{ - {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {-129}, diff --git a/sql/aggregates.go b/sql/aggregates.go index e414316c35..09f28314ad 100644 --- a/sql/aggregates.go +++ b/sql/aggregates.go @@ -115,7 +115,7 @@ type WindowFrame interface { StartNFollowing() Expression // EndNPreceding returns whether a frame end preceding Expression or nil EndNPreceding() Expression - // EndNPreceding returns whether a frame end following Expression or nil + // EndNFollowing returns whether a frame end following Expression or nil EndNFollowing() Expression } @@ -135,3 +135,9 @@ type AggregationBuffer interface { type WindowAggregation interface { WindowAdaptableExpression } + +// OrderedAggregation are aggregate functions that modify the current working row with additional result columns. +type OrderedAggregation interface { + // OutputExpressions gets a list of return expressions. + OutputExpressions() []Expression +} diff --git a/sql/analyzer/apply_foreign_keys.go b/sql/analyzer/apply_foreign_keys.go index e958799bcc..4deb0897c8 100644 --- a/sql/analyzer/apply_foreign_keys.go +++ b/sql/analyzer/apply_foreign_keys.go @@ -122,30 +122,34 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f if plan.IsEmptyTable(n.Child) { return n, transform.SameTree, nil } - updateDest, err := plan.GetUpdatable(n.Child) - if err != nil { - return nil, transform.SameTree, err - } - fkTbl, ok := updateDest.(sql.ForeignKeyTable) - // If foreign keys aren't supported then we return - if !ok { - return n, transform.SameTree, nil + if uj, ok := n.Child.(*plan.UpdateJoin); ok { + updateTargets := uj.UpdateTargets + fkHandlerMap := make(map[string]sql.Node, len(updateTargets)) + for tableName, updateTarget := range updateTargets { + fkHandlerMap[tableName] = updateTarget + fkHandler, err := + getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain) + if err != nil { + return nil, transform.SameTree, err + } + if fkHandler == nil { + fkHandlerMap[tableName] = updateTarget + } else { + fkHandlerMap[tableName] = fkHandler + } + } + uj = plan.NewUpdateJoin(fkHandlerMap, uj.Child) + nn, err := n.WithChildren(uj) + return nn, transform.NewTree, err } - - fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false) + fkHandler, err := getForeignKeyHandlerFromUpdateTarget(ctx, a, n.Child, cache, fkChain) if err != nil { return nil, transform.SameTree, err } - if fkEditor == nil { + if fkHandler == nil { return n, transform.SameTree, nil } - nn, err := n.WithChildren(&plan.ForeignKeyHandler{ - Table: fkTbl, - Sch: updateDest.Schema(), - OriginalNode: n.Child, - Editor: fkEditor, - AllUpdaters: fkChain.GetUpdaters(), - }) + nn, err := n.WithChildren(fkHandler) return nn, transform.NewTree, err case *plan.DeleteFrom: if plan.IsEmptyTable(n.Child) { @@ -443,6 +447,36 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa return fkEditor, nil } +// getForeignKeyHandlerFromUpdateTarget creates a ForeignKeyHandler from a given update target Node. It is used for +// applying foreign key constraints to Update nodes +func getForeignKeyHandlerFromUpdateTarget(ctx *sql.Context, a *Analyzer, updateTarget sql.Node, + cache *foreignKeyCache, fkChain foreignKeyChain) (*plan.ForeignKeyHandler, error) { + updateDest, err := plan.GetUpdatable(updateTarget) + if err != nil { + return nil, err + } + fkTbl, ok := updateDest.(sql.ForeignKeyTable) + if !ok { + return nil, nil + } + + fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false) + if err != nil { + return nil, err + } + if fkEditor == nil { + return nil, nil + } + + return &plan.ForeignKeyHandler{ + Table: fkTbl, + Sch: updateDest.Schema(), + OriginalNode: updateTarget, + Editor: fkEditor, + AllUpdaters: fkChain.GetUpdaters(), + }, nil +} + // resolveSchemaDefaults resolves the default values for the schema of |table|. This is primarily needed for column // default value expressions, since those don't get resolved during the planbuilder phase and assignExecIndexes // doesn't traverse through the ForeignKeyEditors and referential actions to find all of them. In addition to resolving diff --git a/sql/analyzer/apply_hash_in.go b/sql/analyzer/apply_hash_in.go index f89334ae0e..b51e2378e3 100644 --- a/sql/analyzer/apply_hash_in.go +++ b/sql/analyzer/apply_hash_in.go @@ -56,16 +56,13 @@ func applyHashIn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, s // hasSingleOutput checks if an expression evaluates to a single output func hasSingleOutput(e sql.Expression) bool { - return !transform.InspectExpr(e, func(expr sql.Expression) bool { + return transform.InspectExpr(e, func(expr sql.Expression) bool { switch expr.(type) { - case expression.Tuple, *expression.Literal, *expression.GetField, - expression.Comparer, *expression.Convert, sql.FunctionExpression, - *expression.IsTrue, *expression.IsNull, expression.ArithmeticOp: + case *plan.Subquery: return false default: return true } - return false }) } diff --git a/sql/analyzer/assign_update_join.go b/sql/analyzer/assign_update_join.go index 982fee825e..9c7d088560 100644 --- a/sql/analyzer/assign_update_join.go +++ b/sql/analyzer/assign_update_join.go @@ -34,12 +34,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope * return n, transform.SameTree, nil } - updaters, err := rowUpdatersByTable(ctx, us, jn) + updateTargets, err := getUpdateTargetsByTable(us, jn, n.IsJoin) if err != nil { return nil, transform.SameTree, err } - uj := plan.NewUpdateJoin(updaters, us) + uj := plan.NewUpdateJoin(updateTargets, us) ret, err := n.WithChildren(uj) if err != nil { return nil, transform.SameTree, err @@ -51,12 +51,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope * return n, transform.SameTree, nil } -// rowUpdatersByTable maps a set of tables to their RowUpdater objects. -func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) { +// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node +func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[string]sql.Node, error) { namesOfTableToBeUpdated := getTablesToBeUpdated(node) resolvedTables := getTablesByName(ij) - rowUpdatersByTable := make(map[string]sql.RowUpdater) + updateTargets := make(map[string]sql.Node) for tableToBeUpdated, _ := range namesOfTableToBeUpdated { resolvedTable, ok := resolvedTables[tableToBeUpdated] if !ok { @@ -72,14 +72,14 @@ func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[strin } keyless := sql.IsKeyless(updatable.Schema()) - if keyless { + if keyless && isJoin { return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN") } - rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx) + updateTargets[tableToBeUpdated] = resolvedTable } - return rowUpdatersByTable, nil + return updateTargets, nil } // getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index b6c5b9a0c4..c07d1199c0 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -467,6 +467,8 @@ func getStatisticsTable(table sql.Table, prevTable sql.Table) (sql.StatisticsTab return t, true case sql.TableNode: return getStatisticsTable(t.UnderlyingTable(), table) + case sql.TableWrapper: + return getStatisticsTable(t.Underlying(), table) default: return nil, false } diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index 10b4418c92..f8589bbce5 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -16,6 +16,7 @@ package analyzer import ( "fmt" + "slices" "sort" "strings" "time" @@ -203,6 +204,9 @@ func getCostedIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, rt sql.Ta if !ok { stat, err = uniformDistStatisticsForIndex(ctx, statsProv, iat, idx) } + if err != nil { + return nil, nil, nil, err + } err := c.cost(root, stat, idx) if err != nil { return nil, nil, nil, err @@ -446,6 +450,8 @@ type indexCoster struct { // prefix key of the best indexScan bestPrefix int underlyingName string + // whether the column following the prefix key is limited to a subrange + hasRange bool } // cost tries to build the lowest cardinality index scan for an expression @@ -459,10 +465,11 @@ func (c *indexCoster) cost(f indexFilter, stat sql.Statistic, idx sql.Index) err var prefix int var err error var ok bool + hasRange := false switch f := f.(type) { case *iScanAnd: - newHist, newFds, filters, prefix, err = c.costIndexScanAnd(c.ctx, f, stat, stat.Histogram(), ordinals, idx) + newHist, newFds, filters, prefix, hasRange, err = c.costIndexScanAnd(c.ctx, f, stat, stat.Histogram(), ordinals, idx) if err != nil { return err } @@ -491,12 +498,12 @@ func (c *indexCoster) cost(f indexFilter, stat sql.Statistic, idx sql.Index) err newFds = &sql.FuncDepSet{} } - c.updateBest(stat, newHist, newFds, filters, prefix) + c.updateBest(stat, newHist, newFds, filters, prefix, hasRange) return nil } -func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fds *sql.FuncDepSet, filters sql.FastIntSet, prefix int) { +func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fds *sql.FuncDepSet, filters sql.FastIntSet, prefix int, hasRange bool) { if s == nil || filters.Len() == 0 { return } @@ -510,6 +517,7 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd c.bestCnt = rowCnt c.bestFilters = filters c.bestPrefix = prefix + c.hasRange = hasRange } }() @@ -534,6 +542,26 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd return } + // If one index uses a strict superset of the filters of the other, we should always pick the superset. + // This is true even if the index with more filters isn't unique. + if prefix > c.bestPrefix && slices.Equal(c.bestStat.Columns()[:c.bestPrefix], s.Columns()[:c.bestPrefix]) { + update = true + return + } + + if prefix == c.bestPrefix && slices.Equal(c.bestStat.Columns()[:c.bestPrefix], s.Columns()[:c.bestPrefix]) && hasRange && !c.hasRange { + update = true + return + } + + if c.bestPrefix > prefix && slices.Equal(c.bestStat.Columns()[:prefix], s.Columns()[:prefix]) { + return + } + + if c.bestPrefix == prefix && slices.Equal(c.bestStat.Columns()[:prefix], s.Columns()[:prefix]) && !hasRange && c.hasRange { + return + } + bestKey, bok := best.StrictKey() cmpKey, cok := cmp.StrictKey() if cok && !bok { @@ -575,6 +603,10 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd return } + if filters.Len() < c.bestFilters.Len() { + return + } + if s.ColSet().Len()-filters.Len() < c.bestStat.ColSet().Len()-c.bestFilters.Len() { // prefer 1 range filter over 1 column index (1 - 1 = 0) // vs. 1 range filter over 2 column index (2 - 1 = 1) @@ -620,7 +652,7 @@ func (c *indexCoster) getConstAndNullFilters(filters sql.FastIntSet) (sql.FastIn switch e.(type) { case *expression.Equals: isConst.Add(i) - case *expression.IsNull: + case sql.IsNullExpression: isNull.Add(i) case *expression.NullSafeEquals: isConst.Add(i) @@ -1199,7 +1231,7 @@ func ordinalsForStat(stat sql.Statistic) map[string]int { // updated statistic, the subset of applicable filters, the maximum prefix // key created by a subset of equality filters (from conjunction only), // or an error if applicable. -func (c *indexCoster) costIndexScanAnd(ctx *sql.Context, filter *iScanAnd, s sql.Statistic, buckets []sql.HistogramBucket, ordinals map[string]int, idx sql.Index) ([]sql.HistogramBucket, *sql.FuncDepSet, sql.FastIntSet, int, error) { +func (c *indexCoster) costIndexScanAnd(ctx *sql.Context, filter *iScanAnd, s sql.Statistic, buckets []sql.HistogramBucket, ordinals map[string]int, idx sql.Index) ([]sql.HistogramBucket, *sql.FuncDepSet, sql.FastIntSet, int, bool, error) { // first step finds the conjunctions that match index prefix columns. // we divide into eqFilters and rangeFilters @@ -1210,13 +1242,13 @@ func (c *indexCoster) costIndexScanAnd(ctx *sql.Context, filter *iScanAnd, s sql for _, or := range filter.orChildren { childStat, _, ok, err := c.costIndexScanOr(or.(*iScanOr), s, buckets, ordinals, idx) if err != nil { - return nil, nil, sql.FastIntSet{}, 0, err + return nil, nil, sql.FastIntSet{}, 0, false, err } // if valid, INTERSECT if ok { ret, err = stats.Intersect(c.ctx, ret, childStat, s.Types()) if err != nil { - return nil, nil, sql.FastIntSet{}, 0, err + return nil, nil, sql.FastIntSet{}, 0, false, err } exact.Add(int(or.Id())) } @@ -1237,12 +1269,8 @@ func (c *indexCoster) costIndexScanAnd(ctx *sql.Context, filter *iScanAnd, s sql conjFDs = conj.getFds() } - if exact.Len()+conj.applied.Len() == filter.childCnt() { - // matched all filters - return conj.hist, conjFDs, sql.NewFastIntSet(int(filter.id)), conj.missingPrefix, nil - } - - return conj.hist, conjFDs, exact.Union(conj.applied), conj.missingPrefix, nil + hasRange := conj.ineqCols.Contains(conj.missingPrefix) + return conj.hist, conjFDs, exact.Union(conj.applied), conj.missingPrefix, hasRange, nil } func (c *indexCoster) costIndexScanOr(filter *iScanOr, s sql.Statistic, buckets []sql.HistogramBucket, ordinals map[string]int, idx sql.Index) ([]sql.HistogramBucket, *sql.FuncDepSet, bool, error) { @@ -1253,11 +1281,11 @@ func (c *indexCoster) costIndexScanOr(filter *iScanOr, s sql.Statistic, buckets for _, child := range filter.children { switch child := child.(type) { case *iScanAnd: - childBuckets, _, ids, _, err := c.costIndexScanAnd(c.ctx, child, s, buckets, ordinals, idx) + childBuckets, _, ids, _, _, err := c.costIndexScanAnd(c.ctx, child, s, buckets, ordinals, idx) if err != nil { return nil, nil, false, err } - if ids.Len() != 1 || !ids.Contains(int(child.Id())) { + if ids.Len() != child.childCnt() { // scan option missed some filters return nil, nil, false, nil } @@ -1485,14 +1513,20 @@ func IndexLeafChildren(e sql.Expression) (IndexScanOp, sql.Expression, sql.Expre left = e.Left() right = e.Right() op = IndexScanOpLte - case *expression.IsNull: - left = e.Child + case sql.IsNullExpression: + left = e.Children()[0] op = IndexScanOpIsNull + case sql.IsNotNullExpression: + left = e.Children()[0] + op = IndexScanOpIsNotNull case *expression.Not: switch e := e.Child.(type) { - case *expression.IsNull: - left = e.Child + case sql.IsNullExpression: + left = e.Children()[0] op = IndexScanOpIsNotNull + // TODO: In Postgres, Not(IS NULL) is valid, but doesn't necessarily always mean the + // same thing as IS NOT NULL, particularly for the case of records or composite + // values. case *expression.Equals: left = e.Left() right = e.Right() @@ -1664,6 +1698,7 @@ type conjCollector struct { ordinals map[string]int missingPrefix int constant sql.FastIntSet + ineqCols sql.FastIntSet eqVals []interface{} nullable []bool applied sql.FastIntSet @@ -1732,6 +1767,7 @@ func (c *conjCollector) addEq(ctx *sql.Context, col string, val interface{}, nul func (c *conjCollector) addIneq(ctx *sql.Context, op IndexScanOp, col string, val interface{}) error { ord := c.ordinals[col] + c.ineqCols.Add(ord) if ord > 0 { return nil } diff --git a/sql/analyzer/costed_index_scan_test.go b/sql/analyzer/costed_index_scan_test.go index 8f4d4f106f..3464e7d023 100644 --- a/sql/analyzer/costed_index_scan_test.go +++ b/sql/analyzer/costed_index_scan_test.go @@ -477,67 +477,53 @@ func TestRangeBuilder(t *testing.T) { }, // nulls { - or2( - and2(isNull(x), gt2(y, 5)), - ), + and2(isNull(x), gt2(y, 5)), sql.MySQLRangeCollection{ r(null2(), rgt(5)), }, - 1, + 2, }, { - or2( - and2(isNull(x), isNotNull(y)), - ), + and2(isNull(x), isNotNull(y)), sql.MySQLRangeCollection{ r(null2(), notNull()), }, - 1, + 2, }, { - or2( - and2(isNull(x), lt2(y, 5)), - ), + and2(isNull(x), lt2(y, 5)), sql.MySQLRangeCollection{ r(null2(), rlt(5)), }, - 1, + 2, }, { - or2( - and(isNull(x), gte2(y, 5)), - ), + and(isNull(x), gte2(y, 5)), sql.MySQLRangeCollection{ r(null2(), rgte(5)), }, - 1, + 2, }, { - or2( - and(isNull(x), lte2(y, 5)), - ), + and(isNull(x), lte2(y, 5)), sql.MySQLRangeCollection{ r(null2(), rlte(5)), }, - 1, + 2, }, { - or2( - and(isNull(x), lte2(y, 5)), - ), + and(isNull(x), lte2(y, 5)), sql.MySQLRangeCollection{ r(null2(), rlte(5)), }, - 1, + 2, }, { - or2( - and2(isNull(x), eq2(y, 1)), - ), + and2(isNull(x), eq2(y, 1)), sql.MySQLRangeCollection{ r(null2(), req(1)), }, - 1, + 2, }, } @@ -590,8 +576,6 @@ func TestRangeBuilder(t *testing.T) { require.NoError(t, err) include := c.bestFilters - // most tests are designed so that all filters are supported - // |included| = |root.id| require.Equal(t, tt.cnt, include.Len()) if tt.cnt == 1 { require.True(t, include.Contains(1)) diff --git a/sql/analyzer/fix_exec_indexes.go b/sql/analyzer/fix_exec_indexes.go index e73503c12b..90c94e848d 100644 --- a/sql/analyzer/fix_exec_indexes.go +++ b/sql/analyzer/fix_exec_indexes.go @@ -578,9 +578,23 @@ func (s *idxScope) visitSelf(n sql.Node) error { } if ne, ok := n.(sql.Expressioner); ok { scope := append(s.parentScopes, s.childScopes...) + // default nodes can't see lateral join nodes, unless we're in lateral + // join and lateral scopes are promoted to parent status for _, e := range ne.Expressions() { - // default nodes can't see lateral join nodes, unless we're in lateral - // join and lateral scopes are promoted to parent status + // OrderedAggregations are special as they append results to the outer scope row + // We need to account for this extra column in the rows when assigning indexes + // Example: gms/expression/function/aggregation/group_concat.go:groupConcatBuffer.Update() + if ordAgg, isOrdAgg := e.(sql.OrderedAggregation); isOrdAgg { + selExprs := ordAgg.OutputExpressions() + selScope := &idxScope{} + for _, expr := range selExprs { + selScope.columns = append(selScope.columns, expr.String()) + if gf, isGf := expr.(*expression.GetField); isGf { + selScope.ids = append(selScope.ids, gf.Id()) + } + } + scope = append(scope, selScope) + } s.expressions = append(s.expressions, fixExprToScope(e, scope...)) } } diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index b12735916f..f9dd2e69aa 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -158,6 +158,9 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco qFlags.Set(sql.QFlagInnerJoin) + hints := m.SessionHints() + hints = append(hints, memo.ExtractJoinHint(n)...) + err = addIndexScans(ctx, m) if err != nil { return nil, err @@ -180,9 +183,11 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco return nil, err } - err = addMergeJoins(ctx, m) - if err != nil { - return nil, err + if !mergeJoinsDisabled(hints) { + err = addMergeJoins(ctx, m) + if err != nil { + return nil, err + } } memo.CardMemoGroups(ctx, m.Root()) @@ -200,11 +205,9 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco return nil, err } - m.SetDefaultHints() - hints := memo.ExtractJoinHint(n) + // Once we've enumerated all expression groups, we can apply hints. This must be done after expression + // groups have been identified, so that the applied hints use the correct metadata. for _, h := range hints { - // this should probably happen earlier, but the root is not - // populated before reordering m.ApplyHint(h) } @@ -223,6 +226,16 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco return m.BestRootPlan(ctx) } +// mergeJoinsDisabled returns true if merge joins have been disabled in the specified |hints|. +func mergeJoinsDisabled(hints []memo.Hint) bool { + for _, hint := range hints { + if hint.Typ == memo.HintTypeNoMergeJoin { + return true + } + } + return false +} + // addLookupJoins prefixes memo join group expressions with indexed join // alternatives to join plans added by joinOrderBuilder. We can assume that a // join with a non-nil join filter is not degenerate, and we can apply indexed @@ -554,7 +567,7 @@ func convertAntiToLeftJoin(m *memo.Memo) error { // drop null projected columns on right table nullFilters := make([]sql.Expression, len(nullify)) for i, e := range nullify { - nullFilters[i] = expression.NewIsNull(e) + nullFilters[i] = expression.DefaultExpressionFactory.NewIsNull(e) } filterGrp := m.MemoizeFilter(nil, joinGrp, nullFilters) @@ -1399,7 +1412,7 @@ func isWeaklyMonotonic(e sql.Expression) bool { } return false case *expression.Equals, *expression.NullSafeEquals, *expression.Literal, *expression.GetField, - *expression.Tuple, *expression.IsNull, *expression.BindVar: + *expression.Tuple, *expression.BindVar, sql.IsNullExpression, sql.IsNotNullExpression: return false default: if e, ok := e.(expression.Equality); ok && e.RepresentsEquality() { diff --git a/sql/analyzer/load_triggers.go b/sql/analyzer/load_triggers.go index bcbc652444..32cef54438 100644 --- a/sql/analyzer/load_triggers.go +++ b/sql/analyzer/load_triggers.go @@ -33,7 +33,7 @@ func loadTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, switch node := n.(type) { case *plan.ShowTriggers: newShowTriggers := *node - loadedTriggers, err := loadTriggersFromDb(ctx, a, newShowTriggers.Database()) + loadedTriggers, err := loadTriggersFromDb(ctx, a, newShowTriggers.Database(), false) if err != nil { return nil, transform.SameTree, err } @@ -44,16 +44,16 @@ func loadTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, } return &newShowTriggers, transform.NewTree, nil case *plan.DropTrigger: - loadedTriggers, err := loadTriggersFromDb(ctx, a, node.Database()) + loadedTriggers, err := loadTriggersFromDb(ctx, a, node.Database(), true) if err != nil { return nil, transform.SameTree, err } - lowercasedTriggerName := strings.ToLower(node.TriggerName) for _, trigger := range loadedTriggers { - if strings.ToLower(trigger.TriggerName) == lowercasedTriggerName { + if strings.EqualFold(node.TriggerName, trigger.TriggerName) { node.TriggerName = trigger.TriggerName - } else if trigger.TriggerOrder != nil && - strings.ToLower(trigger.TriggerOrder.OtherTriggerName) == lowercasedTriggerName { + continue + } + if trigger.TriggerOrder != nil && strings.EqualFold(node.TriggerName, trigger.TriggerOrder.OtherTriggerName) { return nil, transform.SameTree, sql.ErrTriggerCannotBeDropped.New(node.TriggerName, trigger.TriggerName) } } @@ -70,7 +70,7 @@ func loadTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, dropTableDb = t.SqlDatabase } - loadedTriggers, err := loadTriggersFromDb(ctx, a, dropTableDb) + loadedTriggers, err := loadTriggersFromDb(ctx, a, dropTableDb, false) if err != nil { return nil, transform.SameTree, err } @@ -95,7 +95,7 @@ func loadTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, }) } -func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database) ([]*plan.CreateTrigger, error) { +func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database, ignoreParseErrors bool) ([]*plan.CreateTrigger, error) { var loadedTriggers []*plan.CreateTrigger if triggerDb, ok := db.(sql.TriggerDatabase); ok { triggers, err := triggerDb.GetTriggers(ctx) @@ -108,7 +108,17 @@ func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database) ([]*plan // TODO: should perhaps add the auth query handler to the analyzer? does this even use auth? parsedTrigger, _, err = planbuilder.ParseWithOptions(ctx, a.Catalog, trigger.CreateStatement, sqlMode.ParserOptions()) if err != nil { - return nil, err + // We want to be able to drop invalid triggers, so ignore any parser errors and return the name of the trigger + if !ignoreParseErrors { + return nil, err + } + // TODO: we won't have TriggerOrder information for this unparseable trigger, + // but it will still be referenced by any valid triggers. + fakeTrigger := &plan.CreateTrigger{ + TriggerName: trigger.Name, + } + loadedTriggers = append(loadedTriggers, fakeTrigger) + continue } triggerPlan, ok := parsedTrigger.(*plan.CreateTrigger) if !ok { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index d4b172f56f..06900efec8 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -172,7 +172,7 @@ func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) { switch e := e.(type) { case *expression.GetField: tables.Add(int(e.TableId())) - case *expression.IsNull: + case sql.IsNullExpression, sql.IsNotNullExpression: nullRejecting = false case *expression.NullSafeEquals: nullRejecting = false @@ -188,7 +188,7 @@ func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) { switch e := innerExpr.(type) { case *expression.GetField: tables.Add(int(e.TableId())) - case *expression.IsNull: + case sql.IsNullExpression, sql.IsNotNullExpression: nullRejecting = false case *expression.NullSafeEquals: nullRejecting = false diff --git a/sql/analyzer/process_truncate.go b/sql/analyzer/process_truncate.go index 57ad3deda6..dd1ec8eb38 100644 --- a/sql/analyzer/process_truncate.go +++ b/sql/analyzer/process_truncate.go @@ -100,7 +100,7 @@ func deleteToTruncate(ctx *sql.Context, a *Analyzer, deletePlan *plan.DeleteFrom return deletePlan, transform.SameTree, nil } - triggers, err := loadTriggersFromDb(ctx, a, currentDb) + triggers, err := loadTriggersFromDb(ctx, a, currentDb, false) if err != nil { return nil, transform.SameTree, err } diff --git a/sql/analyzer/replace_count_star.go b/sql/analyzer/replace_count_star.go index cccedb9f01..d86c644dc2 100644 --- a/sql/analyzer/replace_count_star.go +++ b/sql/analyzer/replace_count_star.go @@ -95,7 +95,7 @@ func replaceCountStar(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, return n, transform.SameTree, nil } - if statsTable, ok := rt.Table.(sql.StatisticsTable); ok { + if statsTable, ok := getStatisticsTable(rt.Table, nil); ok { rowCnt, exact, err := statsTable.RowCount(ctx) if err == nil && exact { return plan.NewProject( diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index 93e24737ce..705c1dc592 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -233,6 +233,9 @@ func validateColumnDefault(ctx *sql.Context, col *sql.Column, colDefault *sql.Co var err error sql.Inspect(colDefault.Expr, func(e sql.Expression) bool { switch e.(type) { + case *expression.UserVar, *expression.SystemVar: + err = sql.ErrColumnDefaultUserVariable.New(col.Name) + return false case sql.FunctionExpression, *expression.UnresolvedFunction: var funcName string switch expr := e.(type) { @@ -275,9 +278,49 @@ func validateColumnDefault(ctx *sql.Context, col *sql.Column, colDefault *sql.Co return err } + if enumType, isEnum := col.Type.(sql.EnumType); isEnum && colDefault.IsLiteral() { + if err = validateEnumLiteralDefault(enumType, colDefault, col.Name, ctx); err != nil { + return err + } + } + return nil } +// validateEnumLiteralDefault validates enum literal defaults more strictly than runtime conversions +// MySQL doesn't allow numeric index references for literal enum defaults +func validateEnumLiteralDefault(enumType sql.EnumType, colDefault *sql.ColumnDefaultValue, columnName string, ctx *sql.Context) error { + val, err := colDefault.Expr.Eval(ctx, nil) + if err != nil { + return err + } + + switch v := val.(type) { + case nil: + // NULL is a valid default for enum columns + return nil + case string: + // For string values, check if it's a direct enum value match + enumValues := enumType.Values() + for _, enumVal := range enumValues { + if enumVal == v { + return nil // Valid enum value + } + } + // String doesn't match any enum value, return appropriate error + if v == "" { + return sql.ErrIncompatibleDefaultType.New() + } + return sql.ErrInvalidColumnDefaultValue.New(columnName) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + // MySQL doesn't allow numeric enum indices as literal defaults + return sql.ErrInvalidColumnDefaultValue.New(columnName) + default: + // Other types not supported for enum defaults + return sql.ErrIncompatibleDefaultType.New() + } +} + func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue) if !ok { diff --git a/sql/analyzer/symbol_resolution.go b/sql/analyzer/symbol_resolution.go index ce2d6f150e..c7e66881f9 100644 --- a/sql/analyzer/symbol_resolution.go +++ b/sql/analyzer/symbol_resolution.go @@ -202,37 +202,23 @@ func pruneTableCols( return n, transform.SameTree, nil } + // columns don't need to be pruned if there's a star _, selectStar := parentStars[table.Name()] - if unqualifiedStar { - selectStar = true + if selectStar || unqualifiedStar { + return n, transform.SameTree, nil } - // Don't prune columns if they're needed by a virtual column - virtualColDeps := make(map[tableCol]int) - if !selectStar { // if selectStar, we're adding all columns anyway - if vct, isVCT := n.WrappedTable().(*plan.VirtualColumnTable); isVCT { - for _, projection := range vct.Projections { - transform.InspectExpr(projection, func(e sql.Expression) bool { - if cd, isCD := e.(*sql.ColumnDefaultValue); isCD { - transform.InspectExpr(cd.Expr, func(e sql.Expression) bool { - if gf, ok := e.(*expression.GetField); ok { - c := newTableCol(gf.Table(), gf.Name()) - virtualColDeps[c]++ - } - return false - }) - } - return false - }) - } - } + // pruning VirtualColumnTable underlying tables causes indexing errors when VirtualColumnTable.Projections (which are sql.Expression) + // are evaluated + if _, isVCT := n.WrappedTable().(*plan.VirtualColumnTable); isVCT { + return n, transform.SameTree, nil } cols := make([]string, 0) source := strings.ToLower(table.Name()) for _, col := range table.Schema() { c := newTableCol(source, col.Name) - if selectStar || parentCols[c] > 0 || virtualColDeps[c] > 0 { + if parentCols[c] > 0 { cols = append(cols, c.col) } } diff --git a/sql/analyzer/tables.go b/sql/analyzer/tables.go index 21ed2fd8ec..463dabea19 100644 --- a/sql/analyzer/tables.go +++ b/sql/analyzer/tables.go @@ -22,7 +22,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" ) -// Returns the underlying table name for the node given +// Returns the underlying table name, unaliased, for the node given func getTableName(node sql.Node) string { var tableName string transform.Inspect(node, func(node sql.Node) bool { @@ -43,27 +43,6 @@ func getTableName(node sql.Node) string { return tableName } -// Returns the underlying table name for the node given, ignoring table aliases -func getUnaliasedTableName(node sql.Node) string { - var tableName string - transform.Inspect(node, func(node sql.Node) bool { - switch node := node.(type) { - case *plan.ResolvedTable: - tableName = node.Name() - return false - case *plan.UnresolvedTable: - tableName = node.Name() - return false - case *plan.IndexedTableAccess: - tableName = node.Name() - return false - } - return true - }) - - return tableName -} - // Finds first table node that is a descendant of the node given func getTable(node sql.Node) sql.Table { var table sql.Table diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index 4f9f62b64b..8c8eee61c3 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -15,6 +15,7 @@ package analyzer import ( + "errors" "fmt" "strings" @@ -158,7 +159,15 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, db = n.Database().Name() } case *plan.Update: - affectedTables = append(affectedTables, getTableName(n)) + if n.IsJoin { + uj := n.Child.(*plan.UpdateJoin) + updateTargets := uj.UpdateTargets + for _, updateTarget := range updateTargets { + affectedTables = append(affectedTables, getTableName(updateTarget)) + } + } else { + affectedTables = append(affectedTables, getTableName(n)) + } triggerEvent = plan.UpdateTrigger if n.Database() != "" { db = n.Database() @@ -355,18 +364,18 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope } } - return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { + canApplyTriggerExecutor := func(c transform.Context) bool { // Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the - // parent is a trigger body. - // TODO: this won't work for BEGIN END blocks, stored procedures, etc. For those, we need to examine all ancestors, - // not just the immediate parent. Alternately, we could do something like not walk all children of some node types - // (probably better). + // parent is a trigger body. Having this as a selector function will also prevent walking the child nodes in the + // trigger execution logic. if _, ok := c.Parent.(*plan.TriggerExecutor); ok { if c.ChildNum == 1 { // Right child is the trigger execution logic - return c.Node, transform.SameTree, nil + return false } } - + return true + } + return transform.NodeWithCtx(n, canApplyTriggerExecutor, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { switch n := c.Node.(type) { case *plan.InsertInto: qFlags.Set(sql.QFlagTrigger) @@ -404,9 +413,9 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope // like we need something like a MultipleTriggerExecutor node // that could execute multiple triggers on the same row from its // wrapped iterator. There is also an issue with running triggers - // because their field indexes assume the row they evalute will + // because their field indexes assume the row they evaluate will // only ever contain the columns from the single table the trigger - // is based on, but this isn't true with UPDATE JOIN or DELETE JOIN. + // is based on. if n.HasExplicitTargets() { return nil, transform.SameTree, fmt.Errorf("delete from with explicit target tables " + "does not support triggers; retry with single table deletes") @@ -472,6 +481,12 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop ), ) } else { + // TODO: We should be able to handle duplicate column names by masking columns that aren't part of the + // triggered table https://github.com/dolthub/dolt/issues/9403 + err = validateNoConflictingColumnNames(updateSrc.Child.Schema()) + if err != nil { + return nil, err + } // The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old. scopeNode = plan.NewProject( []sql.Expression{expression.NewStar()}, @@ -497,6 +512,19 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return triggerLogic, err } +// validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column +// names +func validateNoConflictingColumnNames(sch sql.Schema) error { + columnNames := make(map[string]struct{}) + for _, col := range sch { + if _, ok := columnNames[col.Name]; ok { + return errors.New("Unable to apply triggers when joined tables have columns with the same name") + } + columnNames[col.Name] = struct{}{} + } + return nil +} + // validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any // table being updated in an outer scope of this analysis) func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error { @@ -505,8 +533,8 @@ func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *p switch node := node.(type) { case *plan.Update, *plan.InsertInto, *plan.DeleteFrom: for _, n := range append([]sql.Node{n}, scope.MemoNodes()...) { - invokingTableName := getUnaliasedTableName(n) - updatedTable := getUnaliasedTableName(node) + invokingTableName := getTableName(n) + updatedTable := getTableName(node) // TODO: need to compare DB as well if updatedTable == invokingTableName { circularRef = sql.ErrTriggerTableInUse.New(updatedTable) diff --git a/sql/analyzer/validate_create_table.go b/sql/analyzer/validate_create_table.go index edda379530..ceaec20d6e 100644 --- a/sql/analyzer/validate_create_table.go +++ b/sql/analyzer/validate_create_table.go @@ -791,6 +791,10 @@ func validateAutoIncrementModify(schema sql.Schema, keyedColumns map[string]bool seen := false for _, col := range schema { if col.AutoIncrement { + // Check if column type is valid for auto_increment + if types.IsEnum(col.Type) { + return sql.ErrInvalidColumnSpecifier.New(col.Name) + } // keyedColumns == nil means they are trying to add auto_increment column if !col.PrimaryKey && !keyedColumns[col.Name] { // AUTO_INCREMENT col must be a key @@ -815,6 +819,10 @@ func validateAutoIncrementAdd(schema sql.Schema, keyColumns map[string]bool) err for _, col := range schema { if col.AutoIncrement { { + // Check if column type is valid for auto_increment + if types.IsEnum(col.Type) { + return sql.ErrInvalidColumnSpecifier.New(col.Name) + } if !col.PrimaryKey && !keyColumns[col.Name] { // AUTO_INCREMENT col must be a key return sql.ErrInvalidAutoIncCols.New() diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 85db26bad8..b8e0cc50b5 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -262,8 +262,9 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop switch parent.(type) { case *plan.Having, *plan.Project, *plan.Sort: - // TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value - // https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key + // TODO: these shouldn't be skipped but we currently aren't able to validate GroupBys with selected aliased + // expressions and a lot of our tests group by aliases + // https://github.com/dolthub/dolt/issues/4998 return true } @@ -273,17 +274,34 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return true } - var groupBys []string + primaryKeys := make(map[string]bool) + for _, col := range gb.Child.Schema() { + if col.PrimaryKey { + primaryKeys[strings.ToLower(col.String())] = true + } + } + + groupBys := make(map[string]bool) + groupByPrimaryKeys := 0 for _, expr := range gb.GroupByExprs { - groupBys = append(groupBys, expr.String()) + exprStr := strings.ToLower(expr.String()) + groupBys[exprStr] = true + if primaryKeys[exprStr] { + groupByPrimaryKeys++ + } + } + + // TODO: also allow grouping by unique non-nullable columns + if len(primaryKeys) != 0 && groupByPrimaryKeys == len(primaryKeys) { + return true } for _, expr := range gb.SelectedExprs { - if _, ok := expr.(sql.Aggregation); !ok { - if !expressionReferencesOnlyGroupBys(groupBys, expr) { - err = analyzererrors.ErrValidationGroupBy.New(expr.String()) - return false - } + if !expressionReferencesOnlyGroupBys(groupBys, expr) { + // TODO: this is currently too restrictive. Dependent columns are fine to reference + // https://dev.mysql.com/doc/refman/8.4/en/group-by-functional-dependence.html + err = analyzererrors.ErrValidationGroupBy.New(expr.String()) + return false } } return true @@ -292,22 +310,14 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return n, transform.SameTree, err } -func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool { +func expressionReferencesOnlyGroupBys(groupBys map[string]bool, expr sql.Expression) bool { valid := true sql.Inspect(expr, func(expr sql.Expression) bool { switch expr := expr.(type) { case nil, sql.Aggregation, *expression.Literal: return false - case *expression.Alias, sql.FunctionExpression: - if stringContains(groupBys, expr.String()) { - return false - } - return true - // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html - // Each part of the SelectExpr must refer to the aggregated columns in some way - // TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference. default: - if stringContains(groupBys, expr.String()) { + if groupBys[strings.ToLower(expr.String())] { return false } diff --git a/sql/cache.go b/sql/cache.go index 260e4a8bac..c794cba491 100644 --- a/sql/cache.go +++ b/sql/cache.go @@ -15,49 +15,12 @@ package sql import ( - "context" "fmt" "runtime" - "sync" - - "github.com/cespare/xxhash/v2" lru "github.com/hashicorp/golang-lru" ) -// HashOf returns a hash of the given value to be used as key in a cache. -func HashOf(ctx context.Context, v Row) (uint64, error) { - hash := digestPool.Get().(*xxhash.Digest) - hash.Reset() - defer digestPool.Put(hash) - for i, x := range v { - if i > 0 { - // separate each value in the row with a nil byte - if _, err := hash.Write([]byte{0}); err != nil { - return 0, err - } - } - x, err := UnwrapAny(ctx, x) - if err != nil { - return 0, err - } - // TODO: probably much faster to do this with a type switch - // TODO: we don't have the type info necessary to appropriately encode the value of a string with a non-standard - // collation, which means that two strings that differ only in their collations will hash to the same value. - // See rowexec/grouping_key() - if _, err := fmt.Fprintf(hash, "%v,", x); err != nil { - return 0, err - } - } - return hash.Sum64(), nil -} - -var digestPool = sync.Pool{ - New: func() any { - return xxhash.New() - }, -} - // ErrKeyNotFound is returned when the key could not be found in the cache. var ErrKeyNotFound = fmt.Errorf("memory: key not found in cache") diff --git a/sql/cache_test.go b/sql/cache_test.go index 7f77d668cd..1f6dd58f43 100644 --- a/sql/cache_test.go +++ b/sql/cache_test.go @@ -15,7 +15,6 @@ package sql import ( - "context" "errors" "testing" @@ -178,35 +177,3 @@ func TestRowsCache(t *testing.T) { require.True(freed) }) } - -func BenchmarkHashOf(b *testing.B) { - ctx := context.Background() - row := NewRow(1, "1") - b.ResetTimer() - for i := 0; i < b.N; i++ { - sum, err := HashOf(ctx, row) - if err != nil { - b.Fatal(err) - } - if sum != 11268758894040352165 { - b.Fatalf("got %v", sum) - } - } -} - -func BenchmarkParallelHashOf(b *testing.B) { - ctx := context.Background() - row := NewRow(1, "1") - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - sum, err := HashOf(ctx, row) - if err != nil { - b.Fatal(err) - } - if sum != 11268758894040352165 { - b.Fatalf("got %v", sum) - } - } - }) -} diff --git a/sql/column.go b/sql/column.go index 87121bf23c..5fe8af3fa1 100644 --- a/sql/column.go +++ b/sql/column.go @@ -130,6 +130,10 @@ func (c Column) Copy() *Column { return &c } +func (c *Column) String() string { + return c.Source + "." + c.Name +} + // TableId is the unique identifier of a table or table alias in a multi-db environment. // The long-term goal is to migrate all uses of table name strings to this and minimize places where we // construct/inspect TableIDs. By treating this as an opaque identifier, it will be easier to migrate to diff --git a/sql/core.go b/sql/core.go index e2e7d0152d..235e744d51 100644 --- a/sql/core.go +++ b/sql/core.go @@ -45,6 +45,16 @@ type Expression interface { WithChildren(children ...Expression) (Expression, error) } +// RowIterExpression is an Expression that returns a RowIter rather than a scalar, used to implement functions that +// return sets. +type RowIterExpression interface { + Expression + // EvalRowIter evaluates the expression, which must be a RowIter + EvalRowIter(ctx *Context, r Row) (RowIter, error) + // ReturnsRowIter returns whether this expression returns a RowIter + ReturnsRowIter() bool +} + // ExpressionWithNodes is an expression that contains nodes as children. type ExpressionWithNodes interface { Expression @@ -65,6 +75,19 @@ type NonDeterministicExpression interface { IsNonDeterministic() bool } +// IsNullExpression indicates that this expression tests for IS NULL. +type IsNullExpression interface { + Expression + IsNullExpression() bool +} + +// IsNotNullExpression indicates that this expression tests for IS NOT NULL. Note that in some cases in some +// database engines, such as records in Postgres, IS NOT NULL is not identical to NOT(IS NULL). +type IsNotNullExpression interface { + Expression + IsNotNullExpression() bool +} + // Node is a node in the execution plan tree. type Node interface { Resolvable diff --git a/sql/errors.go b/sql/errors.go index ffe33278c5..a0fc2aff24 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -147,6 +147,9 @@ var ( // ErrInvalidColumnDefaultValue is returned when column default function value is not wrapped in parentheses for column types excluding datetime and timestamp ErrInvalidColumnDefaultValue = errors.NewKind("Invalid default value for '%s'") + // ErrColumnDefaultUserVariable is returned when a column default expression contains user or system variables + ErrColumnDefaultUserVariable = errors.NewKind("Default value expression of column '%s' cannot refer user or system variables.") + // ErrInvalidDefaultValueOrder is returned when a default value references a column that comes after it and contains a default expression. ErrInvalidDefaultValueOrder = errors.NewKind(`default value of column "%s" cannot refer to a column defined after it if those columns have an expression default value`) @@ -669,6 +672,9 @@ var ( // ErrInvalidAutoIncCols is returned when an auto_increment column cannot be applied ErrInvalidAutoIncCols = errors.NewKind("there can be only one auto_increment column and it must be defined as a key") + // ErrInvalidColumnSpecifier is returned when an invalid column specifier is used + ErrInvalidColumnSpecifier = errors.NewKind("Incorrect column specifier for column '%s'") + // ErrUnknownConstraintDefinition is returned when an unknown constraint type is used ErrUnknownConstraintDefinition = errors.NewKind("unknown constraint definition: %s, %T") diff --git a/sql/expression/case.go b/sql/expression/case.go index 30b6cf3e06..57d1566c9f 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -43,71 +43,15 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression return &Case{expr, branches, elseExpr} } -// From the description of operator typing here: -// https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case -func combinedCaseBranchType(left, right sql.Type) sql.Type { - if left == types.Null { - return right - } - if right == types.Null { - return left - } - - // Our current implementation of StringType.Convert(enum), does not match MySQL's behavior. - // So, we make sure to return Enums in this particular case. - // More details: https://github.com/dolthub/dolt/issues/8598 - if types.IsEnum(left) && types.IsEnum(right) { - return right - } - if types.IsSet(left) && types.IsSet(right) { - return right - } - if types.IsTextOnly(left) && types.IsTextOnly(right) { - return types.LongText - } - if types.IsTextBlob(left) && types.IsTextBlob(right) { - return types.LongBlob - } - if types.IsTime(left) && types.IsTime(right) { - if left == right { - return left - } - return types.DatetimeMaxPrecision - } - if types.IsNumber(left) && types.IsNumber(right) { - if left == types.Float64 || right == types.Float64 { - return types.Float64 - } - if left == types.Float32 || right == types.Float32 { - return types.Float32 - } - if types.IsDecimal(left) || types.IsDecimal(right) { - return types.MustCreateDecimalType(65, 10) - } - if left == types.Uint64 && types.IsSigned(right) || - right == types.Uint64 && types.IsSigned(left) { - return types.MustCreateDecimalType(65, 10) - } - if !types.IsSigned(left) && !types.IsSigned(right) { - return types.Uint64 - } else { - return types.Int64 - } - } - if types.IsJSON(left) && types.IsJSON(right) { - return types.JSON - } - return types.LongText -} - // Type implements the sql.Expression interface. func (c *Case) Type() sql.Type { - curr := types.Null + var curr sql.Type + curr = types.Null for _, b := range c.Branches { - curr = combinedCaseBranchType(curr, b.Value.Type()) + curr = types.GeneralizeTypes(curr, b.Value.Type()) } if c.Else != nil { - curr = combinedCaseBranchType(curr, c.Else.Type()) + curr = types.GeneralizeTypes(curr, c.Else.Type()) } return curr } diff --git a/sql/expression/case_test.go b/sql/expression/case_test.go index 68033649e9..27b5afdacd 100644 --- a/sql/expression/case_test.go +++ b/sql/expression/case_test.go @@ -161,8 +161,8 @@ func TestCaseType(t *testing.T) { } } - decimalType := types.MustCreateDecimalType(65, 10) - + decimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale) + uint64DecimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0) testCases := []struct { name string c *Case @@ -175,13 +175,13 @@ func TestCaseType(t *testing.T) { }, { "unsigned promoted and unsigned", - caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)), + caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint64)), types.Uint64, }, { "signed promoted and signed", caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)), - types.Int64, + types.Int32, }, { "int and float to float", @@ -216,7 +216,7 @@ func TestCaseType(t *testing.T) { { "uint64 and int8 to decimal", caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)), - decimalType, + uint64DecimalType, }, { "int and text to text", diff --git a/sql/expression/enum.go b/sql/expression/enum.go index 8637863bc7..36b4af9c22 100644 --- a/sql/expression/enum.go +++ b/sql/expression/enum.go @@ -69,7 +69,19 @@ func (e *EnumToString) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } enumType := e.Enum.Type().(types.EnumType) - str, _ := enumType.At(int(val.(uint16))) + var str string + val, err = sql.UnwrapAny(ctx, val) + if err != nil { + return nil, err + } + switch v := val.(type) { + case uint16: + str, _ = enumType.At(int(v)) + case string: + str = v + default: + return nil, sql.ErrInvalidType.New(val, types.Text) + } return str, nil } diff --git a/sql/expression/expr-factory.go b/sql/expression/expr-factory.go new file mode 100644 index 0000000000..4d77094d9f --- /dev/null +++ b/sql/expression/expr-factory.go @@ -0,0 +1,50 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import "github.com/dolthub/go-mysql-server/sql" + +// ExpressionFactory allows integrators to provide custom implementations of +// expressions, such as IS NULL and IS NOT NULL. +type ExpressionFactory interface { + // NewIsNull returns a sql.Expression implementation that handles + // the IS NULL expression. + NewIsNull(e sql.Expression) sql.Expression + // NewIsNotNull returns a sql.Expression implementation that handles + // the IS NOT NULL expression. + NewIsNotNull(e sql.Expression) sql.Expression +} + +// DefaultExpressionFactory is the ExpressionFactory used when the analyzer +// needs to create new expressions during analysis, such as IS NULL or +// IS NOT NULL. Integrators can swap in their own implementation if they need +// to customize the existing logic for these expressions. +var DefaultExpressionFactory ExpressionFactory = MySqlExpressionFactory{} + +// MySqlExpressionFactory is the ExpressionFactory that creates expressions +// that follow MySQL's logic. +type MySqlExpressionFactory struct{} + +var _ ExpressionFactory = (*MySqlExpressionFactory)(nil) + +// NewIsNull implements the ExpressionFactory interface. +func (m MySqlExpressionFactory) NewIsNull(e sql.Expression) sql.Expression { + return NewIsNull(e) +} + +// NewIsNotNull implements the ExpressionFactory interface. +func (m MySqlExpressionFactory) NewIsNotNull(e sql.Expression) sql.Expression { + return NewNot(NewIsNull(e)) +} diff --git a/sql/expression/filter-range.go b/sql/expression/filter-range.go index 5e74b16ae6..231e8043c8 100644 --- a/sql/expression/filter-range.go +++ b/sql/expression/filter-range.go @@ -48,24 +48,24 @@ func NewRangeFilterExpr(exprs []sql.Expression, ranges []sql.MySQLRange) (sql.Ex case sql.RangeType_All: rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(1, types.Int8)) case sql.RangeType_EqualNull: - rangeColumnExpr = NewIsNull(exprs[i]) + rangeColumnExpr = DefaultExpressionFactory.NewIsNull(exprs[i]) case sql.RangeType_GreaterThan: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) } else { - rangeColumnExpr = NewNot(NewIsNull(exprs[i])) + rangeColumnExpr = DefaultExpressionFactory.NewIsNotNull(exprs[i]) } case sql.RangeType_GreaterOrEqual: rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) case sql.RangeType_LessThanOrNull: rangeColumnExpr = JoinOr( NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), - NewIsNull(exprs[i]), + DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_LessOrEqualOrNull: rangeColumnExpr = JoinOr( NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), - NewIsNull(exprs[i]), + DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_ClosedClosed: rangeColumnExpr = JoinAnd( diff --git a/sql/expression/function/aggregation/group_concat.go b/sql/expression/function/aggregation/group_concat.go index 4a763b1843..8e896031d6 100644 --- a/sql/expression/function/aggregation/group_concat.go +++ b/sql/expression/function/aggregation/group_concat.go @@ -40,6 +40,7 @@ type GroupConcat struct { var _ sql.FunctionExpression = &GroupConcat{} var _ sql.Aggregation = &GroupConcat{} var _ sql.WindowAdaptableExpression = (*GroupConcat)(nil) +var _ sql.OrderedAggregation = (*GroupConcat)(nil) func NewEmptyGroupConcat() sql.Expression { return &GroupConcat{} @@ -153,6 +154,40 @@ func (g *GroupConcat) String() string { return sb.String() } +func (g *GroupConcat) DebugString() string { + sb := strings.Builder{} + sb.WriteString("group_concat(") + if g.distinct != "" { + sb.WriteString(fmt.Sprintf("distinct %s", g.distinct)) + } + + if g.selectExprs != nil { + var exprs = make([]string, len(g.selectExprs)) + for i, expr := range g.selectExprs { + exprs[i] = sql.DebugString(expr) + } + + sb.WriteString(strings.Join(exprs, ", ")) + } + + if len(g.sf) > 0 { + sb.WriteString(" order by ") + for i, ob := range g.sf { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(sql.DebugString(ob)) + } + } + + sb.WriteString(" separator ") + sb.WriteString(fmt.Sprintf("'%s'", g.separator)) + + sb.WriteString(")") + + return sb.String() +} + // Type implements the Expression interface. // cc: https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html#function_group-concat for explanations // on return type. @@ -195,6 +230,11 @@ func (g *GroupConcat) WithChildren(children ...sql.Expression) (sql.Expression, return NewGroupConcat(g.distinct, g.sf.FromExpressions(orderByExpr...), g.separator, children[sortFieldMarker:], g.maxLen), nil } +// OutputExpressions implements the OrderedAggregation interface. +func (g *GroupConcat) OutputExpressions() []sql.Expression { + return g.selectExprs +} + type groupConcatBuffer struct { gc *GroupConcat rows []sql.Row @@ -231,16 +271,27 @@ func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error return nil } } else { - v, _, err = types.LongText.Convert(ctx, evalRow[0]) - if err != nil { - return err - } - if v == nil { - return nil - } - vs, _, err = sql.Unwrap[string](ctx, v) - if err != nil { - return err + // Use type-aware conversion for enum types + if len(g.gc.selectExprs) > 0 { + vs, _, err = types.ConvertToCollatedString(ctx, evalRow[0], g.gc.selectExprs[0].Type()) + if err != nil { + return err + } + if vs == "" { + return nil + } + } else { + v, _, err = types.LongText.Convert(ctx, evalRow[0]) + if err != nil { + return err + } + if v == nil { + return nil + } + vs, _, err = sql.Unwrap[string](ctx, v) + if err != nil { + return err + } } } @@ -257,7 +308,7 @@ func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error // Append the current value to the end of the row. We want to preserve the row's original structure for // for sort ordering in the final step. - g.rows = append(g.rows, append(originalRow, nil, vs)) + g.rows = append(g.rows, append(originalRow, vs)) return nil } diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index df2b1c82fb..c484b5321a 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -500,17 +500,19 @@ func (c *countBuffer) Dispose() { } type firstBuffer struct { - val interface{} - expr sql.Expression + val interface{} + // writtenNil means that val is supposed to be nil and should not be overwritten + writtenNil bool + expr sql.Expression } func NewFirstBuffer(child sql.Expression) *firstBuffer { - return &firstBuffer{nil, child} + return &firstBuffer{nil, false, child} } // Update implements the AggregationBuffer interface. func (f *firstBuffer) Update(ctx *sql.Context, row sql.Row) error { - if f.val != nil { + if f.val != nil || f.writtenNil { return nil } @@ -520,6 +522,7 @@ func (f *firstBuffer) Update(ctx *sql.Context, row sql.Row) error { } if v == nil { + f.writtenNil = true return nil } diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index ea2a9691c3..14c353a843 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -58,7 +58,9 @@ func (c *Coalesce) Type() sql.Type { if c.typ != nil { return c.typ } - retType := types.Null + + var retType sql.Type + retType = types.Null for i, arg := range c.args { if arg == nil { continue diff --git a/sql/expression/function/concat.go b/sql/expression/function/concat.go index e2541a62d1..1dc96a951e 100644 --- a/sql/expression/function/concat.go +++ b/sql/expression/function/concat.go @@ -123,17 +123,13 @@ func (c *Concat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - val, _, err = types.LongText.Convert(ctx, val) + // Use type-aware conversion for enum types + content, _, err := types.ConvertToCollatedString(ctx, val, arg.Type()) if err != nil { return nil, err } - val, _, err = sql.Unwrap[string](ctx, val) - if err != nil { - return nil, err - } - - parts = append(parts, val.(string)) + parts = append(parts, content) } return strings.Join(parts, ""), nil diff --git a/sql/expression/function/conv.go b/sql/expression/function/conv.go index 82dcbb02d0..517490df1c 100644 --- a/sql/expression/function/conv.go +++ b/sql/expression/function/conv.go @@ -136,62 +136,66 @@ func (c *Conv) WithChildren(children ...sql.Expression) (sql.Expression, error) // This conversion truncates nVal as its first subpart that is convertable. // nVal is treated as unsigned except nVal is negative. func convertFromBase(ctx *sql.Context, nVal string, fromBase interface{}) interface{} { - fromBase, _, err := types.Int64.Convert(ctx, fromBase) - if err != nil { + if len(nVal) == 0 { return nil } - fromVal := int(math.Abs(float64(fromBase.(int64)))) + // Convert and validate fromBase + baseVal, _, err := types.Int64.Convert(ctx, fromBase) + if err != nil { + return nil + } + fromVal := int(math.Abs(float64(baseVal.(int64)))) if fromVal < 2 || fromVal > 36 { return nil } + // Handle sign negative := false - var upper string - var lower string - if nVal[0] == '-' { + switch nVal[0] { + case '-': + if len(nVal) == 1 { + return uint64(0) + } negative = true nVal = nVal[1:] - } else if nVal[0] == '+' { + case '+': + if len(nVal) == 1 { + return uint64(0) + } nVal = nVal[1:] } - // check for upper and lower bound for given fromBase + // Determine bounds based on sign + var maxLen int if negative { - upper = strconv.FormatInt(math.MaxInt64, fromVal) - lower = strconv.FormatInt(math.MinInt64, fromVal) - if len(nVal) > len(lower) { - nVal = lower - } else if len(nVal) > len(upper) { - nVal = upper + maxLen = len(strconv.FormatInt(math.MinInt64, fromVal)) + if len(nVal) > maxLen { + // Use MinInt64 representation in the given base + nVal = strconv.FormatInt(math.MinInt64, fromVal)[1:] // remove minus sign } } else { - upper = strconv.FormatUint(math.MaxUint64, fromVal) - lower = "0" - if len(nVal) < len(lower) { - nVal = lower - } else if len(nVal) > len(upper) { - nVal = upper + maxLen = len(strconv.FormatUint(math.MaxUint64, fromVal)) + if len(nVal) > maxLen { + // Use MaxUint64 representation in the given base + nVal = strconv.FormatUint(math.MaxUint64, fromVal) } } - truncate := false - result := uint64(0) - i := 1 - for !truncate && i <= len(nVal) { + // Find the longest valid prefix that can be converted + var result uint64 + for i := 1; i <= len(nVal); i++ { val, err := strconv.ParseUint(nVal[:i], fromVal, 64) if err != nil { - truncate = true - return result + break } result = val - i++ } if negative { + // MySQL returns signed value for negative inputs return int64(result) * -1 } - return result } diff --git a/sql/expression/function/conv_test.go b/sql/expression/function/conv_test.go index 664701be0d..05b00adb38 100644 --- a/sql/expression/function/conv_test.go +++ b/sql/expression/function/conv_test.go @@ -35,6 +35,8 @@ func TestConv(t *testing.T) { {"n is nil", types.Int32, sql.NewRow(nil, 16, 2), nil}, {"fromBase is nil", types.LongText, sql.NewRow('a', nil, 2), nil}, {"toBase is nil", types.LongText, sql.NewRow('a', 16, nil), nil}, + {"empty n string", types.LongText, sql.NewRow("", 3, 4), nil}, + {"empty arg strings", types.LongText, sql.NewRow(4, "", ""), nil}, // invalid inputs {"invalid N", types.LongText, sql.NewRow("r", 16, 2), "0"}, diff --git a/sql/expression/function/export_set.go b/sql/expression/function/export_set.go new file mode 100644 index 0000000000..acff3ff7ac --- /dev/null +++ b/sql/expression/function/export_set.go @@ -0,0 +1,230 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// ExportSet implements the SQL function EXPORT_SET() which returns a string representation of bits in a number +type ExportSet struct { + bits sql.Expression + on sql.Expression + off sql.Expression + separator sql.Expression + numberOfBits sql.Expression +} + +var _ sql.FunctionExpression = (*ExportSet)(nil) +var _ sql.CollationCoercible = (*ExportSet)(nil) + +// NewExportSet creates a new ExportSet expression +func NewExportSet(args ...sql.Expression) (sql.Expression, error) { + if len(args) < 3 || len(args) > 5 { + return nil, sql.ErrInvalidArgumentNumber.New("EXPORT_SET", "3, 4, or 5", len(args)) + } + + var separator, numberOfBits sql.Expression + if len(args) >= 4 { + separator = args[3] + } + if len(args) == 5 { + numberOfBits = args[4] + } + + return &ExportSet{ + bits: args[0], + on: args[1], + off: args[2], + separator: separator, + numberOfBits: numberOfBits, + }, nil +} + +// FunctionName implements sql.FunctionExpression +func (e *ExportSet) FunctionName() string { + return "export_set" +} + +// Description implements sql.FunctionExpression +func (e *ExportSet) Description() string { + return "returns a string such that for every bit set in the value bits, you get an on string and for every unset bit, you get an off string." +} + +// Children implements the Expression interface +func (e *ExportSet) Children() []sql.Expression { + children := []sql.Expression{e.bits, e.on, e.off} + if e.separator != nil { + children = append(children, e.separator) + } + if e.numberOfBits != nil { + children = append(children, e.numberOfBits) + } + return children +} + +// Resolved implements the Expression interface +func (e *ExportSet) Resolved() bool { + for _, child := range e.Children() { + if !child.Resolved() { + return false + } + } + return true +} + +// IsNullable implements the Expression interface +func (e *ExportSet) IsNullable() bool { + for _, child := range e.Children() { + if child.IsNullable() { + return true + } + } + return false +} + +// Type implements the Expression interface +func (e *ExportSet) Type() sql.Type { + return types.LongText +} + +// CollationCoercibility implements the interface sql.CollationCoercible +func (e *ExportSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + collation, coercibility = sql.GetCoercibility(ctx, e.on) + otherCollation, otherCoercibility := sql.GetCoercibility(ctx, e.off) + collation, coercibility = sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility) + if e.separator != nil { + otherCollation, otherCoercibility = sql.GetCoercibility(ctx, e.separator) + collation, coercibility = sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility) + } + return collation, coercibility +} + +// String implements the Expression interface +func (e *ExportSet) String() string { + children := e.Children() + childStrs := make([]string, len(children)) + for i, child := range children { + childStrs[i] = child.String() + } + return fmt.Sprintf("export_set(%s)", strings.Join(childStrs, ", ")) +} + +// WithChildren implements the Expression interface +func (e *ExportSet) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewExportSet(children...) +} + +// Eval implements the Expression interface +func (e *ExportSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + bitsVal, err := e.bits.Eval(ctx, row) + if err != nil { + return nil, err + } + if bitsVal == nil { + return nil, nil + } + + onVal, err := e.on.Eval(ctx, row) + if err != nil { + return nil, err + } + if onVal == nil { + return nil, nil + } + + offVal, err := e.off.Eval(ctx, row) + if err != nil { + return nil, err + } + if offVal == nil { + return nil, nil + } + + // Default separator is comma + separatorVal := "," + if e.separator != nil { + sepVal, err := e.separator.Eval(ctx, row) + if err != nil { + return nil, err + } + if sepVal == nil { + return nil, nil + } + sepStr, _, err := types.LongText.Convert(ctx, sepVal) + if err != nil { + return nil, err + } + separatorVal = sepStr.(string) + } + + // Default number of bits is 64 + numberOfBitsVal := int64(64) + if e.numberOfBits != nil { + numBitsVal, err := e.numberOfBits.Eval(ctx, row) + if err != nil { + return nil, err + } + if numBitsVal == nil { + return nil, nil + } + numBitsInt, _, err := types.Int64.Convert(ctx, numBitsVal) + if err != nil { + return nil, err + } + numberOfBitsVal = numBitsInt.(int64) + // MySQL silently clips to 64 if larger, treats negative as 64 + if numberOfBitsVal > 64 || numberOfBitsVal < 0 { + numberOfBitsVal = 64 + } + } + + // Convert arguments to proper types + bitsInt, _, err := types.Uint64.Convert(ctx, bitsVal) + if err != nil { + return nil, err + } + + onStr, _, err := types.LongText.Convert(ctx, onVal) + if err != nil { + return nil, err + } + + offStr, _, err := types.LongText.Convert(ctx, offVal) + if err != nil { + return nil, err + } + + bits := bitsInt.(uint64) + on := onStr.(string) + off := offStr.(string) + + // Build the result by examining bits from right to left (LSB to MSB) + // but adding strings from left to right + result := make([]string, numberOfBitsVal) + for i := int64(0); i < numberOfBitsVal; i++ { + if (bits & (1 << uint(i))) != 0 { + result[i] = on + } else { + result[i] = off + } + } + + return strings.Join(result, separatorVal), nil +} diff --git a/sql/expression/function/export_set_test.go b/sql/expression/function/export_set_test.go new file mode 100644 index 0000000000..c6425211f3 --- /dev/null +++ b/sql/expression/function/export_set_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestExportSet(t *testing.T) { + testCases := []struct { + name string + args []interface{} + expected interface{} + err bool + }{ + // MySQL documentation examples + {"mysql example 1", []interface{}{5, "Y", "N", ",", 4}, "Y,N,Y,N", false}, + {"mysql example 2", []interface{}{6, "1", "0", ",", 10}, "0,1,1,0,0,0,0,0,0,0", false}, + + // Basic functionality tests + {"zero value", []interface{}{0, "1", "0", ",", 4}, "0,0,0,0", false}, + {"all bits set", []interface{}{15, "1", "0", ",", 4}, "1,1,1,1", false}, + {"single bit", []interface{}{1, "T", "F", ",", 3}, "T,F,F", false}, + {"single bit position 2", []interface{}{2, "T", "F", ",", 3}, "F,T,F", false}, + {"single bit position 3", []interface{}{4, "T", "F", ",", 3}, "F,F,T", false}, + + // Different separators + {"pipe separator", []interface{}{5, "1", "0", "|", 4}, "1|0|1|0", false}, + {"space separator", []interface{}{5, "1", "0", " ", 4}, "1 0 1 0", false}, + {"empty separator", []interface{}{5, "1", "0", "", 4}, "1010", false}, + {"no separator specified", []interface{}{5, "1", "0"}, "1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", false}, + + // Different on/off strings + {"word strings", []interface{}{5, "ON", "OFF", ",", 4}, "ON,OFF,ON,OFF", false}, + {"empty on string", []interface{}{5, "", "0", ",", 4}, ",0,,0", false}, + {"empty off string", []interface{}{5, "1", "", ",", 4}, "1,,1,", false}, + + // Number of bits tests + {"no number of bits specified", []interface{}{5, "1", "0"}, "1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", false}, + {"1 bit", []interface{}{5, "1", "0", ",", 1}, "1", false}, + {"8 bits", []interface{}{255, "1", "0", ",", 8}, "1,1,1,1,1,1,1,1", false}, + {"large number of bits", []interface{}{5, "1", "0", ",", 100}, "1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", false}, + {"negative number of bits", []interface{}{5, "1", "0", ",", -5}, "1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", false}, + + // Large numbers + {"large number", []interface{}{4294967295, "1", "0", ",", 32}, "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1", false}, + {"powers of 2", []interface{}{1024, "1", "0", ",", 12}, "0,0,0,0,0,0,0,0,0,0,1,0", false}, + + // NULL handling + {"null bits", []interface{}{nil, "1", "0", ",", 4}, nil, false}, + {"null on", []interface{}{5, nil, "0", ",", 4}, nil, false}, + {"null off", []interface{}{5, "1", nil, ",", 4}, nil, false}, + {"null separator", []interface{}{5, "1", "0", nil, 4}, nil, false}, + {"null number of bits", []interface{}{5, "1", "0", ",", nil}, nil, false}, + + // Type conversion + {"string number", []interface{}{"5", "1", "0", ",", 4}, "1,0,1,0", false}, + {"float number", []interface{}{5.7, "1", "0", ",", 4}, "0,1,1,0", false}, + {"negative number", []interface{}{-1, "1", "0", ",", 4}, "1,1,1,1", false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + // Convert test args to expressions + args := make([]sql.Expression, len(tt.args)) + for i, arg := range tt.args { + if arg == nil { + args[i] = expression.NewLiteral(nil, types.Null) + } else { + switch v := arg.(type) { + case int: + args[i] = expression.NewLiteral(int64(v), types.Int64) + case string: + args[i] = expression.NewLiteral(v, types.LongText) + default: + args[i] = expression.NewLiteral(v, types.LongText) + } + } + } + + f, err := NewExportSet(args...) + require.NoError(err) + + v, err := f.Eval(ctx, nil) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestExportSetArguments(t *testing.T) { + require := require.New(t) + + // Test invalid number of arguments + _, err := NewExportSet() + require.Error(err) + + _, err = NewExportSet(expression.NewLiteral(1, types.Int64)) + require.Error(err) + + _, err = NewExportSet(expression.NewLiteral(1, types.Int64), expression.NewLiteral("1", types.Text)) + require.Error(err) + + // Test too many arguments + args := make([]sql.Expression, 6) + for i := range args { + args[i] = expression.NewLiteral(1, types.Int64) + } + _, err = NewExportSet(args...) + require.Error(err) + + // Test valid argument counts + validArgs := [][]sql.Expression{ + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("1", types.Text), expression.NewLiteral("0", types.Text)}, + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("1", types.Text), expression.NewLiteral("0", types.Text), expression.NewLiteral(",", types.Text)}, + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("1", types.Text), expression.NewLiteral("0", types.Text), expression.NewLiteral(",", types.Text), expression.NewLiteral(4, types.Int64)}, + } + + for _, args := range validArgs { + _, err := NewExportSet(args...) + require.NoError(err) + } +} diff --git a/sql/expression/function/if.go b/sql/expression/function/if.go index ebbe34a02b..c019357f39 100644 --- a/sql/expression/function/if.go +++ b/sql/expression/function/if.go @@ -77,26 +77,27 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + var eval interface{} if asBool { - return f.ifTrue.Eval(ctx, row) + eval, err = f.ifTrue.Eval(ctx, row) + if err != nil { + return nil, err + } } else { - return f.ifFalse.Eval(ctx, row) + eval, err = f.ifFalse.Eval(ctx, row) + if err != nil { + return nil, err + } + } + if ret, _, err := f.Type().Convert(ctx, eval); err == nil { + return ret, nil } + return eval, err } // Type implements the Expression interface. func (f *If) Type() sql.Type { - // if either type is string type, this should be a string type, regardless need to promote - typ1 := f.ifTrue.Type() - typ2 := f.ifFalse.Type() - if types.IsText(typ1) || types.IsText(typ2) { - return types.Text - } - - if typ1 == types.Null { - return typ2.Promote() - } - return typ1.Promote() + return types.GeneralizeTypes(f.ifTrue.Type(), f.ifFalse.Type()) } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/if_test.go b/sql/expression/function/if_test.go index 2559f40438..946912ff46 100644 --- a/sql/expression/function/if_test.go +++ b/sql/expression/function/if_test.go @@ -29,20 +29,22 @@ func TestIf(t *testing.T) { expr sql.Expression row sql.Row expected interface{} + type1 sql.Type + type2 sql.Type }{ - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a"}, - {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b"}, - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, 1}, - {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, 2}, - {eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b"}, - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, int64(1), types.Int64, types.Int64}, + {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, int64(2), types.Int64, types.Int64}, + {eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil, nil, types.Text}, } for _, tc := range testCases { f := NewIf( tc.expr, - expression.NewGetField(0, types.LongText, "true", true), - expression.NewGetField(1, types.LongText, "false", true), + expression.NewGetField(0, tc.type1, "true", true), + expression.NewGetField(1, tc.type2, "false", true), ) v, err := f.Eval(sql.NewEmptyContext(), tc.row) diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go index 9f5e4f8709..9e80a16337 100644 --- a/sql/expression/function/ifnull.go +++ b/sql/expression/function/ifnull.go @@ -52,30 +52,32 @@ func (f *IfNull) Description() string { // Eval implements the Expression interface. func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + t := f.Type() + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } if left != nil { - return left, nil + if ret, _, err := t.Convert(ctx, left); err == nil { + return ret, nil + } + return left, err } right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } - return right, nil + if ret, _, err := t.Convert(ctx, right); err == nil { + return ret, nil + } + return right, err } // Type implements the Expression interface. func (f *IfNull) Type() sql.Type { - if types.IsNull(f.LeftChild) { - if types.IsNull(f.RightChild) { - return types.Null - } - return f.RightChild.Type() - } - return f.LeftChild.Type() + return types.GeneralizeTypes(f.LeftChild.Type(), f.RightChild.Type()) } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/ifnull_test.go b/sql/expression/function/ifnull_test.go index ed6acc3336..507b8e7bdd 100644 --- a/sql/expression/function/ifnull_test.go +++ b/sql/expression/function/ifnull_test.go @@ -26,25 +26,28 @@ import ( func TestIfNull(t *testing.T) { testCases := []struct { - expression interface{} - value interface{} - expected interface{} + expression interface{} + expressionType sql.Type + value interface{} + valueType sql.Type + expected interface{} + expectedType sql.Type }{ - {"foo", "bar", "foo"}, - {"foo", "foo", "foo"}, - {nil, "foo", "foo"}, - {"foo", nil, "foo"}, - {nil, nil, nil}, - {"", nil, ""}, + {"foo", types.LongText, "bar", types.LongText, "foo", types.LongText}, + {"foo", types.LongText, "foo", types.LongText, "foo", types.LongText}, + {nil, types.LongText, "foo", types.LongText, "foo", types.LongText}, + {"foo", types.LongText, nil, types.LongText, "foo", types.LongText}, + {nil, types.LongText, nil, types.LongText, nil, types.LongText}, + {"", types.LongText, nil, types.LongText, "", types.LongText}, + {nil, types.Int8, 128, types.Int64, int64(128), types.Int64}, } - f := NewIfNull( - expression.NewGetField(0, types.LongText, "expression", true), - expression.NewGetField(1, types.LongText, "value", true), - ) - require.Equal(t, types.LongText, f.Type()) - for _, tc := range testCases { + f := NewIfNull( + expression.NewGetField(0, tc.expressionType, "expression", true), + expression.NewGetField(1, tc.valueType, "value", true), + ) + require.Equal(t, tc.expectedType, f.Type()) v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value)) require.NoError(t, err) require.Equal(t, tc.expected, v) diff --git a/sql/expression/function/insert.go b/sql/expression/function/insert.go new file mode 100644 index 0000000000..55029521bc --- /dev/null +++ b/sql/expression/function/insert.go @@ -0,0 +1,179 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Insert implements the SQL function INSERT() which inserts a substring at a specified position +type Insert struct { + str sql.Expression + pos sql.Expression + length sql.Expression + newStr sql.Expression +} + +var _ sql.FunctionExpression = (*Insert)(nil) +var _ sql.CollationCoercible = (*Insert)(nil) + +// NewInsert creates a new Insert expression +func NewInsert(str, pos, length, newStr sql.Expression) sql.Expression { + return &Insert{str, pos, length, newStr} +} + +// FunctionName implements sql.FunctionExpression +func (i *Insert) FunctionName() string { + return "insert" +} + +// Description implements sql.FunctionExpression +func (i *Insert) Description() string { + return "returns the string str, with the substring beginning at position pos and len characters long replaced by the string newstr." +} + +// Children implements the Expression interface +func (i *Insert) Children() []sql.Expression { + return []sql.Expression{i.str, i.pos, i.length, i.newStr} +} + +// Resolved implements the Expression interface +func (i *Insert) Resolved() bool { + return i.str.Resolved() && i.pos.Resolved() && i.length.Resolved() && i.newStr.Resolved() +} + +// IsNullable implements the Expression interface +func (i *Insert) IsNullable() bool { + return i.str.IsNullable() || i.pos.IsNullable() || i.length.IsNullable() || i.newStr.IsNullable() +} + +// Type implements the Expression interface +func (i *Insert) Type() sql.Type { + return types.LongText +} + +// CollationCoercibility implements the interface sql.CollationCoercible +func (i *Insert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + collation, coercibility = sql.GetCoercibility(ctx, i.str) + otherCollation, otherCoercibility := sql.GetCoercibility(ctx, i.newStr) + return sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility) +} + +// String implements the Expression interface +func (i *Insert) String() string { + return fmt.Sprintf("insert(%s, %s, %s, %s)", i.str, i.pos, i.length, i.newStr) +} + +// WithChildren implements the Expression interface +func (i *Insert) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 4 { + return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 4) + } + return NewInsert(children[0], children[1], children[2], children[3]), nil +} + +// Eval implements the Expression interface +func (i *Insert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := i.str.Eval(ctx, row) + if err != nil { + return nil, err + } + if str == nil { + return nil, nil + } + + pos, err := i.pos.Eval(ctx, row) + if err != nil { + return nil, err + } + if pos == nil { + return nil, nil + } + + length, err := i.length.Eval(ctx, row) + if err != nil { + return nil, err + } + if length == nil { + return nil, nil + } + + newStr, err := i.newStr.Eval(ctx, row) + if err != nil { + return nil, err + } + if newStr == nil { + return nil, nil + } + + // Convert all arguments to their expected types + strVal, _, err := types.LongText.Convert(ctx, str) + if err != nil { + return nil, err + } + + posVal, _, err := types.Int64.Convert(ctx, pos) + if err != nil { + return nil, err + } + + lengthVal, _, err := types.Int64.Convert(ctx, length) + if err != nil { + return nil, err + } + + newStrVal, _, err := types.LongText.Convert(ctx, newStr) + if err != nil { + return nil, err + } + + s := strVal.(string) + p := posVal.(int64) + l := lengthVal.(int64) + n := newStrVal.(string) + + // MySQL uses 1-based indexing for position + // Handle negative position - return original string + if p < 1 { + return s, nil + } + + // Convert to 0-based indexing + startIdx := p - 1 + + // Handle case where position is beyond string length + if startIdx >= int64(len(s)) { + return s, nil + } + + // Calculate end index + // For negative length, replace from position to end of string + var endIdx int64 + if l < 0 { + endIdx = int64(len(s)) + } else { + endIdx = startIdx + l + if endIdx > int64(len(s)) { + endIdx = int64(len(s)) + } + } + + // Build the result string + result := s[:startIdx] + n + s[endIdx:] + return result, nil +} diff --git a/sql/expression/function/insert_test.go b/sql/expression/function/insert_test.go new file mode 100644 index 0000000000..8db924ef32 --- /dev/null +++ b/sql/expression/function/insert_test.go @@ -0,0 +1,78 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestInsert(t *testing.T) { + f := NewInsert( + expression.NewGetField(0, types.LongText, "", false), + expression.NewGetField(1, types.Int64, "", false), + expression.NewGetField(2, types.Int64, "", false), + expression.NewGetField(3, types.LongText, "", false), + ) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null str", sql.NewRow(nil, 1, 2, "new"), nil, false}, + {"null pos", sql.NewRow("hello", nil, 2, "new"), nil, false}, + {"null length", sql.NewRow("hello", 1, nil, "new"), nil, false}, + {"null newStr", sql.NewRow("hello", 1, 2, nil), nil, false}, + {"empty string", sql.NewRow("", 1, 2, "new"), "", false}, + {"position is 0", sql.NewRow("hello", 0, 2, "new"), "hello", false}, + {"position is negative", sql.NewRow("hello", -1, 2, "new"), "hello", false}, + {"negative length", sql.NewRow("hello", 1, -1, "new"), "new", false}, + {"position beyond string length", sql.NewRow("hello", 10, 2, "new"), "hello", false}, + {"normal insertion", sql.NewRow("hello", 2, 2, "xyz"), "hxyzlo", false}, + {"insert at beginning", sql.NewRow("hello", 1, 2, "xyz"), "xyzllo", false}, + {"insert at end", sql.NewRow("hello", 5, 1, "xyz"), "hellxyz", false}, + {"replace entire string", sql.NewRow("hello", 1, 5, "world"), "world", false}, + {"length exceeds string", sql.NewRow("hello", 3, 10, "world"), "heworld", false}, + {"empty replacement", sql.NewRow("hello", 2, 2, ""), "hlo", false}, + {"zero length", sql.NewRow("hello", 3, 0, "xyz"), "hexyzllo", false}, + {"negative length from middle", sql.NewRow("hello", 3, -1, "xyz"), "hexyz", false}, + {"negative length from beginning", sql.NewRow("hello", 1, -5, "xyz"), "xyz", false}, + {"large positive length", sql.NewRow("hello", 2, 100, "xyz"), "hxyz", false}, + {"length exactly matches remaining", sql.NewRow("hello", 3, 3, "xyz"), "hexyz", false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/make_set.go b/sql/expression/function/make_set.go new file mode 100644 index 0000000000..8471706a46 --- /dev/null +++ b/sql/expression/function/make_set.go @@ -0,0 +1,152 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// MakeSet implements the SQL function MAKE_SET() which returns a comma-separated set of strings +// where the corresponding bit in bits is set +type MakeSet struct { + bits sql.Expression + values []sql.Expression +} + +var _ sql.FunctionExpression = (*MakeSet)(nil) +var _ sql.CollationCoercible = (*MakeSet)(nil) + +// NewMakeSet creates a new MakeSet expression +func NewMakeSet(args ...sql.Expression) (sql.Expression, error) { + if len(args) < 2 { + return nil, sql.ErrInvalidArgumentNumber.New("MAKE_SET", "2 or more", len(args)) + } + + return &MakeSet{ + bits: args[0], + values: args[1:], + }, nil +} + +// FunctionName implements sql.FunctionExpression +func (m *MakeSet) FunctionName() string { + return "make_set" +} + +// Description implements sql.FunctionExpression +func (m *MakeSet) Description() string { + return "returns a set string (a string containing substrings separated by , characters) consisting of the strings that have the corresponding bit in bits set." +} + +// Children implements the Expression interface +func (m *MakeSet) Children() []sql.Expression { + children := []sql.Expression{m.bits} + children = append(children, m.values...) + return children +} + +// Resolved implements the Expression interface +func (m *MakeSet) Resolved() bool { + for _, child := range m.Children() { + if !child.Resolved() { + return false + } + } + return true +} + +// IsNullable implements the Expression interface +func (m *MakeSet) IsNullable() bool { + return m.bits.IsNullable() +} + +// Type implements the Expression interface +func (m *MakeSet) Type() sql.Type { + return types.LongText +} + +// CollationCoercibility implements the interface sql.CollationCoercible +func (m *MakeSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + // Start with highest coercibility (most coercible) + collation = sql.Collation_Default + coercibility = 5 + + for _, value := range m.values { + valueCollation, valueCoercibility := sql.GetCoercibility(ctx, value) + collation, coercibility = sql.ResolveCoercibility(collation, coercibility, valueCollation, valueCoercibility) + } + + return collation, coercibility +} + +// String implements the Expression interface +func (m *MakeSet) String() string { + children := m.Children() + childStrs := make([]string, len(children)) + for i, child := range children { + childStrs[i] = child.String() + } + return fmt.Sprintf("make_set(%s)", strings.Join(childStrs, ", ")) +} + +// WithChildren implements the Expression interface +func (m *MakeSet) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewMakeSet(children...) +} + +// Eval implements the Expression interface +func (m *MakeSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + bitsVal, err := m.bits.Eval(ctx, row) + if err != nil { + return nil, err + } + if bitsVal == nil { + return nil, nil + } + + // Convert bits to uint64 + bitsInt, _, err := types.Uint64.Convert(ctx, bitsVal) + if err != nil { + return nil, err + } + bits := bitsInt.(uint64) + + var result []string + + // Check each value argument against the corresponding bit + for i, valueExpr := range m.values { + // Check if bit i is set + if (bits & (1 << uint(i))) != 0 { + val, err := valueExpr.Eval(ctx, row) + if err != nil { + return nil, err + } + // Skip NULL values + if val != nil { + valStr, _, err := types.LongText.Convert(ctx, val) + if err != nil { + return nil, err + } + result = append(result, valStr.(string)) + } + } + } + + return strings.Join(result, ","), nil +} diff --git a/sql/expression/function/make_set_test.go b/sql/expression/function/make_set_test.go new file mode 100644 index 0000000000..de8b742cf9 --- /dev/null +++ b/sql/expression/function/make_set_test.go @@ -0,0 +1,148 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestMakeSet(t *testing.T) { + testCases := []struct { + name string + args []interface{} + expected interface{} + err bool + }{ + // MySQL documentation examples + {"mysql example 1", []interface{}{1, "a", "b", "c"}, "a", false}, + {"mysql example 2", []interface{}{1 | 4, "hello", "nice", "world"}, "hello,world", false}, + {"mysql example 3", []interface{}{1 | 4, "hello", "nice", nil, "world"}, "hello", false}, + {"mysql example 4", []interface{}{0, "a", "b", "c"}, "", false}, + + // Basic functionality tests + {"single bit set - bit 0", []interface{}{1, "first", "second", "third"}, "first", false}, + {"single bit set - bit 1", []interface{}{2, "first", "second", "third"}, "second", false}, + {"single bit set - bit 2", []interface{}{4, "first", "second", "third"}, "third", false}, + {"no bits set", []interface{}{0, "first", "second", "third"}, "", false}, + + // Multiple bits set + {"bits 0 and 1", []interface{}{3, "a", "b", "c"}, "a,b", false}, + {"bits 0 and 2", []interface{}{5, "a", "b", "c"}, "a,c", false}, + {"bits 1 and 2", []interface{}{6, "a", "b", "c"}, "b,c", false}, + {"all bits set", []interface{}{7, "a", "b", "c"}, "a,b,c", false}, + + // Large bit numbers + {"bit 10 set", []interface{}{1024, "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"}, "k", false}, + {"bits 0 and 10", []interface{}{1025, "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"}, "a,k", false}, + + // NULL handling + {"null bits", []interface{}{nil, "a", "b", "c"}, nil, false}, + {"null in middle", []interface{}{7, "a", nil, "c"}, "a,c", false}, + {"null at start", []interface{}{7, nil, "b", "c"}, "b,c", false}, + {"null at end", []interface{}{7, "a", "b", nil}, "a,b", false}, + {"all nulls", []interface{}{7, nil, nil, nil}, "", false}, + + // Type conversion + {"string bits", []interface{}{"5", "a", "b", "c"}, "a,c", false}, + {"float bits", []interface{}{5.7, "a", "b", "c"}, "b,c", false}, // 5.7 converts to 6 (binary 110) + {"negative bits", []interface{}{-1, "a", "b", "c"}, "a,b,c", false}, + + // Different value types + {"numeric strings", []interface{}{3, "1", "2", "3"}, "1,2", false}, + {"mixed types", []interface{}{3, 123, "hello", 456}, "123,hello", false}, + + // Edge cases + {"no strings provided", []interface{}{1}, "", true}, + {"bit beyond available strings", []interface{}{16, "a", "b", "c"}, "", false}, + {"bit partially beyond strings", []interface{}{9, "a", "b", "c"}, "a", false}, + + // Large numbers + {"max uint64 bits", []interface{}{^uint64(0), "a", "b", "c"}, "a,b,c", false}, + {"large positive number", []interface{}{4294967295, "a", "b", "c"}, "a,b,c", false}, + + // Empty strings + {"empty string values", []interface{}{3, "", "test", ""}, ",test", false}, + {"only empty strings", []interface{}{3, "", ""}, ",", false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + // Convert test args to expressions + args := make([]sql.Expression, len(tt.args)) + for i, arg := range tt.args { + if arg == nil { + args[i] = expression.NewLiteral(nil, types.Null) + } else { + switch v := arg.(type) { + case int: + args[i] = expression.NewLiteral(int64(v), types.Int64) + case uint64: + args[i] = expression.NewLiteral(v, types.Uint64) + case float64: + args[i] = expression.NewLiteral(v, types.Float64) + case string: + args[i] = expression.NewLiteral(v, types.LongText) + default: + args[i] = expression.NewLiteral(v, types.LongText) + } + } + } + + f, err := NewMakeSet(args...) + if tt.err { + require.Error(err) + return + } + require.NoError(err) + + v, err := f.Eval(ctx, nil) + require.NoError(err) + require.Equal(tt.expected, v) + }) + } +} + +func TestMakeSetArguments(t *testing.T) { + require := require.New(t) + + // Test invalid number of arguments + _, err := NewMakeSet() + require.Error(err) + + _, err = NewMakeSet(expression.NewLiteral(1, types.Int64)) + require.Error(err) + + // Test valid argument counts + validArgs := [][]sql.Expression{ + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("a", types.Text)}, + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("a", types.Text), expression.NewLiteral("b", types.Text)}, + {expression.NewLiteral(1, types.Int64), expression.NewLiteral("a", types.Text), expression.NewLiteral("b", types.Text), expression.NewLiteral("c", types.Text)}, + } + + for _, args := range validArgs { + _, err := NewMakeSet(args...) + require.NoError(err) + } +} diff --git a/sql/expression/function/oct.go b/sql/expression/function/oct.go new file mode 100644 index 0000000000..f287de6281 --- /dev/null +++ b/sql/expression/function/oct.go @@ -0,0 +1,91 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Oct function provides a string representation for the octal value of N, where N is a decimal (base 10) number. +type Oct struct { + n sql.Expression +} + +var _ sql.FunctionExpression = (*Oct)(nil) +var _ sql.CollationCoercible = (*Oct)(nil) + +// NewOct returns a new Oct expression. +func NewOct(n sql.Expression) sql.Expression { return &Oct{n} } + +// FunctionName implements sql.FunctionExpression. +func (o *Oct) FunctionName() string { + return "oct" +} + +// Description implements sql.FunctionExpression. +func (o *Oct) Description() string { + return "returns a string representation for octal value of N, where N is a decimal (base 10) number." +} + +// Type implements the Expression interface. +func (o *Oct) Type() sql.Type { + return types.LongText +} + +// IsNullable implements the Expression interface. +func (o *Oct) IsNullable() bool { + return o.n.IsNullable() +} + +// Eval implements the Expression interface. +func (o *Oct) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + // Convert a decimal (base 10) number to octal (base 8) + return NewConv( + o.n, + expression.NewLiteral(10, types.Int64), + expression.NewLiteral(8, types.Int64), + ).Eval(ctx, row) +} + +// Resolved implements the Expression interface. +func (o *Oct) Resolved() bool { + return o.n.Resolved() +} + +// Children implements the Expression interface. +func (o *Oct) Children() []sql.Expression { + return []sql.Expression{o.n} +} + +// WithChildren implements the Expression interface. +func (o *Oct) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1) + } + return NewOct(children[0]), nil +} + +func (o *Oct) String() string { + return fmt.Sprintf("%s(%s)", o.FunctionName(), o.n) +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*Oct) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return ctx.GetCollation(), 4 // strings with collations +} diff --git a/sql/expression/function/oct_test.go b/sql/expression/function/oct_test.go new file mode 100644 index 0000000000..7cd978405e --- /dev/null +++ b/sql/expression/function/oct_test.go @@ -0,0 +1,80 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "math" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +type test struct { + name string + nType sql.Type + row sql.Row + expected interface{} +} + +func TestOct(t *testing.T) { + tests := []test{ + // NULL input + {"n is nil", types.Int32, sql.NewRow(nil), nil}, + + // Positive numbers + {"positive small", types.Int32, sql.NewRow(8), "10"}, + {"positive medium", types.Int32, sql.NewRow(64), "100"}, + {"positive large", types.Int32, sql.NewRow(4095), "7777"}, + {"positive huge", types.Int64, sql.NewRow(123456789), "726746425"}, + + // Negative numbers + {"negative small", types.Int32, sql.NewRow(-8), "1777777777777777777770"}, + {"negative medium", types.Int32, sql.NewRow(-64), "1777777777777777777700"}, + {"negative large", types.Int32, sql.NewRow(-4095), "1777777777777777770001"}, + + // Zero + {"zero", types.Int32, sql.NewRow(0), "0"}, + + // String inputs + {"string number", types.LongText, sql.NewRow("15"), "17"}, + {"alpha string", types.LongText, sql.NewRow("abc"), "0"}, + {"mixed string", types.LongText, sql.NewRow("123abc"), "173"}, + + // Edge cases + {"max int32", types.Int32, sql.NewRow(math.MaxInt32), "17777777777"}, + {"min int32", types.Int32, sql.NewRow(math.MinInt32), "1777777777760000000000"}, + {"max int64", types.Int64, sql.NewRow(math.MaxInt64), "777777777777777777777"}, + {"min int64", types.Int64, sql.NewRow(math.MinInt64), "1000000000000000000000"}, + + // Decimal numbers + {"decimal", types.Float64, sql.NewRow(15.5), "17"}, + {"negative decimal", types.Float64, sql.NewRow(-15.5), "1777777777777777777761"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := NewOct(expression.NewGetField(0, tt.nType, "n", true)) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if err != nil { + t.Fatal(err) + } + if result != tt.expected { + t.Errorf("got %v; expected %v", result, tt.expected) + } + }) + } +} diff --git a/sql/expression/function/regexp_instr.go b/sql/expression/function/regexp_instr.go index 3abee0f1d2..ba36757a8c 100644 --- a/sql/expression/function/regexp_instr.go +++ b/sql/expression/function/regexp_instr.go @@ -167,8 +167,8 @@ func (r *RegexpInstr) String() string { // compile handles compilation of the regex. func (r *RegexpInstr) compile(ctx *sql.Context, row sql.Row) { r.compileOnce.Do(func() { - r.cacheRegex = canBeCached(r.Text, r.Pattern, r.Flags) - r.cacheVal = canBeCached(r.Text, r.Pattern, r.Position, r.Occurrence, r.ReturnOption, r.Flags) + r.cacheRegex = canBeCached(r.Pattern, r.Flags) + r.cacheVal = r.cacheRegex && canBeCached(r.Text, r.Position, r.Occurrence, r.ReturnOption) if r.cacheRegex { r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row) } diff --git a/sql/expression/function/regexp_instr_test.go b/sql/expression/function/regexp_instr_test.go new file mode 100644 index 0000000000..287837b6c2 --- /dev/null +++ b/sql/expression/function/regexp_instr_test.go @@ -0,0 +1,54 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Last Run: 06/17/2025 +// BenchmarkRegexpInStr +// BenchmarkRegexpInStr-14 100 97313270 ns/op +// BenchmarkRegexpInStr-14 10000 1001064 ns/op +func BenchmarkRegexpInStr(b *testing.B) { + ctx := sql.NewEmptyContext() + data := make([]sql.Row, 100) + for i := range data { + data[i] = sql.Row{fmt.Sprintf("test%d", i)} + } + + for i := 0; i < b.N; i++ { + f, err := NewRegexpInstr( + expression.NewGetField(0, types.LongText, "text", false), + expression.NewLiteral("^test[0-9]$", types.LongText), + ) + require.NoError(b, err) + var total int + for _, row := range data { + res, err := f.Eval(ctx, row) + require.NoError(b, err) + total += int(res.(int32)) + } + require.Equal(b, 10, total) + f.(*RegexpInstr).Dispose() + } +} diff --git a/sql/expression/function/regexp_like.go b/sql/expression/function/regexp_like.go index 43a83eeeb9..ff6feefcc8 100644 --- a/sql/expression/function/regexp_like.go +++ b/sql/expression/function/regexp_like.go @@ -34,8 +34,9 @@ type RegexpLike struct { Pattern sql.Expression Flags sql.Expression + cacheVal bool cachedVal any - cacheable bool + cacheRegex bool re regex.Regex compileOnce sync.Once compileErr error @@ -136,12 +137,13 @@ func (r *RegexpLike) String() string { // compile handles compilation of the regex. func (r *RegexpLike) compile(ctx *sql.Context, row sql.Row) { r.compileOnce.Do(func() { - r.cacheable = canBeCached(r.Text, r.Pattern, r.Flags) - if r.cacheable { + r.cacheRegex = canBeCached(r.Pattern, r.Flags) + r.cacheVal = r.cacheRegex && canBeCached(r.Text) + if r.cacheRegex { r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row) } }) - if !r.cacheable { + if !r.cacheRegex { if r.re != nil { if r.compileErr = r.re.Close(); r.compileErr != nil { return @@ -199,7 +201,7 @@ func (r *RegexpLike) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { outVal = int8(0) } - if r.cacheable { + if r.cacheVal { r.cachedVal = outVal } return outVal, nil diff --git a/sql/expression/function/regexp_like_test.go b/sql/expression/function/regexp_like_test.go index 2a23ee641e..c8b66a1f61 100644 --- a/sql/expression/function/regexp_like_test.go +++ b/sql/expression/function/regexp_like_test.go @@ -363,3 +363,31 @@ func TestRegexpLikeNilAndErrors(t *testing.T) { require.Equal(t, nil, res) f.(*RegexpLike).Dispose() } + +// Last Run: 06/17/2025 +// BenchmarkRegexpLike +// BenchmarkRegexpLike-14 100 98269522 ns/op +// BenchmarkRegexpLike-14 10000 958159 ns/op +func BenchmarkRegexpLike(b *testing.B) { + ctx := sql.NewEmptyContext() + data := make([]sql.Row, 100) + for i := range data { + data[i] = sql.Row{fmt.Sprintf("test%d", i)} + } + + for i := 0; i < b.N; i++ { + f, err := NewRegexpLike( + expression.NewGetField(0, types.LongText, "text", false), + expression.NewLiteral("^test[0-9]$", types.LongText), + ) + require.NoError(b, err) + var total int8 + for _, row := range data { + res, err := f.Eval(ctx, row) + require.NoError(b, err) + total += res.(int8) + } + require.Equal(b, int8(10), total) + f.(*RegexpLike).Dispose() + } +} diff --git a/sql/expression/function/regexp_replace.go b/sql/expression/function/regexp_replace.go index 9a639e7bc4..266cea9290 100644 --- a/sql/expression/function/regexp_replace.go +++ b/sql/expression/function/regexp_replace.go @@ -17,29 +17,79 @@ package function import ( "fmt" "strings" + "sync" + regex "github.com/dolthub/go-icu-regex" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" ) // RegexpReplace implements the REGEXP_REPLACE function. // https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace type RegexpReplace struct { - args []sql.Expression + Text sql.Expression + Pattern sql.Expression + RText sql.Expression + Position sql.Expression + Occurrence sql.Expression + Flags sql.Expression + + cacheVal bool + cachedVal any + cacheRegex bool + re regex.Regex + compileOnce sync.Once + compileErr error } var _ sql.FunctionExpression = (*RegexpReplace)(nil) var _ sql.CollationCoercible = (*RegexpReplace)(nil) +var _ sql.Disposable = (*RegexpReplace)(nil) // NewRegexpReplace creates a new RegexpReplace expression. func NewRegexpReplace(args ...sql.Expression) (sql.Expression, error) { - if len(args) < 3 || len(args) > 6 { + var r *RegexpReplace + switch len(args) { + case 6: + r = &RegexpReplace{ + Text: args[0], + Pattern: args[1], + RText: args[2], + Position: args[3], + Occurrence: args[4], + Flags: args[5], + } + case 5: + r = &RegexpReplace{ + Text: args[0], + Pattern: args[1], + RText: args[2], + Position: args[3], + Occurrence: args[4], + } + case 4: + r = &RegexpReplace{ + Text: args[0], + Pattern: args[1], + RText: args[2], + Position: args[3], + Occurrence: expression.NewLiteral(0, types.Int32), + } + case 3: + r = &RegexpReplace{ + Text: args[0], + Pattern: args[1], + RText: args[2], + Position: expression.NewLiteral(1, types.Int32), + Occurrence: expression.NewLiteral(0, types.Int32), + } + default: return nil, sql.ErrInvalidArgumentNumber.New("regexp_replace", "3,4,5 or 6", len(args)) } - - return &RegexpReplace{args: args}, nil + return r, nil } // FunctionName implements sql.FunctionExpression @@ -57,14 +107,11 @@ func (r *RegexpReplace) Type() sql.Type { return types.LongText } // CollationCoercibility implements the interface sql.CollationCoercible. func (r *RegexpReplace) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - if len(r.args) == 0 { - return sql.Collation_binary, 6 - } - collation, coercibility = sql.GetCoercibility(ctx, r.args[0]) - for i := 1; i < len(r.args) && i < 3; i++ { - nextCollation, nextCoercibility := sql.GetCoercibility(ctx, r.args[i]) - collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility) - } + collation, coercibility = sql.GetCoercibility(ctx, r.Text) + nextCollation, nextCoercibility := sql.GetCoercibility(ctx, r.Pattern) + collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility) + nextCollation, nextCoercibility = sql.GetCoercibility(ctx, r.RText) + collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility) return collation, coercibility } @@ -73,152 +120,163 @@ func (r *RegexpReplace) IsNullable() bool { return true } // Children implements the sql.Expression interface. func (r *RegexpReplace) Children() []sql.Expression { - return r.args + var children = []sql.Expression{r.Text, r.Pattern, r.RText, r.Position, r.Occurrence} + if r.Flags != nil { + children = append(children, r.Flags) + } + return children } // Resolved implements the sql.Expression interface. func (r *RegexpReplace) Resolved() bool { - for _, arg := range r.args { - if !arg.Resolved() { - return false - } - } - return true + return r.Text.Resolved() && + r.Pattern.Resolved() && + r.RText.Resolved() && + r.Position.Resolved() && + r.Occurrence.Resolved() && + (r.Flags == nil || r.Flags.Resolved()) } // WithChildren implements the sql.Expression interface. func (r *RegexpReplace) WithChildren(children ...sql.Expression) (sql.Expression, error) { - if len(children) != len(r.args) { - return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), len(r.args)) + required := 5 + if r.Flags != nil { + required = 6 + } + if len(children) != required { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required) + } + + // Copy over the regex instance, in case it has already been set to avoid leaking it. + replace, err := NewRegexpReplace(children...) + if err != nil { + if r.re != nil { + if err = r.re.Close(); err != nil { + return nil, err + } + } + return nil, err } - return NewRegexpReplace(children...) + if r.re != nil { + replace.(*RegexpReplace).re = r.re + } + return replace, nil } func (r *RegexpReplace) String() string { var args []string - for _, e := range r.args { + for _, e := range r.Children() { args = append(args, e.String()) } return fmt.Sprintf("%s(%s)", r.FunctionName(), strings.Join(args, ",")) } +func (r *RegexpReplace) compile(ctx *sql.Context, row sql.Row) { + r.compileOnce.Do(func() { + r.cacheRegex = canBeCached(r.Pattern, r.Flags) + r.cacheVal = r.cacheRegex && canBeCached(r.Text, r.RText, r.Position, r.Occurrence) + if r.cacheRegex { + r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row) + } + }) + if !r.cacheRegex { + if r.re != nil { + if r.compileErr = r.re.Close(); r.compileErr != nil { + return + } + } + r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row) + } +} + // Eval implements the sql.Expression interface. func (r *RegexpReplace) Eval(ctx *sql.Context, row sql.Row) (val interface{}, err error) { - // Evaluate string value - str, err := r.args[0].Eval(ctx, row) + span, ctx := ctx.Span("function.RegexpReplace") + defer span.End() + + if r.cachedVal != nil { + return r.cachedVal, nil + } + + r.compile(ctx, row) + if r.compileErr != nil { + return nil, r.compileErr + } + if r.re == nil { + return nil, nil + } + + text, err := r.Text.Eval(ctx, row) if err != nil { return nil, err } - if str == nil { + if text == nil { return nil, nil } - str, _, err = types.LongText.Convert(ctx, str) + text, _, err = types.LongText.Convert(ctx, text) if err != nil { return nil, err } - // Convert to string - _str := str.(string) - - // Handle flags - var flags sql.Expression = nil - if len(r.args) == 6 { - flags = r.args[5] - } - - // Create regex, should handle null pattern and null flags - re, compileErr := compileRegex(ctx, r.args[1], r.args[0], flags, r.FunctionName(), row) - if compileErr != nil { - return nil, compileErr + rText, err := r.RText.Eval(ctx, row) + if err != nil { + return nil, err } - if re == nil { + if rText == nil { return nil, nil } - defer func() { - if nErr := re.Close(); err == nil { - err = nErr - } - }() - if err = re.SetMatchString(ctx, _str); err != nil { + rText, _, err = types.LongText.Convert(ctx, rText) + if err != nil { return nil, err } - // Evaluate ReplaceStr - replaceStr, err := r.args[2].Eval(ctx, row) + pos, err := r.Position.Eval(ctx, row) if err != nil { return nil, err } - if replaceStr == nil { + if pos == nil { return nil, nil } - replaceStr, _, err = types.LongText.Convert(ctx, replaceStr) + pos, _, err = types.Int32.Convert(ctx, pos) if err != nil { return nil, err } - - // Convert to string - _replaceStr := replaceStr.(string) - - // Do nothing if str is empty - if len(_str) == 0 { - return _str, nil + if pos.(int32) <= 0 { + return nil, sql.ErrInvalidArgumentDetails.New(r.FunctionName(), fmt.Sprintf("%d", pos.(int32))) } - // Default position is 1 - _pos := 1 - - // Check if position argument was provided - if len(r.args) >= 4 { - // Evaluate position argument - pos, err := r.args[3].Eval(ctx, row) - if err != nil { - return nil, err - } - if pos == nil { - return nil, nil - } - - // Convert to int32 - pos, _, err = types.Int32.Convert(ctx, pos) - if err != nil { - return nil, err - } - // Convert to int - _pos = int(pos.(int32)) + if len(text.(string)) != 0 && int(pos.(int32)) > len(text.(string)) { + return nil, errors.NewKind("Index out of bounds for regular expression search.").New() } - // Non-positive position throws incorrect parameter - if _pos <= 0 { - return nil, sql.ErrInvalidArgumentDetails.New(r.FunctionName(), fmt.Sprintf("%d", _pos)) + occurrence, err := r.Occurrence.Eval(ctx, row) + if err != nil { + return nil, err } - - // Handle out of bounds - if _pos > len(_str) { - return nil, errors.NewKind("Index out of bounds for regular expression search.").New() + if occurrence == nil { + return nil, nil + } + occurrence, _, err = types.Int32.Convert(ctx, occurrence) + if err != nil { + return nil, err } - // Default occurrence is 0 (replace all occurrences) - _occ := 0 + err = r.re.SetMatchString(ctx, text.(string)) + if err != nil { + return nil, err + } - // Check if Occurrence argument was provided - if len(r.args) >= 5 { - occ, err := r.args[4].Eval(ctx, row) - if err != nil { - return nil, err - } - if occ == nil { - return nil, nil - } + result, err := r.re.Replace(ctx, rText.(string), int(pos.(int32)), int(occurrence.(int32))) + if err != nil { + return nil, err + } - // Convert occurrence to int32 - occ, _, err = types.Int32.Convert(ctx, occ) - if err != nil { - return nil, err - } + return result, nil +} - // Convert to int - _occ = int(occ.(int32)) +// Dispose implements the sql.Disposable interface. +func (r *RegexpReplace) Dispose() { + if r.re != nil { + _ = r.re.Close() } - - return re.Replace(ctx, _replaceStr, _pos, _occ) } diff --git a/sql/expression/function/regexp_replace_test.go b/sql/expression/function/regexp_replace_test.go index 88ad7bccfa..d00d413a26 100644 --- a/sql/expression/function/regexp_replace_test.go +++ b/sql/expression/function/regexp_replace_test.go @@ -1,4 +1,4 @@ -// Copyright 2021 Dolthub, Inc. +// Copyright 2021-2025 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ package function import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -376,3 +377,38 @@ func TestRegexpReplaceWithFlags(t *testing.T) { }) } } + +// Last Run: 06/17/2025 +// BenchmarkRegexpReplace +// BenchmarkRegexpReplace-14 100 97385769 ns/op +// BenchmarkRegexpReplace-14 10000 1012373 ns/op +func BenchmarkRegexpReplace(b *testing.B) { + ctx := sql.NewEmptyContext() + // TODO: for some reason large datasets cause this to hang + data := make([]sql.Row, 11) + for i := range data { + data[i] = sql.Row{fmt.Sprintf("test%d", i)} + } + + for i := 0; i < b.N; i++ { + f, err := NewRegexpReplace( + expression.NewGetField(0, types.LongText, "text", false), + expression.NewLiteral("^test[0-9]$", types.LongText), + expression.NewLiteral("abc", types.LongText), + ) + require.NoError(b, err) + var total int + for _, row := range data { + res, err := f.Eval(ctx, row) + if err != nil { + require.NoError(b, err) + } + require.NoError(b, err) + if res.(string)[:3] == "abc" { + total++ + } + } + require.Equal(b, 10, total) + f.(*RegexpReplace).Dispose() + } +} diff --git a/sql/expression/function/regexp_substr.go b/sql/expression/function/regexp_substr.go index 89d6fd2110..b3d2845d10 100644 --- a/sql/expression/function/regexp_substr.go +++ b/sql/expression/function/regexp_substr.go @@ -36,8 +36,8 @@ type RegexpSubstr struct { Flags sql.Expression cachedVal any - cacheRegex bool cacheVal bool + cacheRegex bool re regex.Regex compileOnce sync.Once compileErr error @@ -154,8 +154,8 @@ func (r *RegexpSubstr) String() string { // compile handles compilation of the regex. func (r *RegexpSubstr) compile(ctx *sql.Context, row sql.Row) { r.compileOnce.Do(func() { - r.cacheRegex = canBeCached(r.Text, r.Pattern, r.Flags) - r.cacheVal = canBeCached(r.Text, r.Pattern, r.Position, r.Occurrence, r.Flags) + r.cacheRegex = canBeCached(r.Pattern, r.Flags) + r.cacheVal = r.cacheRegex && canBeCached(r.Text, r.Position, r.Occurrence) if r.cacheRegex { r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row) } diff --git a/sql/expression/function/regexp_substr_test.go b/sql/expression/function/regexp_substr_test.go new file mode 100644 index 0000000000..cabd937259 --- /dev/null +++ b/sql/expression/function/regexp_substr_test.go @@ -0,0 +1,56 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Last Run: 06/17/2025 +// BenchmarkRegexpSubstr +// BenchmarkRegexpSubstr-14 100 95661410 ns/op +// BenchmarkRegexpSubstr-14 10000 999559 ns/op +func BenchmarkRegexpSubstr(b *testing.B) { + ctx := sql.NewEmptyContext() + data := make([]sql.Row, 100) + for i := range data { + data[i] = sql.Row{fmt.Sprintf("test%d", i)} + } + + for i := 0; i < b.N; i++ { + f, err := NewRegexpSubstr( + expression.NewGetField(0, types.LongText, "text", false), + expression.NewLiteral("^test[0-9]$", types.LongText), + ) + require.NoError(b, err) + var total int + for _, row := range data { + res, err := f.Eval(ctx, row) + require.NoError(b, err) + if res != nil && res.(string)[:4] == "test" { + total++ + } + } + require.Equal(b, 10, total) + f.(*RegexpSubstr).Dispose() + } +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index a6bccbc828..cbf18ccd60 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -89,6 +89,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "degrees", Fn: NewDegrees}, sql.FunctionN{Name: "elt", Fn: NewElt}, sql.Function1{Name: "exp", Fn: NewExp}, + sql.FunctionN{Name: "export_set", Fn: NewExportSet}, sql.Function2{Name: "extract", Fn: NewExtract}, sql.FunctionN{Name: "field", Fn: NewField}, sql.Function2{Name: "find_in_set", Fn: NewFindInSet}, @@ -111,6 +112,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "inet_ntoa", Fn: NewInetNtoa}, sql.Function1{Name: "inet6_aton", Fn: NewInet6Aton}, sql.Function1{Name: "inet6_ntoa", Fn: NewInet6Ntoa}, + sql.Function4{Name: "insert", Fn: NewInsert}, sql.Function2{Name: "instr", Fn: NewInstr}, sql.Function1{Name: "is_binary", Fn: NewIsBinary}, sql.Function1{Name: "is_ipv4", Fn: NewIsIPv4}, @@ -172,6 +174,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "lower", Fn: NewLower}, sql.FunctionN{Name: "lpad", Fn: NewLeftPad}, sql.Function1{Name: "ltrim", Fn: NewLeftTrim}, + sql.FunctionN{Name: "make_set", Fn: NewMakeSet}, sql.Function1{Name: "max", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMax(e) }}, sql.Function1{Name: "md5", Fn: NewMD5}, sql.Function1{Name: "microsecond", Fn: NewMicrosecond}, @@ -184,6 +187,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "ntile", Fn: window.NewNTile}, sql.FunctionN{Name: "now", Fn: NewNow}, sql.Function2{Name: "nullif", Fn: NewNullIf}, + sql.Function1{Name: "oct", Fn: NewOct}, sql.Function1{Name: "octet_length", Fn: NewLength}, sql.Function1{Name: "ord", Fn: NewOrd}, sql.Function0{Name: "pi", Fn: NewPi}, diff --git a/sql/expression/function/substring.go b/sql/expression/function/substring.go index 36189e10c8..19a51a46f0 100644 --- a/sql/expression/function/substring.go +++ b/sql/expression/function/substring.go @@ -349,8 +349,20 @@ func (l Left) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch str := str.(type) { case string: text = []rune(str) + case sql.StringWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: text = []rune(string(str)) + case sql.BytesWrapper: + b, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(string(b)) case nil: return nil, nil default: @@ -583,8 +595,20 @@ func (i Instr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch str := str.(type) { case string: text = []rune(str) + case sql.StringWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: text = []rune(string(str)) + case sql.BytesWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(string(s)) case nil: return nil, nil default: @@ -600,8 +624,20 @@ func (i Instr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch substr := substr.(type) { case string: subtext = []rune(substr) + case sql.StringWrapper: + s, err := substr.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: - subtext = []rune(string(subtext)) + subtext = []rune(string(substr)) + case sql.BytesWrapper: + s, err := substr.Unwrap(ctx) + if err != nil { + return nil, err + } + subtext = []rune(string(s)) case nil: return nil, nil default: diff --git a/sql/expression/function/time_math.go b/sql/expression/function/time_math.go index 9fd3b702ed..4e25fd847a 100644 --- a/sql/expression/function/time_math.go +++ b/sql/expression/function/time_math.go @@ -232,7 +232,7 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } var dateVal interface{} - dateVal, _, err = types.DatetimeMaxPrecision.Convert(ctx, date) + dateVal, _, err = types.DatetimeMaxRange.Convert(ctx, date) if err != nil { ctx.Warn(1292, err.Error()) return nil, nil @@ -380,7 +380,7 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } var dateVal interface{} - dateVal, _, err = types.DatetimeMaxPrecision.Convert(ctx, date) + dateVal, _, err = types.DatetimeMaxRange.Convert(ctx, date) if err != nil { ctx.Warn(1292, err.Error()) return nil, nil diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 7e60d8f163..0cf8e23a80 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -237,72 +237,69 @@ const ( week = 7 * day ) -func (td TimeDelta) apply(t time.Time, sign int64) time.Time { - y := int64(t.Year()) - mo := int64(t.Month()) - d := t.Day() - h := t.Hour() - min := t.Minute() - s := t.Second() - ns := t.Nanosecond() +// isLeapYear determines if a given year is a leap year +func isLeapYear(year int) bool { + return daysInMonth(year, time.February) == 29 +} - if td.Years != 0 { - y += td.Years * sign - } +// daysInMonth returns the number of days in a given month/year combination +func daysInMonth(year int, month time.Month) int { + return time.Date(year, month+1, 0, 0, 0, 0, 0, time.UTC).Day() +} - if td.Months != 0 { - m := mo + td.Months*sign - if m < 1 { - mo = 12 + (m % 12) - y += m/12 - 1 - } else if m > 12 { - mo = m % 12 - y += m / 12 +// apply applies the time delta to the given time, using the specified sign +func (td TimeDelta) apply(t time.Time, sign int64) time.Time { + if td.Years != 0 { + targetYear := t.Year() + int(td.Years*sign) + + // special handling for Feb 29 on leap years + if t.Month() == time.February && t.Day() == 29 && !isLeapYear(targetYear) { + // if we're on Feb 29 and target year is not a leap year, + // move to Feb 28 + t = time.Date(targetYear, time.February, 28, + t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), t.Location()) } else { - mo = m - } - - // Due to the operations done before, month may be zero, which means it's - // december. - if mo == 0 { - mo = 12 + t = time.Date(targetYear, t.Month(), t.Day(), + t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), t.Location()) } } - if days := daysInMonth(time.Month(mo), int(y)); days < d { - d = days - } - - date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location()) + if td.Months != 0 { + totalMonths := int(t.Month()) - 1 + int(td.Months*sign) // convert to 0-based - if td.Days != 0 { - date = date.Add(time.Duration(td.Days) * day * time.Duration(sign)) - } + // calculate target year and month + yearOffset := totalMonths / 12 + if totalMonths < 0 { + yearOffset = (totalMonths - 11) / 12 // handle negative division correctly + } + targetYear := t.Year() + yearOffset + targetMonth := time.Month((totalMonths%12+12)%12 + 1) // ensure positive month - if td.Hours != 0 { - date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign)) - } + // handle end-of-month edge cases + originalDay := t.Day() + maxDaysInTargetMonth := daysInMonth(targetYear, targetMonth) - if td.Minutes != 0 { - date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign)) - } + targetDay := originalDay + if originalDay > maxDaysInTargetMonth { + targetDay = maxDaysInTargetMonth + } - if td.Seconds != 0 { - date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign)) + t = time.Date(targetYear, targetMonth, targetDay, + t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), t.Location()) } - if td.Microseconds != 0 { - date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign)) + if td.Days != 0 { + t = t.AddDate(0, 0, int(td.Days*sign)) } - return date -} + duration := time.Duration(td.Hours*sign)*time.Hour + + time.Duration(td.Minutes*sign)*time.Minute + + time.Duration(td.Seconds*sign)*time.Second + + time.Duration(td.Microseconds*sign)*time.Microsecond -func daysInMonth(month time.Month, year int) int { - if month == time.December { - return 31 + if duration != 0 { + t = t.Add(duration) } - date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local) - return date.Add(-1 * day).Day() + return t } diff --git a/sql/expression/interval_test.go b/sql/expression/interval_test.go index 0a808a5ef0..235757cacb 100644 --- a/sql/expression/interval_test.go +++ b/sql/expression/interval_test.go @@ -51,10 +51,10 @@ func TestTimeDelta(t *testing.T) { date(2005, time.March, 29, 0, 0, 0, 0), }, { - "plus overflowing until december", + "plus overflowing until december", // #7300 mysql produced 2005-12-29 TimeDelta{Months: 22}, leapYear, - date(2006, time.December, 29, 0, 0, 0, 0), + date(2005, time.December, 29, 0, 0, 0, 0), }, { "minus overflowing months", diff --git a/sql/expression/isnull.go b/sql/expression/isnull.go index f0cf53e087..109e915b86 100644 --- a/sql/expression/isnull.go +++ b/sql/expression/isnull.go @@ -26,12 +26,19 @@ type IsNull struct { var _ sql.Expression = (*IsNull)(nil) var _ sql.CollationCoercible = (*IsNull)(nil) +var _ sql.IsNullExpression = (*IsNull)(nil) // NewIsNull creates a new IsNull expression. func NewIsNull(child sql.Expression) *IsNull { return &IsNull{UnaryExpression{child}} } +// IsNullExpression implements the sql.IsNullExpression interface. This function exsists primarily +// to ensure the IsNullExpression interface has a unique signature. +func (e *IsNull) IsNullExpression() bool { + return true +} + // Type implements the Expression interface. func (e *IsNull) Type() sql.Type { return types.Boolean @@ -53,18 +60,6 @@ func (e *IsNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - - // Slices of typed values (e.g. Record and Composite types in Postgres) evaluate - // to NULL if all of their entries are NULL. - if tupleValue, ok := v.([]types.TupleValue); ok { - for _, typedValue := range tupleValue { - if typedValue.Value != nil { - return false, nil - } - } - return true, nil - } - return v == nil, nil } diff --git a/sql/expression/like.go b/sql/expression/like.go index cbfc56f582..bc17607883 100644 --- a/sql/expression/like.go +++ b/sql/expression/like.go @@ -86,10 +86,12 @@ func (l *Like) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } if _, ok := left.(string); !ok { - left, _, err = types.LongText.Convert(ctx, left) + // Use type-aware conversion for enum types + leftStr, _, err := types.ConvertToCollatedString(ctx, left, l.Left().Type()) if err != nil { return nil, err } + left = leftStr } var lm LikeMatcher @@ -120,9 +122,6 @@ func (l *Like) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - if lm.collation == sql.Collation_Unspecified { - return false, nil - } ok := lm.Match(left.(string)) if l.cached { @@ -141,10 +140,12 @@ func (l *Like) evalRight(ctx *sql.Context, row sql.Row) (right *string, escape r return nil, 0, err } if _, ok := rightVal.(string); !ok { - rightVal, _, err = types.LongText.Convert(ctx, rightVal) + // Use type-aware conversion for enum types + rightStr, _, err := types.ConvertToCollatedString(ctx, rightVal, l.Right().Type()) if err != nil { return nil, 0, err } + rightVal = rightStr } var escapeVal interface{} diff --git a/sql/hash/hash.go b/sql/hash/hash.go new file mode 100644 index 0000000000..94bcc64206 --- /dev/null +++ b/sql/hash/hash.go @@ -0,0 +1,99 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hash + +import ( + "fmt" + "sync" + + "github.com/cespare/xxhash/v2" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +var digestPool = sync.Pool{ + New: func() any { + return xxhash.New() + }, +} + +// ExprsToSchema converts a list of sql.Expression to a sql.Schema. +// This is used for functions that use HashOf, but don't already have a schema. +// The generated schema ONLY contains the types of the expressions without any column names or any other info. +func ExprsToSchema(exprs ...sql.Expression) sql.Schema { + var sch sql.Schema + for _, expr := range exprs { + sch = append(sch, &sql.Column{Type: expr.Type()}) + } + return sch +} + +// HashOf returns a hash of the given value to be used as key in a cache. +func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) { + hash := digestPool.Get().(*xxhash.Digest) + hash.Reset() + defer digestPool.Put(hash) + for i, v := range row { + if i > 0 { + // separate each value in the row with a nil byte + if _, err := hash.Write([]byte{0}); err != nil { + return 0, err + } + } + + v, err := sql.UnwrapAny(ctx, v) + if err != nil { + return 0, fmt.Errorf("error unwrapping value: %w", err) + } + + // TODO: we may not always have the type information available, so we check schema length. + // Then, defer to original behavior + if i >= len(sch) || v == nil { + _, err := fmt.Fprintf(hash, "%v", v) + if err != nil { + return 0, err + } + continue + } + + switch typ := sch[i].Type.(type) { + case types.ExtendedType: + // TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL, + // so we're using the old (probably incorrect) behavior for now + _, err = fmt.Fprintf(hash, "%v", v) + if err != nil { + return 0, err + } + case types.StringType: + var strVal string + strVal, err = types.ConvertToString(ctx, v, typ, nil) + if err != nil { + return 0, err + } + err = typ.Collation().WriteWeightString(hash, strVal) + if err != nil { + return 0, err + } + default: + // TODO: probably much faster to do this with a type switch + _, err = fmt.Fprintf(hash, "%v", v) + if err != nil { + return 0, err + } + } + } + return hash.Sum64(), nil +} diff --git a/sql/hash/hash_test.go b/sql/hash/hash_test.go new file mode 100644 index 0000000000..30cd0a10dc --- /dev/null +++ b/sql/hash/hash_test.go @@ -0,0 +1,53 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hash + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" +) + +func BenchmarkHashOf(b *testing.B) { + ctx := sql.NewEmptyContext() + row := sql.NewRow(1, "1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + sum, err := HashOf(ctx, nil, row) + if err != nil { + b.Fatal(err) + } + if sum != 11268758894040352165 { + b.Fatalf("got %v", sum) + } + } +} + +func BenchmarkParallelHashOf(b *testing.B) { + ctx := sql.NewEmptyContext() + row := sql.NewRow(1, "1") + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sum, err := HashOf(ctx, nil, row) + if err != nil { + b.Fatal(err) + } + if sum != 11268758894040352165 { + b.Fatalf("got %v", sum) + } + } + }) +} diff --git a/sql/information_schema/constants.go b/sql/information_schema/constants.go index b7183be9bb..57b12ebf3d 100644 --- a/sql/information_schema/constants.go +++ b/sql/information_schema/constants.go @@ -805,7 +805,6 @@ var keywordsArray = [747]Keyword{ {"WAIT", 0}, {"WARNINGS", 0}, {"WEEK", 0}, - {"WEIGHT_STRING", 0}, {"WHEN", 1}, {"WHERE", 1}, {"WHILE", 1}, diff --git a/sql/information_schema/tables_table.go b/sql/information_schema/tables_table.go index 1edd99d057..b3f80dc7b6 100644 --- a/sql/information_schema/tables_table.go +++ b/sql/information_schema/tables_table.go @@ -65,7 +65,7 @@ var tablesSchema = Schema{ func tablesRowIter(ctx *Context, cat Catalog) (RowIter, error) { var rows []Row var ( - tableType string + tableType uint16 tableRows uint64 avgRowLength uint64 dataLength uint64 @@ -82,9 +82,9 @@ func tablesRowIter(ctx *Context, cat Catalog) (RowIter, error) { for _, db := range databases { if db.Database.Name() == InformationSchemaDatabaseName { - tableType = "SYSTEM VIEW" + tableType = 3 // SYSTEM_VIEW } else { - tableType = "BASE TABLE" + tableType = 1 // BASE_TABLE engine = "InnoDB" rowFormat = "Dynamic" } diff --git a/sql/iters/rel_iters.go b/sql/iters/rel_iters.go index 6033891160..cf35b53a35 100644 --- a/sql/iters/rel_iters.go +++ b/sql/iters/rel_iters.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -571,7 +572,7 @@ func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - hash, err := sql.HashOf(ctx, row) + hash, err := hash.HashOf(ctx, nil, row) if err != nil { return nil, err } @@ -643,11 +644,14 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) { ii.cache = make(map[uint64]int) for { res, err := ii.RIter.Next(ctx) - if err != nil && err != io.EOF { + if err != nil { + if err == io.EOF { + break + } return nil, err } - hash, herr := sql.HashOf(ctx, res) + hash, herr := hash.HashOf(ctx, nil, res) if herr != nil { return nil, herr } @@ -655,10 +659,6 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) { ii.cache[hash] = 0 } ii.cache[hash]++ - - if err == io.EOF { - break - } } ii.cached = true } @@ -669,7 +669,7 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - hash, herr := sql.HashOf(ctx, res) + hash, herr := hash.HashOf(ctx, nil, res) if herr != nil { return nil, herr } @@ -714,7 +714,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - hash, herr := sql.HashOf(ctx, res) + hash, herr := hash.HashOf(ctx, nil, res) if herr != nil { return nil, herr } @@ -736,7 +736,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - hash, herr := sql.HashOf(ctx, res) + hash, herr := hash.HashOf(ctx, nil, res) if herr != nil { return nil, herr } diff --git a/sql/memo/join_order_builder.go b/sql/memo/join_order_builder.go index 9c07e3ce54..2f5b4c8329 100644 --- a/sql/memo/join_order_builder.go +++ b/sql/memo/join_order_builder.go @@ -154,7 +154,7 @@ var ErrUnsupportedReorderNode = errors.New("unsupported join reorder node") // useFastReorder determines whether to skip the current brute force join planning and use an alternate // planning algorithm that analyzes the join tree to find a sequence that can be implemented purely as lookup joins. -// Currently we only use it for large joins (20+ tables) with no join hints. +// Currently, we only use it for large joins (15+ tables) with no join hints. func (j *joinOrderBuilder) useFastReorder() bool { if j.forceFastDFSLookupForTest { return true @@ -180,7 +180,7 @@ func (j *joinOrderBuilder) ReorderJoin(n sql.Node) { // from ensureClosure in buildSingleLookupPlan, but the equivalence sets could create multiple possible join orders // for the single-lookup plan, which would complicate things. j.ensureClosure(j.m.root) - j.dbSube() + j.dpEnumerateSubsets() return } @@ -627,10 +627,10 @@ func (j *joinOrderBuilder) checkSize() { } } -// dpSube iterates all disjoint combinations of table sets, +// dpEnumerateSubsets iterates all disjoint combinations of table sets, // adding plans to the tree when we find two sets that can // be joined -func (j *joinOrderBuilder) dbSube() { +func (j *joinOrderBuilder) dpEnumerateSubsets() { all := j.allVertices() for subset := vertexSet(1); subset <= all; subset++ { if subset.isSingleton() { diff --git a/sql/memo/memo.go b/sql/memo/memo.go index d77574a8cd..c51042ab16 100644 --- a/sql/memo/memo.go +++ b/sql/memo/memo.go @@ -82,10 +82,13 @@ func (m *Memo) StatsProvider() sql.StatsProvider { return m.statsProv } -func (m *Memo) SetDefaultHints() { +// SessionHints returns any hints that have been enabled in the session for join planning, +// such as the @@disable_merge_join SQL system variable. +func (m *Memo) SessionHints() (hints []Hint) { if val, _ := m.Ctx.GetSessionVariable(m.Ctx, sql.DisableMergeJoin); val.(int8) != 0 { - m.ApplyHint(Hint{Typ: HintTypeNoMergeJoin}) + hints = append(hints, Hint{Typ: HintTypeNoMergeJoin}) } + return hints } // newExprGroup creates a new logical expression group to encapsulate the @@ -465,11 +468,6 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { // rather than a local property. func (m *Memo) updateBest(grp *ExprGroup, n RelExpr, cost float64) { if !m.hints.isEmpty() { - for _, block := range m.hints.block { - if !block.isOk(n) { - return - } - } if m.hints.satisfiedBy(n) { if !grp.HintOk { grp.Best = n @@ -522,31 +520,20 @@ func getProjectColset(p *Project) sql.ColSet { return colset } +// ApplyHint applies |hint| to this memo, converting the parsed hint into an internal representation and updating +// the internal data to match the memo metadata. Note that this function MUST be called only after memo groups have +// been fully built out, otherwise the group information set in the internal join hint structures will be incomplete. func (m *Memo) ApplyHint(hint Hint) { switch hint.Typ { case HintTypeJoinOrder: m.SetJoinOrder(hint.Args) case HintTypeJoinFixedOrder: case HintTypeNoMergeJoin: - m.SetBlockOp(func(n RelExpr) bool { - switch n := n.(type) { - case JoinRel: - jp := n.JoinPrivate() - if !jp.Left.Best.Group().HintOk || !jp.Right.Best.Group().HintOk { - // equiv closures can generate child plans that bypass hints - return false - } - if jp.Op.IsMerge() { - return false - } - } - return true - }) + m.hints.disableMergeJoin = true case HintTypeInnerJoin, HintTypeMergeJoin, HintTypeLookupJoin, HintTypeHashJoin, HintTypeSemiJoin, HintTypeAntiJoin, HintTypeLeftOuterLookupJoin: m.SetJoinOp(hint.Typ, hint.Args[0], hint.Args[1]) case HintTypeLeftDeep: m.hints.leftDeep = true - default: } } @@ -568,10 +555,6 @@ func (m *Memo) SetJoinOrder(tables []string) { } } -func (m *Memo) SetBlockOp(cb func(n RelExpr) bool) { - m.hints.block = append(m.hints.block, joinBlockHint{cb: cb}) -} - func (m *Memo) SetJoinOp(op HintType, left, right string) { var lTab, rTab sql.TableId for _, n := range m.root.RelProps.TableIdNodes() { diff --git a/sql/memo/rel_props.go b/sql/memo/rel_props.go index 997384057d..d5697ae6e0 100644 --- a/sql/memo/rel_props.go +++ b/sql/memo/rel_props.go @@ -285,13 +285,18 @@ func (p *relProps) populateFds() { } } case *expression.Not: - child, ok := f.Child.(*expression.IsNull) + child, ok := f.Child.(sql.IsNullExpression) if ok { - col, ok := child.Child.(*expression.GetField) + col, ok := child.Children()[0].(*expression.GetField) if ok { notNull.Add(col.Id()) } } + case sql.IsNotNullExpression: + col, ok := f.Children()[0].(*expression.GetField) + if ok { + notNull.Add(col.Id()) + } } } fds = sql.NewFilterFDs(rel.Child.RelProps.FuncDeps(), notNull, constant, equiv) diff --git a/sql/memo/select_hints.go b/sql/memo/select_hints.go index 7c9a515e33..13b462b596 100644 --- a/sql/memo/select_hints.go +++ b/sql/memo/select_hints.go @@ -372,25 +372,18 @@ func (o joinOpHint) typeMatches(n RelExpr) bool { return true } -type joinBlockHint struct { - cb func(n RelExpr) bool -} - -func (o joinBlockHint) isOk(n RelExpr) bool { - return o.cb(n) -} - // joinHints wraps a collection of join hints. The memo // interfaces with this object during costing. type joinHints struct { - ops []joinOpHint - order *joinOrderHint - block []joinBlockHint - leftDeep bool + ops []joinOpHint + order *joinOrderHint + leftDeep bool + disableMergeJoin bool } +// isEmpty returns true if no hints that affect join planning have been set. func (h joinHints) isEmpty() bool { - return len(h.ops) == 0 && h.order == nil && !h.leftDeep && len(h.block) == 0 + return len(h.ops) == 0 && h.order == nil && !h.leftDeep && !h.disableMergeJoin } // satisfiedBy returns whether a RelExpr satisfies every join hint. This diff --git a/sql/parser.go b/sql/parser.go index ea4703f6fa..f23ae02dc1 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -133,6 +133,7 @@ func RemoveSpaceAndDelimiter(query string, d rune) string { }) } +// EscapeSpecialCharactersInComment escapes special characters in a comment string. func EscapeSpecialCharactersInComment(comment string) string { commentString := comment commentString = strings.ReplaceAll(commentString, "'", "''") diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index ed7ce5f406..d32b3cf0c1 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -157,7 +157,9 @@ func NewCheckDefinition(ctx *sql.Context, check *sql.CheckConstraint) (*sql.Chec unqualifiedCols, _, err := transform.Expr(check.Expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { gf, ok := e.(*expression.GetField) if ok { - return expression.NewGetField(gf.Index(), gf.Type(), gf.Name(), gf.IsNullable()), transform.NewTree, nil + newGf := expression.NewGetField(gf.Index(), gf.Type(), gf.Name(), gf.IsNullable()) + newGf = newGf.WithQuotedNames(sql.GlobalSchemaFormatter, true) + return newGf, transform.NewTree, nil } return e, transform.SameTree, nil }) @@ -167,7 +169,7 @@ func NewCheckDefinition(ctx *sql.Context, check *sql.CheckConstraint) (*sql.Chec return &sql.CheckDefinition{ Name: check.Name, - CheckExpression: fmt.Sprintf("%s", unqualifiedCols), + CheckExpression: unqualifiedCols.String(), Enforced: check.Enforced, }, nil } diff --git a/sql/plan/alter_foreign_key.go b/sql/plan/alter_foreign_key.go index 94e21638ec..c145f587c9 100644 --- a/sql/plan/alter_foreign_key.go +++ b/sql/plan/alter_foreign_key.go @@ -655,6 +655,10 @@ func foreignKeyComparableTypes(ctx *sql.Context, type1 sql.Type, type2 sql.Type) if type1String.Collation().CharacterSet() != type2String.Collation().CharacterSet() { return false } + case sqltypes.Enum: + // Enum types can reference each other in foreign keys regardless of their string values. + // MySQL allows enum foreign keys to match based on underlying numeric values. + return true default: return false } diff --git a/sql/plan/hash_lookup.go b/sql/plan/hash_lookup.go index 0e6950b25c..7926d29255 100644 --- a/sql/plan/hash_lookup.go +++ b/sql/plan/hash_lookup.go @@ -18,9 +18,9 @@ import ( "fmt" "sync" - "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/hash" + "github.com/dolthub/go-mysql-server/sql/types" ) // NewHashLookup returns a node that performs an indexed hash lookup @@ -33,12 +33,14 @@ import ( // on the projected results. If cached results are not available, it // simply delegates to the child. func NewHashLookup(n sql.Node, rightEntryKey sql.Expression, leftProbeKey sql.Expression, joinType JoinType) *HashLookup { + leftKeySch := hash.ExprsToSchema(leftProbeKey) return &HashLookup{ UnaryNode: UnaryNode{n}, RightEntryKey: rightEntryKey, LeftProbeKey: leftProbeKey, Mutex: new(sync.Mutex), JoinType: joinType, + leftKeySch: leftKeySch, } } @@ -49,6 +51,7 @@ type HashLookup struct { Mutex *sync.Mutex Lookup *map[interface{}][]sql.Row JoinType JoinType + leftKeySch sql.Schema } var _ sql.Node = (*HashLookup)(nil) @@ -70,6 +73,7 @@ func (n *HashLookup) WithExpressions(exprs ...sql.Expression) (sql.Node, error) ret := *n ret.RightEntryKey = exprs[0] ret.LeftProbeKey = exprs[1] + ret.leftKeySch = hash.ExprsToSchema(ret.LeftProbeKey) return &ret, nil } @@ -127,7 +131,7 @@ func (n *HashLookup) GetHashKey(ctx *sql.Context, e sql.Expression, row sql.Row) return nil, err } if s, ok := key.([]interface{}); ok { - return sql.HashOf(ctx, s) + return hash.HashOf(ctx, n.leftKeySch, s) } // byte slices are not hashable if k, ok := key.([]byte); ok { diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 5c7a24da12..9c5dd6272e 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -72,7 +72,7 @@ type InsertInto struct { LiteralValueSource bool // Returning is a list of expressions to return after the insert operation. This feature is not supported - // in MySQL's syntax, but is exposed through PostgreSQL's syntax. + // in MySQL's syntax, but is exposed through PostgreSQL's and MariaDB's syntax. Returning []sql.Expression // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id. diff --git a/sql/plan/insubquery.go b/sql/plan/insubquery.go index 179f05ba0e..7dcc46cc36 100644 --- a/sql/plan/insubquery.go +++ b/sql/plan/insubquery.go @@ -19,6 +19,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -47,7 +48,7 @@ func NewInSubquery(left sql.Expression, right sql.Expression) *InSubquery { return &InSubquery{expression.BinaryExpressionStub{LeftChild: left, RightChild: right}} } -var nilKey, _ = sql.HashOf(nil, sql.NewRow(nil)) +var nilKey, _ = hash.HashOf(nil, nil, sql.NewRow(nil)) // Eval implements the Expression interface. func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { @@ -75,7 +76,7 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, sql.ErrInvalidOperandColumns.New(types.NumColumns(typ), types.NumColumns(right.Type())) } - typ := right.Type() + rTyp := right.Type() values, err := right.HashMultiple(ctx, row) if err != nil { @@ -91,12 +92,12 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // convert left to right's type - nLeft, _, err := typ.Convert(ctx, left) + nLeft, _, err := rTyp.Convert(ctx, left) if err != nil { return false, nil } - key, err := sql.HashOf(ctx, sql.NewRow(nLeft)) + key, err := hash.HashOf(ctx, sql.Schema{&sql.Column{Type: rTyp}}, sql.NewRow(nLeft)) if err != nil { return nil, err } @@ -109,12 +110,12 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return false, nil } - val, _, err = typ.Convert(ctx, val) + val, _, err = rTyp.Convert(ctx, val) if err != nil { return false, nil } - cmp, err := typ.Compare(ctx, left, val) + cmp, err := rTyp.Compare(ctx, left, val) if err != nil { return nil, err } diff --git a/sql/plan/join.go b/sql/plan/join.go index 9e689c74e0..f5768df93a 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -531,9 +531,12 @@ func NewSemiJoin(left, right sql.Node, cond sql.Expression) *JoinNode { // IsNullRejecting returns whether the expression always returns false for // nil inputs. func IsNullRejecting(e sql.Expression) bool { + // Note that InspectExpr will stop inspecting expressions in the + // expression tree when true is returned, so we invert that return + // value from InspectExpr to return the correct null rejecting value. return !transform.InspectExpr(e, func(e sql.Expression) bool { switch e.(type) { - case *expression.NullSafeEquals, *expression.IsNull: + case sql.IsNullExpression, sql.IsNotNullExpression, *expression.NullSafeEquals: return true default: return false diff --git a/sql/plan/project.go b/sql/plan/project.go index 2f4c541fed..9e377794c2 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -26,9 +26,15 @@ import ( // Project is a projection of certain expression from the children node. type Project struct { UnaryNode + // Projections are the expressions to be projected on the row returned by the child node Projections []sql.Expression - CanDefer bool - deps sql.ColSet + // CanDefer is true when the projection evaluation can be deferred to row spooling, which allows us to avoid a + // separate iterator for the project node. + CanDefer bool + // IncludesNestedIters is true when the projection includes nested iterators because of expressions that return + // a RowIter. + IncludesNestedIters bool + deps sql.ColSet } var _ sql.Expressioner = (*Project)(nil) @@ -202,8 +208,17 @@ func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { return &np, nil } +// WithCanDefer returns a new Project with the CanDefer field set to the given value. func (p *Project) WithCanDefer(canDefer bool) *Project { np := *p np.CanDefer = canDefer return &np } + +// WithIncludesNestedIters returns a new Project with the IncludesNestedIters field set to the given value. +func (p *Project) WithIncludesNestedIters(includesNestedIters bool) *Project { + np := *p + np.IncludesNestedIters = includesNestedIters + np.CanDefer = false + return &np +} diff --git a/sql/plan/set.go b/sql/plan/set.go index 51e22d06cd..add34c3488 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" ) // Set represents a set statement. This can be variables, but in some instances can also refer to row values. @@ -77,13 +78,9 @@ func (s *Set) Expressions() []sql.Expression { return s.Exprs } -// setSch is used to differentiate from the nil schema, -// because Set does return rows -var setSch = make(sql.Schema, 0) - // Schema implements the sql.Node interface. func (s *Set) Schema() sql.Schema { - return setSch + return types.OkResultSchema } func (s *Set) String() string { diff --git a/sql/plan/subquery.go b/sql/plan/subquery.go index a612c72ab2..061e82d4dd 100644 --- a/sql/plan/subquery.go +++ b/sql/plan/subquery.go @@ -19,10 +19,10 @@ import ( "io" "sync" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" - - "github.com/dolthub/go-mysql-server/sql" ) // Subquery is as an expression whose value is derived by executing a subquery. It must be executed for every row in @@ -313,7 +313,7 @@ func (m *Max1Row) CollationCoercibility(ctx *sql.Context) (collation sql.Collati } // EvalMultiple returns all rows returned by a subquery. -func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) { +func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) ([]any, error) { s.cacheMu.Lock() cached := s.resultsCached s.cacheMu.Unlock() @@ -341,7 +341,7 @@ func (s *Subquery) canCacheResults() bool { return s.correlated.Empty() && !s.volatile } -func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) { +func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]any, error) { // Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its // result rows are prepended with the scope row. q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false)) @@ -362,7 +362,7 @@ func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, e // Reduce the result row to the size of the expected schema. This means chopping off the first len(row) columns. col := len(row) - var result []interface{} + var result []any for { row, err := iter.Next(ctx) if err == io.EOF { @@ -407,7 +407,7 @@ func (s *Subquery) HashMultiple(ctx *sql.Context, row sql.Row) (sql.KeyValueCach defer s.cacheMu.Unlock() if !s.resultsCached || s.hashCache == nil { hashCache, disposeFn := ctx.Memory.NewHistoryCache() - err = putAllRows(ctx, hashCache, result) + err = putAllRows(ctx, hashCache, s.Query.Schema(), result) if err != nil { return nil, err } @@ -417,7 +417,11 @@ func (s *Subquery) HashMultiple(ctx *sql.Context, row sql.Row) (sql.KeyValueCach } cache := sql.NewMapCache() - return cache, putAllRows(ctx, cache, result) + err = putAllRows(ctx, cache, s.Query.Schema(), result) + if err != nil { + return nil, err + } + return cache, nil } // HasResultRow returns whether the subquery has a result set > 0. @@ -466,22 +470,25 @@ func (s *Subquery) HasResultRow(ctx *sql.Context, row sql.Row) (bool, error) { // normalizeValue returns a canonical version of a value for use in a sql.KeyValueCache. // Two values that compare equal should have the same canonical version. -// TODO: Fix https://github.com/dolthub/dolt/issues/9049 by making this function collation-aware func normalizeForKeyValueCache(ctx *sql.Context, val interface{}) (interface{}, error) { - return sql.UnwrapAny(ctx, val) + val, err := sql.UnwrapAny(ctx, val) + if err != nil { + return nil, err + } + return val, nil } -func putAllRows(ctx *sql.Context, cache sql.KeyValueCache, vals []interface{}) error { +func putAllRows(ctx *sql.Context, cache sql.KeyValueCache, sch sql.Schema, vals []interface{}) error { for _, val := range vals { - val, err := normalizeForKeyValueCache(ctx, val) + normVal, err := normalizeForKeyValueCache(ctx, val) if err != nil { return err } - rowKey, err := sql.HashOf(ctx, sql.NewRow(val)) + rowKey, err := hash.HashOf(ctx, sch, sql.NewRow(normVal)) if err != nil { return err } - err = cache.Put(rowKey, val) + err = cache.Put(rowKey, normVal) if err != nil { return err } diff --git a/sql/plan/update.go b/sql/plan/update.go index b023e2d68d..2aedd1174c 100644 --- a/sql/plan/update.go +++ b/sql/plan/update.go @@ -31,8 +31,10 @@ var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but ex // Update is a node for updating rows on tables. type Update struct { UnaryNode - checks sql.CheckConstraints - Ignore bool + checks sql.CheckConstraints + Ignore bool + // IsJoin is true only for explicit UPDATE JOIN queries. It's possible for Update.IsJoin to be false and + // Update.Child to be an UpdateJoin since subqueries are optimized as Joins IsJoin bool HasSingleRel bool IsProcNested bool diff --git a/sql/plan/update_join.go b/sql/plan/update_join.go index 814e953a26..da2ec1ca03 100644 --- a/sql/plan/update_join.go +++ b/sql/plan/update_join.go @@ -21,15 +21,15 @@ import ( ) type UpdateJoin struct { - Updaters map[string]sql.RowUpdater + UpdateTargets map[string]sql.Node UnaryNode } -// NewUpdateJoin returns an *UpdateJoin node. -func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin { +// NewUpdateJoin returns a new *UpdateJoin node. +func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin { return &UpdateJoin{ - Updaters: editorMap, - UnaryNode: UnaryNode{Child: child}, + UpdateTargets: updateTargets, + UnaryNode: UnaryNode{Child: child}, } } @@ -55,8 +55,8 @@ func (u *UpdateJoin) DebugString() string { // GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable. func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable { return &updatableJoinTable{ - updaters: u.Updaters, - joinNode: u.Child.(*UpdateSource).Child, + updateTargets: u.UpdateTargets, + joinNode: u.Child.(*UpdateSource).Child, } } @@ -66,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) { return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return NewUpdateJoin(u.Updaters, children[0]), nil + return NewUpdateJoin(u.UpdateTargets, children[0]), nil } func (u *UpdateJoin) IsReadOnly() bool { @@ -78,10 +78,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll return sql.GetCoercibility(ctx, u.Child) } +func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) { + return getUpdaters(u.UpdateTargets, ctx) +} + +func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) { + updaterMap := make(map[string]sql.RowUpdater) + for tableName, updateTarget := range updateTargets { + updatable, err := GetUpdatable(updateTarget) + if err != nil { + return nil, err + } + updaterMap[tableName] = updatable.Updater(ctx) + } + return updaterMap, nil +} + // updatableJoinTable manages the update of multiple tables. type updatableJoinTable struct { - updaters map[string]sql.RowUpdater - joinNode sql.Node + updateTargets map[string]sql.Node + joinNode sql.Node } var _ sql.UpdatableTable = (*updatableJoinTable)(nil) @@ -118,8 +134,9 @@ func (u *updatableJoinTable) Collation() sql.CollationID { // Updater implements the sql.UpdatableTable interface. func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater { + updaters, _ := getUpdaters(u.updateTargets, ctx) return &updatableJoinUpdater{ - updaterMap: u.updaters, + updaterMap: updaters, schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()), joinSchema: u.joinNode.Schema(), } diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 3d3b65e0e3..fe2ee04140 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -1418,6 +1418,15 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t } for i, def := range defaults { + // Early validation for enum default 0 to catch it before conversion + if def != nil && types.IsEnum(schema[i].Type) { + if lit, ok := def.(*ast.SQLVal); ok { + if lit.Type == ast.IntVal && string(lit.Val) == "0" { + b.handleErr(sql.ErrInvalidColumnDefaultValue.New(schema[i].Name)) + } + } + } + schema[i].Default = b.convertDefaultExpression(outScope, def, schema[i].Type, schema[i].Nullable) err := validateDefaultExprs(schema[i]) if err != nil { diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 60b4ef9090..d7ec5053b6 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -76,10 +76,6 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { schema := rt.Schema() columns = make([]string, len(schema)) for i, col := range schema { - // Tables with any generated column must always supply a column list, so this is always an error - if col.Generated != nil { - b.handleErr(sql.ErrGeneratedColumnValue.New(col.Name, rt.Name())) - } columns[i] = col.Name } } @@ -150,12 +146,10 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore) ins.LiteralValueSource = srcLiteralOnly - if i.Returning != nil { - returningExprs := make([]sql.Expression, len(i.Returning)) - for i, selectExpr := range i.Returning { - returningExprs[i] = b.selectExprToExpression(destScope, selectExpr) - } - ins.Returning = returningExprs + if len(i.Returning) > 0 { + // TODO: read returning results from outScope instead of ins.Returning so that there is no need to return list + // of expressions + ins.Returning = b.analyzeSelectList(destScope, destScope, i.Returning) } b.validateInsert(ins) @@ -297,16 +291,25 @@ func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentE colIdx := tableSch.IndexOfColName(gf.Name()) // TODO: during trigger parsing the table in the node is unresolved, so we need this additional bounds check // This means that trigger execution will be able to update generated columns - // Prevent update of generated columns - if colIdx >= 0 && tableSch[colIdx].Generated != nil { + + // Check if this is a DEFAULT expression for a generated column + _, isDefaultExpr := updateExpr.Expr.(*ast.Default) + + // Prevent update of generated columns, but allow DEFAULT + if colIdx >= 0 && tableSch[colIdx].Generated != nil && !isDefaultExpr { err := sql.ErrGeneratedColumnValue.New(tableSch[colIdx].Name, inScope.node.(sql.NameableNode).Name()) b.handleErr(err) } // Replace default with column default from resolved schema - if _, ok := updateExpr.Expr.(*ast.Default); ok { + if isDefaultExpr { if colIdx >= 0 { - innerExpr = expression.WrapExpression(tableSch[colIdx].Default) + // For generated columns, use the generated expression as the default + if tableSch[colIdx].Generated != nil { + innerExpr = expression.WrapExpression(tableSch[colIdx].Generated) + } else { + innerExpr = expression.WrapExpression(tableSch[colIdx].Default) + } } } } @@ -492,6 +495,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) { return } +// buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's +// children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We +// don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require +// analyzer processing that converts the subquery into a join, and then requires the same logic to +// create an UpdateJoin node under the original Update node. func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { // TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is. // The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect. @@ -534,44 +542,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { update.IsProcNested = b.ProcCtx().DbName != "" var checks []*sql.CheckConstraint - if join, ok := outScope.node.(*plan.JoinNode); ok { - // TODO this doesn't work, a lot of the time the top node - // is a filter. This would have to go before we build the - // filter/accessory nodes. But that errors for a lot of queries. - source := plan.NewUpdateSource( - join, - ignore, - updateExprs, - ) - updaters, err := rowUpdatersByTable(b.ctx, source, join) + if hasJoinNode(outScope.node) { + tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, update.Child, outScope.node) if err != nil { b.handleErr(err) } - updateJoin := plan.NewUpdateJoin(updaters, source) - update.Child = updateJoin - transform.Inspect(update, func(n sql.Node) bool { - // todo maybe this should be later stage - switch n := n.(type) { - case sql.NameableNode: - if _, ok := updaters[n.Name()]; ok { - rt := getResolvedTable(n) - tableScope := inScope.push() - for _, c := range rt.Schema() { - tableScope.addColumn(scopeColumn{ - db: rt.SqlDatabase.Name(), - table: strings.ToLower(n.Name()), - tableId: tableScope.tables[strings.ToLower(n.Name())], - col: strings.ToLower(c.Name), - typ: c.Type, - nullable: c.Nullable, - }) - } - checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...) - } - default: + + for _, rt := range tablesToUpdate { + tableScope := inScope.push() + for _, c := range rt.Schema() { + tableScope.addColumn(scopeColumn{ + db: rt.SqlDatabase.Name(), + table: strings.ToLower(rt.Name()), + tableId: tableScope.tables[strings.ToLower(rt.Name())], + col: strings.ToLower(c.Name), + typ: c.Type, + nullable: c.Nullable, + }) } - return true - }) + checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...) + } } else { transform.Inspect(update, func(n sql.Node) bool { // todo maybe this should be later stage @@ -583,46 +573,39 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) { } if len(u.Returning) > 0 { - returningExprs := make([]sql.Expression, len(u.Returning)) - for i, selectExpr := range u.Returning { - returningExprs[i] = b.selectExprToExpression(outScope, selectExpr) - } - update.Returning = returningExprs + update.Returning = b.analyzeSelectList(outScope, outScope, u.Returning) } outScope.node = update.WithChecks(checks) return } -// rowUpdatersByTable maps a set of tables to their RowUpdater objects. -func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) { - namesOfTableToBeUpdated := getTablesToBeUpdated(node) - resolvedTables := getTablesByName(ij) - - rowUpdatersByTable := make(map[string]sql.RowUpdater) - for tableToBeUpdated, _ := range namesOfTableToBeUpdated { - resolvedTable, ok := resolvedTables[strings.ToLower(tableToBeUpdated)] - if !ok { - return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) +// hasJoinNode returns true if |node| or any child is a JoinNode. +func hasJoinNode(node sql.Node) bool { + updateJoinFound := false + transform.Inspect(node, func(n sql.Node) bool { + if _, ok := n.(*plan.JoinNode); ok { + updateJoinFound = true } + return !updateJoinFound + }) + return updateJoinFound +} - var table = resolvedTable.UnderlyingTable() +func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) { + namesOfTablesToBeUpdated := getTablesToBeUpdated(node) + resolvedTablesMap := getTablesByName(ij) - // If there is no UpdatableTable for a table being updated, error out - updatable, ok := table.(sql.UpdatableTable) - if !ok && updatable == nil { + for tableToBeUpdated, _ := range namesOfTablesToBeUpdated { + resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)] + if !ok { return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) } - keyless := sql.IsKeyless(updatable.Schema()) - if keyless { - return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN") - } - - rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx) + resolvedTables = append(resolvedTables, resolvedTable) } - return rowUpdatersByTable, nil + return resolvedTables, nil } // getTablesByName takes a node and returns all found resolved tables in a map. diff --git a/sql/planbuilder/dml_validate.go b/sql/planbuilder/dml_validate.go index f7e3c04f44..3ae579535c 100644 --- a/sql/planbuilder/dml_validate.go +++ b/sql/planbuilder/dml_validate.go @@ -165,7 +165,12 @@ func validGeneratedColumnValue(idx int, source sql.Node) bool { if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok { return true } + if _, ok := val.Unwrap().(*expression.DefaultColumn); ok { + return true + } return false + case *expression.DefaultColumn: // handle unwrapped DefaultColumn + return true default: return false } diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index 66075429e9..898273d714 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -29,8 +29,8 @@ func (b *Builder) analyzeProjectionList(inScope, outScope *scope, selectExprs as b.analyzeSelectList(inScope, outScope, selectExprs) } -func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.SelectExprs) { - // todo ideally we would not create new expressions here. +func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.SelectExprs) (expressions []sql.Expression) { + // TODO: ideally we would not create new expressions here. // we want to in-place identify aggregations, expand stars. // use inScope to construct projections for projScope @@ -160,6 +160,7 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se } inScope.parent = tempScope.parent + return exprs } // selectExprToExpression binds dependencies in a scalar expression in a SELECT clause. diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 60a94d94b8..88b3715f29 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -747,10 +747,10 @@ func (b *Builder) buildIsExprToExpression(inScope *scope, c *ast.IsExpr) sql.Exp e := b.buildScalar(inScope, c.Expr) switch strings.ToLower(c.Operator) { case ast.IsNullStr: - return expression.NewIsNull(e) + return expression.DefaultExpressionFactory.NewIsNull(e) case ast.IsNotNullStr: b.qFlags.Set(sql.QFlgNotExpr) - return expression.NewNot(expression.NewIsNull(e)) + return expression.DefaultExpressionFactory.NewIsNotNull(e) case ast.IsTrueStr: return expression.NewIsTrue(e) case ast.IsFalseStr: diff --git a/sql/planbuilder/set_op.go b/sql/planbuilder/set_op.go index 5b443c393d..0ba0e50681 100644 --- a/sql/planbuilder/set_op.go +++ b/sql/planbuilder/set_op.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" + "github.com/dolthub/go-mysql-server/sql/types" ) func hasRecursiveCte(node sql.Node) bool { @@ -144,10 +145,28 @@ func (b *Builder) buildSetOp(inScope *scope, u *ast.SetOp) (outScope *scope) { tabId := b.tabId ret := plan.NewSetOp(setOpType, leftScope.node, rightScope.node, distinct, limit, offset, sortFields).WithId(tabId).WithColumns(cols) outScope = leftScope + outScope.cols = b.mergeSetOpScopeColumns(leftScope.cols, rightScope.cols, tabId) outScope.node = b.mergeSetOpSchemas(ret.(*plan.SetOp)) return } +func (b *Builder) mergeSetOpScopeColumns(left, right []scopeColumn, tabId sql.TableId) []scopeColumn { + merged := make([]scopeColumn, len(left)) + for i := range left { + merged[i] = scopeColumn{ + tableId: tabId, + db: left[i].db, + table: left[i].table, + col: left[i].col, + originalCol: left[i].originalCol, + id: left[i].id, + typ: types.GeneralizeTypes(left[i].typ, right[i].typ), + nullable: left[i].nullable || right[i].nullable, + } + } + return merged +} + func (b *Builder) mergeSetOpSchemas(u *plan.SetOp) sql.Node { ls, rs := u.Left().Schema(), u.Right().Schema() if len(ls) != len(rs) { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 8c4b7341ae..949bbe92ca 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -140,6 +140,12 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. return nil, err } e.Expr = newExpr.(ast.Expr) + case *ast.ExistsExpr: + newSubquery, err := replaceVariablesInExpr(ctx, stack, e.Subquery, asOf) + if err != nil { + return nil, err + } + e.Subquery = newSubquery.(*ast.Subquery) case *ast.FuncExpr: for i := range e.Exprs { newExpr, err := replaceVariablesInExpr(ctx, stack, e.Exprs[i], asOf) diff --git a/sql/rowexec/agg.go b/sql/rowexec/agg.go index e43911065b..384e8efe91 100644 --- a/sql/rowexec/agg.go +++ b/sql/rowexec/agg.go @@ -16,14 +16,13 @@ package rowexec import ( "errors" - "fmt" "io" - "github.com/cespare/xxhash/v2" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/hash" ) type groupByIter struct { @@ -238,47 +237,29 @@ func (i *groupByGroupingIter) Dispose() { } } -func groupingKey( - ctx *sql.Context, - exprs []sql.Expression, - row sql.Row, -) (uint64, error) { - hash := xxhash.New() +func groupingKey(ctx *sql.Context, exprs []sql.Expression, row sql.Row) (uint64, error) { + var keyRow = make(sql.Row, len(exprs)) + var keySch = make(sql.Schema, len(exprs)) for i, expr := range exprs { v, err := expr.Eval(ctx, row) if err != nil { return 0, err } - if i > 0 { - // separate each expression in the grouping key with a nil byte - if _, err = hash.Write([]byte{0}); err != nil { - return 0, err + // TODO: this should be moved into hash.HashOf + typ := expr.Type() + if extTyp, isExtTyp := typ.(types.ExtendedType); isExtTyp { + val, vErr := extTyp.SerializeValue(ctx, v) + if vErr != nil { + return 0, vErr } + v = string(val) } - extendedType, isExtendedType := expr.Type().(types.ExtendedType) - stringType, isStringType := expr.Type().(sql.StringType) - - if isExtendedType && v != nil { - bytes, err := extendedType.SerializeValue(ctx, v) - if err == nil { - _, err = fmt.Fprint(hash, string(bytes)) - } - } else if isStringType && v != nil { - v, err = types.ConvertToString(ctx, v, stringType, nil) - if err == nil { - err = stringType.Collation().WriteWeightString(hash, v.(string)) - } - } else { - _, err = fmt.Fprintf(hash, "%v", v) - } - if err != nil { - return 0, err - } + keyRow[i] = v + keySch[i] = &sql.Column{Type: typ} } - - return hash.Sum64(), nil + return hash.HashOf(ctx, keySch, keyRow) } func newAggregationBuffer(expr sql.Expression) (sql.AggregationBuffer, error) { diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index c2c779a362..e4347c70a8 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row return nil, err } + updaters, err := n.GetUpdaters(ctx) + if err != nil { + return nil, err + } return &updateJoinIter{ updateSourceIter: ji, joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(), - updaters: n.Updaters, + updaters: updaters, caches: make(map[string]sql.KeyValueCache), disposals: make(map[string]sql.DisposeFunc), joinNode: n.Child.(*plan.UpdateSource).Child, diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index c16d4b3b7d..aba643ef98 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -49,6 +49,7 @@ type insertIter struct { firstGeneratedAutoIncRowIdx int deferredDefaults sql.FastIntSet + rowNumber int64 } func getInsertExpressions(values sql.Node) []sql.Expression { @@ -74,6 +75,9 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) return nil, i.ignoreOrClose(ctx, row, err) } + // Increment row number for error reporting (MySQL starts at 1) + i.rowNumber++ + // Prune the row down to the size of the schema. It can be larger in the case of running with an outer scope, in which // case the additional scope variables are prepended to the row. if len(row) > len(i.schema) { @@ -87,7 +91,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) break } _, isColDefVal := i.insertExprs[idx].(*sql.ColumnDefaultValue) - if row[idx] == nil && types.IsEnum(col.Type) && isColDefVal { + if row[idx] == nil && !col.Nullable && types.IsEnum(col.Type) && isColDefVal { row[idx] = 1 } } @@ -140,6 +144,8 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) cErr = types.ErrLengthBeyondLimit.New(row[idx], col.Name) } else if sql.ErrNotMatchingSRID.Is(cErr) { cErr = sql.ErrNotMatchingSRIDWithColName.New(col.Name, cErr) + } else if types.ErrConvertingToEnum.Is(cErr) { + cErr = types.ErrDataTruncatedForColumnAtRow.New(col.Name, i.rowNumber) } return nil, sql.NewWrappedInsertError(origRow, cErr) } @@ -261,6 +267,7 @@ func getFieldIndexFromUpdateExpr(updateExpr sql.Expression) (int, bool) { // resolveValues resolves all VALUES functions. func (i *insertIter) resolveValues(ctx *sql.Context, insertRow sql.Row) error { + // if vals empty then no need to resolve for _, updateExpr := range i.updateExprs { var err error sql.Inspect(updateExpr, func(expr sql.Expression) bool { diff --git a/sql/rowexec/join_iters.go b/sql/rowexec/join_iters.go index d228bf32bc..a39c5f0ff3 100644 --- a/sql/rowexec/join_iters.go +++ b/sql/rowexec/join_iters.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -462,7 +463,7 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) { rightRow, err := i.r.Next(ctx) if err == io.EOF { - key, err := sql.HashOf(ctx, i.leftRow) + key, err := hash.HashOf(ctx, nil, i.leftRow) if err != nil { return nil, err } @@ -485,12 +486,12 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) { if !sql.IsTrue(matches) { continue } - rkey, err := sql.HashOf(ctx, rightRow) + rkey, err := hash.HashOf(ctx, nil, rightRow) if err != nil { return nil, err } i.seenRight[rkey] = struct{}{} - lKey, err := sql.HashOf(ctx, i.leftRow) + lKey, err := hash.HashOf(ctx, nil, i.leftRow) if err != nil { return nil, err } @@ -517,7 +518,7 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, io.EOF } - key, err := sql.HashOf(ctx, rightRow) + key, err := hash.HashOf(ctx, nil, rightRow) if err != nil { return nil, err } diff --git a/sql/rowexec/other_iters.go b/sql/rowexec/other_iters.go index b2f471a071..c4d87cb898 100644 --- a/sql/rowexec/other_iters.go +++ b/sql/rowexec/other_iters.go @@ -18,6 +18,8 @@ import ( "io" "sync" + "github.com/dolthub/go-mysql-server/sql/hash" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" ) @@ -334,7 +336,7 @@ func (ci *concatIter) Next(ctx *sql.Context) (sql.Row, error) { if err != nil { return nil, err } - hash, err := sql.HashOf(ctx, res) + hash, err := hash.HashOf(ctx, nil, res) if err != nil { return nil, err } diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index 041ed8f525..391a9a61e1 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -312,9 +312,10 @@ func (b *BaseBuilder) buildProject(ctx *sql.Context, n *plan.Project, row sql.Ro } return sql.NewSpanIter(span, &ProjectIter{ - projs: n.Projections, - canDefer: n.CanDefer, - childIter: i, + projs: n.Projections, + canDefer: n.CanDefer, + hasNestedIters: n.IncludesNestedIters, + childIter: i, }), nil } @@ -386,9 +387,11 @@ func (b *BaseBuilder) buildSet(ctx *sql.Context, n *plan.Set, row sql.Row) (sql. } copy(resultRow, row) resultRow = row.Append(newRow) + return sql.RowsToRowIter(resultRow), nil } - return sql.RowsToRowIter(resultRow), nil + // For system and user variable SET statements, return OkResult like MySQL does + return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil } func (b *BaseBuilder) buildGroupBy(ctx *sql.Context, n *plan.GroupBy, row sql.Row) (sql.RowIter, error) { diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index a3c372f0e1..a1317b6259 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -22,8 +22,10 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/iters" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -125,16 +127,29 @@ func (i *offsetIter) Close(ctx *sql.Context) error { var _ sql.RowIter = &iters.JsonTableRowIter{} type ProjectIter struct { - projs []sql.Expression - canDefer bool - childIter sql.RowIter + projs []sql.Expression + canDefer bool + hasNestedIters bool + nestedState *nestedIterState + childIter sql.RowIter +} + +type nestedIterState struct { + projections []sql.Expression + sourceRow sql.Row + iterEvaluators []*RowIterEvaluator } func (i *ProjectIter) Next(ctx *sql.Context) (sql.Row, error) { + if i.hasNestedIters { + return i.ProjectRowWithNestedIters(ctx) + } + childRow, err := i.childIter.Next(ctx) if err != nil { return nil, err } + return ProjectRow(ctx, i.projs, childRow) } @@ -154,6 +169,136 @@ func (i *ProjectIter) GetChildIter() sql.RowIter { return i.childIter } +// ProjectRowWithNestedIters evaluates a set of projections, allowing for nested iterators in the expressions. +func (i *ProjectIter) ProjectRowWithNestedIters( + ctx *sql.Context, +) (sql.Row, error) { + + // For the set of iterators, we return one row each element in the longest of the iterators provided. + // Other iterator values will be NULL after they are depleted. All non-iterator fields for the row are returned + // identically for each row in the result set. + if i.nestedState != nil { + row, err := ProjectRow(ctx, i.nestedState.projections, i.nestedState.sourceRow) + if err != nil { + return nil, err + } + + nestedIterationFinished := true + for _, evaluator := range i.nestedState.iterEvaluators { + if !evaluator.finished && evaluator.iter != nil { + nestedIterationFinished = false + break + } + } + + if nestedIterationFinished { + i.nestedState = nil + return i.ProjectRowWithNestedIters(ctx) + } + + return row, nil + } + + row, err := i.childIter.Next(ctx) + if err != nil { + return nil, err + } + + i.nestedState = &nestedIterState{ + sourceRow: row, + } + + // We need a new set of projections, with any iterator-returning expressions replaced by new expressions that will + // return the result of the iteration on each call to Eval. We also need to keep a list of all such iterators, so + // that we can tell when they have all finished their iterations. + var rowIterEvaluators []*RowIterEvaluator + newProjs := make([]sql.Expression, len(i.projs)) + for i, proj := range i.projs { + p, _, err := transform.Expr(proj, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + if rie, ok := e.(sql.RowIterExpression); ok && rie.ReturnsRowIter() { + ri, err := rie.EvalRowIter(ctx, row) + if err != nil { + return nil, false, err + } + + evaluator := &RowIterEvaluator{ + iter: ri, + typ: rie.Type(), + } + rowIterEvaluators = append(rowIterEvaluators, evaluator) + return evaluator, transform.NewTree, nil + } + + return e, transform.SameTree, nil + }) + + if err != nil { + return nil, err + } + + newProjs[i] = p + } + + i.nestedState.projections = newProjs + i.nestedState.iterEvaluators = rowIterEvaluators + + return i.ProjectRowWithNestedIters(ctx) +} + +// RowIterEvaluator is an expression that returns the next value from a sql.RowIter each time Eval is called. +type RowIterEvaluator struct { + iter sql.RowIter + typ sql.Type + finished bool +} + +var _ sql.Expression = (*RowIterEvaluator)(nil) + +func (r RowIterEvaluator) Resolved() bool { + return true +} + +func (r RowIterEvaluator) String() string { + return "RowIterEvaluator" +} + +func (r RowIterEvaluator) Type() sql.Type { + return r.typ +} + +func (r RowIterEvaluator) IsNullable() bool { + return true +} + +func (r *RowIterEvaluator) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + if r.finished || r.iter == nil { + return nil, nil + } + + nextRow, err := r.iter.Next(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + r.finished = true + return nil, nil + } + return nil, err + } + + // All of the set-returning functions return a single value per column + return nextRow[0], nil +} + +func (r RowIterEvaluator) Children() []sql.Expression { + return nil +} + +func (r RowIterEvaluator) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } + return &r, nil +} + // ProjectRow evaluates a set of projections. func ProjectRow( ctx *sql.Context, @@ -446,7 +591,7 @@ func (r *recursiveCteIter) Next(ctx *sql.Context) (sql.Row, error) { var key uint64 if r.deduplicate { - key, _ = sql.HashOf(ctx, row) + key, _ = hash.HashOf(ctx, nil, row) if k, _ := r.cache.Get(key); k != nil { // skip duplicate continue diff --git a/sql/rowexec/subquery_test.go b/sql/rowexec/subquery_test.go index 3b9e5a4624..fcc6649fd3 100644 --- a/sql/rowexec/subquery_test.go +++ b/sql/rowexec/subquery_test.go @@ -92,5 +92,5 @@ func TestSubqueryMultipleRows(t *testing.T) { values, err := subquery.EvalMultiple(ctx, nil) require.NoError(err) - require.Equal(values, []interface{}{"one", "two", "three"}) + require.Equal([]any{"one", "two", "three"}, values) } diff --git a/sql/rowexec/update.go b/sql/rowexec/update.go index 2c4cf4eff1..c7ecba8abf 100644 --- a/sql/rowexec/update.go +++ b/sql/rowexec/update.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/hash" "github.com/dolthub/go-mysql-server/sql/plan" ) @@ -249,7 +250,7 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { // Determine whether this row in the table has already been updated cache := u.getOrCreateCache(ctx, tableName) - hash, err := sql.HashOf(ctx, oldTableRow) + hash, err := hash.HashOf(ctx, nil, oldTableRow) if err != nil { return nil, err } @@ -258,8 +259,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { if errors.Is(err, sql.ErrKeyNotFound) { cache.Put(hash, struct{}{}) - // updateJoin counts matched rows from join output - u.accumulator.handleRowMatched() + // updateJoin counts matched rows from join output, unless a RETURNING clause + // is in use, in which case there will not be an accumulator assigned, since we + // don't need to return the count of updated rows, just the RETURNING expressions. + if u.accumulator != nil { + u.accumulator.handleRowMatched() + } continue } else if err != nil { diff --git a/sql/type.go b/sql/type.go index 738cccfdec..eca6a2bda2 100644 --- a/sql/type.go +++ b/sql/type.go @@ -104,6 +104,9 @@ type Type interface { // NullType represents the type of NULL values type NullType interface { Type + + // IsNullType is a marker interface for types that represent NULL values. + IsNullType() bool } // DeferredType is a placeholder for prepared statements diff --git a/sql/types/conversion.go b/sql/types/conversion.go index fc027f1f65..583dff4a13 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -16,6 +16,7 @@ package types import ( "fmt" + "reflect" "strconv" "strings" "time" @@ -554,3 +555,183 @@ func TypesEqual(a, b sql.Type) bool { return a.Equals(b) } } + +// generalizeNumberTypes assumes both inputs return true for IsNumber +func generalizeNumberTypes(a, b sql.Type) sql.Type { + if IsFloat(a) || IsFloat(b) { + // TODO: handle cases where MySQL returns Float32 + return Float64 + } + + if IsDecimal(a) || IsDecimal(b) { + // TODO: match precision and scale to that of the decimal type, check if defines column + return MustCreateDecimalType(DecimalTypeMaxPrecision, DecimalTypeMaxScale) + } + + aIsSigned := IsSigned(a) + bIsSigned := IsSigned(b) + + if a == Uint64 || b == Uint64 { + if aIsSigned || bIsSigned { + return MustCreateDecimalType(DecimalTypeMaxPrecision, 0) + } + return Uint64 + } + + if a == Int64 || b == Int64 { + return Int64 + } + + if a == Uint32 || b == Uint32 { + if aIsSigned || bIsSigned { + return Int64 + } + return Uint32 + } + + if a == Int32 || b == Int32 { + return Int32 + } + + if a == Uint24 || b == Uint24 { + if aIsSigned || bIsSigned { + return Int32 + } + return Uint24 + } + + if a == Int24 || b == Int24 { + return Int24 + } + + if a == Uint16 || b == Uint16 { + if aIsSigned || bIsSigned { + return Int24 + } + return Uint16 + } + + if a == Int16 || b == Int16 { + return Int16 + } + + if a == Uint8 || b == Uint8 { + if aIsSigned || bIsSigned { + return Int16 + } + return Uint8 + } + + if a == Int8 || b == Int8 { + return Int8 + } + + if IsBoolean(a) && IsBoolean(b) { + return Boolean + } + + return Int64 +} + +// GeneralizeTypes returns the more "general" of two types as defined by +// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html +// TODO: Create and handle "Illegal mix of collations" error +// TODO: Handle extended types, like DoltgresType +func GeneralizeTypes(a, b sql.Type) sql.Type { + if reflect.DeepEqual(a, b) { + return a + } + + if IsNullType(a) { + return b + } + if IsNullType(b) { + return a + } + + if svt, ok := a.(sql.SystemVariableType); ok { + a = svt.UnderlyingType() + } + if svt, ok := a.(sql.SystemVariableType); ok { + b = svt.UnderlyingType() + } + + if IsJSON(a) && IsJSON(b) { + return JSON + } + + if IsGeometry(a) && IsGeometry(b) { + return a + } + + if IsEnum(a) && IsEnum(b) { + return a + } + + if IsSet(a) && IsSet(b) { + return a + } + + aIsTimespan := IsTimespan(a) + bIsTimespan := IsTimespan(b) + if aIsTimespan && bIsTimespan { + return Time + } + if (IsTime(a) || aIsTimespan) && (IsTime(b) || bIsTimespan) { + if IsDateType(a) && IsDateType(b) { + return Date + } + if IsTimestampType(a) && IsTimestampType(b) { + // TODO: match precision to max precision of the two timestamps + return TimestampMaxPrecision + } + // TODO: match precision to max precision of the two time types + return DatetimeMaxPrecision + } + + if IsBlobType(a) || IsBlobType(b) { + // TODO: match blob length to max of the blob lengths + return LongBlob + } + + aIsBit := IsBit(a) + bIsBit := IsBit(b) + if aIsBit && bIsBit { + // TODO: match max bits to max of max bits between a and b + return a.Promote() + } + if aIsBit { + a = Int64 + } + if bIsBit { + b = Int64 + } + + aIsYear := IsYear(a) + bIsYear := IsYear(b) + if aIsYear && bIsYear { + return a + } + if aIsYear { + a = Int32 + } + if bIsYear { + b = Int32 + } + + if IsNumber(a) && IsNumber(b) { + return generalizeNumberTypes(a, b) + } + + if IsText(a) && IsText(b) { + sta := a.(sql.StringType) + stb := b.(sql.StringType) + if sta.Length() > stb.Length() { + return a + } + return b + } + + // TODO: decide if we want to make this VarChar to match MySQL, match VarChar length to max of two types + return LongText +} diff --git a/sql/types/conversion_test.go b/sql/types/conversion_test.go index 35cb5f03d2..07a719b782 100644 --- a/sql/types/conversion_test.go +++ b/sql/types/conversion_test.go @@ -119,7 +119,7 @@ func TestColumnTypeToType_Time(t *testing.T) { } func TestColumnCharTypes(t *testing.T) { - test := []struct { + tests := []struct { typ string len int64 exp sql.Type @@ -146,7 +146,7 @@ func TestColumnCharTypes(t *testing.T) { }, } - for _, test := range test { + for _, test := range tests { t.Run(fmt.Sprintf("%v %v", test.typ, test.exp), func(t *testing.T) { ct := &sqlparser.ColumnType{ Type: test.typ, @@ -158,3 +158,65 @@ func TestColumnCharTypes(t *testing.T) { }) } } + +func TestGeneralizeTypes(t *testing.T) { + decimalType := MustCreateDecimalType(DecimalTypeMaxPrecision, DecimalTypeMaxScale) + uint64DecimalType := MustCreateDecimalType(DecimalTypeMaxPrecision, 0) + + tests := []struct { + typeA sql.Type + typeB sql.Type + expected sql.Type + }{ + {Float64, Float32, Float64}, + {Float64, Int32, Float64}, + {Int24, Float32, Float64}, + {decimalType, Float64, Float64}, + {decimalType, Int32, decimalType}, + {Int64, decimalType, decimalType}, + {Uint64, Int32, uint64DecimalType}, + {Int24, Uint64, uint64DecimalType}, + {Uint64, Uint8, Uint64}, + {Uint24, Uint64, Uint64}, + {Int64, Uint32, Int64}, + {Int24, Int64, Int64}, + {Int8, Int64, Int64}, + {Uint32, Int24, Int64}, + {Uint24, Uint32, Uint32}, + {Int32, Int8, Int32}, + {Uint24, Int32, Int32}, + {Uint24, Int24, Int32}, + {Uint8, Uint24, Uint24}, + {Int24, Uint8, Int24}, + {Int8, Int24, Int24}, + {Int8, Uint16, Int24}, + {Uint16, Uint8, Uint16}, + {Int16, Int16, Int16}, + {Int8, Int16, Int16}, + {Uint8, Int8, Int16}, + {Uint8, Uint8, Uint8}, + {Int8, Int8, Int8}, + {Boolean, Int64, Int64}, + {Boolean, Boolean, Boolean}, + {Text, Text, Text}, + {Text, LongText, LongText}, + {Text, Float64, LongText}, + {Int64, Text, LongText}, + {Int8, Null, Int8}, + {Time, Time, Time}, + {Time, Date, DatetimeMaxPrecision}, + {Date, Date, Date}, + {Date, Timestamp, DatetimeMaxPrecision}, + {Timestamp, Timestamp, Timestamp}, + {Timestamp, TimestampMaxPrecision, TimestampMaxPrecision}, + {Timestamp, Datetime, DatetimeMaxPrecision}, + {Null, Int64, Int64}, + {Null, Null, Null}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%v %v %v", test.typeA, test.typeB, test.expected), func(t *testing.T) { + res := GeneralizeTypes(test.typeA, test.typeB) + assert.Equal(t, test.expected, res) + }) + } +} diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 19197ce0e7..f37c3963e6 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -39,18 +39,30 @@ var ( ErrConvertingToTimeOutOfRange = errors.NewKind("value %q is outside of %v range") - // datetimeTypeMaxDatetime is the maximum representable Datetime/Date value. - datetimeTypeMaxDatetime = time.Date(9999, 12, 31, 23, 59, 59, 999999000, time.UTC) + // datetimeTypeMaxDatetime is the maximum representable Datetime/Date value. MYSQL: 9999-12-31 23:59:59.499999 (microseconds) + datetimeTypeMaxDatetime = time.Date(9999, 12, 31, 23, 59, 59, 499999000, time.UTC) - // datetimeTypeMinDatetime is the minimum representable Datetime/Date value. - datetimeTypeMinDatetime = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) + // datetimeTypeMinDatetime is the minimum representable Datetime/Date value. MYSQL: 1000-01-01 00:00:00.000000 (microseconds) + datetimeTypeMinDatetime = time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC) - // datetimeTypeMaxTimestamp is the maximum representable Timestamp value, which is the maximum 32-bit integer as a Unix time. + // datetimeTypeMaxTimestamp is the maximum representable Timestamp value, MYSQL: 2038-01-19 03:14:07.999999 (microseconds) datetimeTypeMaxTimestamp = time.Unix(math.MaxInt32, 999999000) - // datetimeTypeMinTimestamp is the minimum representable Timestamp value, which is one second past the epoch. + // datetimeTypeMinTimestamp is the minimum representable Timestamp value, MYSQL: 1970-01-01 00:00:01.000000 (microseconds) datetimeTypeMinTimestamp = time.Unix(1, 0) + datetimeTypeMaxDate = time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC) + + // datetimeTypeMinDate is the minimum representable Date value, MYSQL: 1000-01-01 00:00:00.000000 (microseconds) + datetimeTypeMinDate = time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC) + + // The MAX and MIN are extrapolated from commit ff05628a530 in the MySQL source code from my_time.cc + // datetimeMaxTime is the maximum representable time value, MYSQL: 9999-12-31 23:59:59.999999 (microseconds) + datetimeMaxTime = time.Date(9999, 12, 31, 23, 59, 59, 999999000, time.UTC) + + // datetimeMinTime is the minimum representable time value, MYSQL: 0000-01-01 00:00:00.000000 (microseconds) + datetimeMinTime = time.Date(0000, 0, 0, 0, 0, 0, 0, time.UTC) + DateOnlyLayouts = []string{ "20060102", "2006-1-2", @@ -71,8 +83,9 @@ var ( "2006-01-02 15:04:", "2006-01-02 15:04:.", "2006-01-02 15:04:05.", - "2006-01-02 15:04:05.999999", - "2006-1-2 15:4:5.999999", + "2006-01-02 15:04:05.999999999", + "2006-1-2 15:4:5.999999999", + "2006-1-2:15:4:5.999999999", "2006-01-02T15:04:05", "20060102150405", "2006-01-02 15:04:05.999999999 -0700 MST", // represents standard Time.time.UTC() @@ -91,6 +104,8 @@ var ( Timestamp = MustCreateDatetimeType(sqltypes.Timestamp, 0) // TimestampMaxPrecision is a UNIX timestamp with maximum precision TimestampMaxPrecision = MustCreateDatetimeType(sqltypes.Timestamp, 6) + // DatetimeMaxRange is a date and a time with maximum precision and maximum range. + DatetimeMaxRange = MustCreateDatetimeType(sqltypes.Datetime, 6) datetimeValueType = reflect.TypeOf(time.Time{}) ) @@ -200,9 +215,20 @@ func ConvertToTime(ctx context.Context, v interface{}, t datetimeType) (time.Tim } // Round the date to the precision of this type - truncationDuration := time.Second - truncationDuration /= time.Duration(precisionConversion[t.precision]) - res = res.Round(truncationDuration) + if t.precision < 6 { + truncationDuration := time.Second / time.Duration(precisionConversion[t.precision]) + res = res.Round(truncationDuration) + } else { + res = res.Round(time.Microsecond) + } + + if t == DatetimeMaxRange { + validated := ValidateTime(res) + if validated == nil { + return time.Time{}, ErrConvertingToTimeOutOfRange.New(v, t) + } + return validated.(time.Time), nil + } switch t.baseType { case sqltypes.Date: @@ -214,10 +240,11 @@ func ConvertToTime(ctx context.Context, v interface{}, t datetimeType) (time.Tim return time.Time{}, ErrConvertingToTimeOutOfRange.New(res.Format(sql.TimestampDatetimeLayout), t.String()) } case sqltypes.Timestamp: - if res.Before(datetimeTypeMinTimestamp) || res.After(datetimeTypeMaxTimestamp) { + if ValidateTimestamp(res) == nil { return time.Time{}, ErrConvertingToTimeOutOfRange.New(res.Format(sql.TimestampDatetimeLayout), t.String()) } } + return res, nil } @@ -338,8 +365,8 @@ func (t datetimeType) ConvertWithoutRangeCheck(ctx context.Context, v interface{ } func parseDatetime(value string) (time.Time, bool) { - for _, fmt := range TimestampDatetimeLayouts { - if t, err := time.Parse(fmt, value); err == nil { + for _, layout := range TimestampDatetimeLayouts { + if t, err := time.Parse(layout, value); err == nil { return t.UTC(), true } } @@ -473,7 +500,16 @@ func (t datetimeType) MinimumTime() time.Time { // ValidateTime receives a time and returns either that time or nil if it's // not a valid time. func ValidateTime(t time.Time) interface{} { - if t.After(time.Date(9999, time.December, 31, 23, 59, 59, 999999999, time.UTC)) { + if t.Before(datetimeMinTime) || t.After(datetimeMaxTime) { + return nil + } + return t +} + +// ValidateTimestamp receives a time and returns either that time or nil if it's +// not a valid timestamp. +func ValidateTimestamp(t time.Time) interface{} { + if t.Before(datetimeTypeMinTimestamp) || t.After(datetimeTypeMaxTimestamp) { return nil } return t diff --git a/sql/types/datetime_test.go b/sql/types/datetime_test.go index 26edeb9945..6efc77af6c 100644 --- a/sql/types/datetime_test.go +++ b/sql/types/datetime_test.go @@ -405,3 +405,35 @@ func TestDatetimeZero(t *testing.T) { _, ok = MustCreateDatetimeType(sqltypes.Timestamp, 0).Zero().(time.Time) require.True(t, ok) } + +func TestDatetimeOverflowUnderflow(t *testing.T) { + ctx := sql.NewEmptyContext() + tests := []struct { + typ sql.DatetimeType + val interface{} + expectError bool + }{ + {Timestamp, "1969-12-31 23:59:59", true}, + {Timestamp, "2038-01-19 03:14:08", true}, + {Date, Date.MinimumTime().Format("2006-01-02"), false}, + {Date, Date.MaximumTime().Format("2006-01-02"), false}, + {Datetime, Datetime.MinimumTime().Format("2006-01-02 15:04:05"), false}, + {Datetime, Datetime.MaximumTime().Format("2006-01-02 15:04:05"), false}, + {Timestamp, Timestamp.MinimumTime().Format("2006-01-02 15:04:05"), false}, + {Timestamp, Timestamp.MaximumTime().Format("2006-01-02 15:04:05"), false}, + } + + for _, tt := range tests { + t.Run(tt.typ.String()+"_"+tt.val.(string), func(t *testing.T) { + _, inRange, err := tt.typ.Convert(ctx, tt.val) + + if tt.expectError { + require.True(t, err != nil || inRange == sql.OutOfRange, + "expected error or out-of-range but got neither; err: %v, inRange: %v", err, inRange) + } else { + require.NoError(t, err) + require.Equal(t, sql.InRange, inRange) + } + }) + } +} diff --git a/sql/types/enum.go b/sql/types/enum.go index c01b0de0da..067adc4540 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -42,7 +42,8 @@ const ( var ( ErrConvertingToEnum = errors.NewKind("value %v is not valid for this Enum") - ErrDataTruncatedForColumn = errors.NewKind("Data truncated for column '%s'") + ErrDataTruncatedForColumn = errors.NewKind("Data truncated for column '%s'") + ErrDataTruncatedForColumnAtRow = errors.NewKind("Data truncated for column '%s' at row %d") enumValueType = reflect.TypeOf(uint16(0)) ) @@ -164,6 +165,10 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql. switch value := v.(type) { case int: + // MySQL rejects 0 values in strict mode regardless of enum definition + if value == 0 && t.validateScrictMode(ctx) { + return nil, sql.OutOfRange, ErrConvertingToEnum.New(value) + } if _, ok := t.At(value); ok { return uint16(value), sql.InRange, nil } @@ -176,7 +181,10 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql. case int16: return t.Convert(ctx, int(value)) case uint16: - return t.Convert(ctx, int(value)) + // uint16 values are stored enum indices - allow them without strict mode validation + if _, ok := t.At(int(value)); ok { + return value, sql.InRange, nil + } case int32: return t.Convert(ctx, int(value)) case uint32: @@ -207,6 +215,15 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql. return nil, sql.InRange, ErrConvertingToEnum.New(v) } +// validateScrictMode checks if STRICT_TRANS_TABLES or STRICT_ALL_TABLES is enabled +func (t EnumType) validateScrictMode(ctx context.Context) bool { + if sqlCtx, ok := ctx.(*sql.Context); ok { + sqlMode := sql.LoadSqlMode(sqlCtx) + return sqlMode.ModeEnabled("STRICT_TRANS_TABLES") || sqlMode.ModeEnabled("STRICT_ALL_TABLES") + } + return false +} + // Equals implements the Type interface. func (t EnumType) Equals(otherType sql.Type) bool { if ot, ok := otherType.(EnumType); ok && t.collation.Equals(ot.collation) && len(t.idxToVal) == len(ot.idxToVal) { diff --git a/sql/types/json_encode.go b/sql/types/json_encode.go index 727365b38b..19818d516a 100644 --- a/sql/types/json_encode.go +++ b/sql/types/json_encode.go @@ -210,10 +210,12 @@ func writeMarshalledValue(writer io.Writer, val interface{}) error { writer.Write([]byte{'{'}) for i, k := range keys { - writer.Write([]byte{'"'}) - writer.Write([]byte(k)) - writer.Write([]byte(`": `)) - err := writeMarshalledValue(writer, val[k]) + err := writeMarshalledValue(writer, k) + if err != nil { + return err + } + writer.Write([]byte(`: `)) + err = writeMarshalledValue(writer, val[k]) if err != nil { return err } diff --git a/sql/types/json_encode_test.go b/sql/types/json_encode_test.go index e167dd82d3..4f5190361c 100644 --- a/sql/types/json_encode_test.go +++ b/sql/types/json_encode_test.go @@ -106,6 +106,14 @@ newlines val: decimal.New(123, -2), expected: "1.23", }, + { + name: "formatted key strings", + val: map[string]interface{}{ + "baz\n\\n": "qux", + "foo\"": "bar\t", + }, + expected: `{"foo\"": "bar\t", "baz\n\\n": "qux"}`, + }, } for _, test := range tests { diff --git a/sql/types/null.go b/sql/types/null.go index 5a06746d27..5702f0e1dd 100644 --- a/sql/types/null.go +++ b/sql/types/null.go @@ -34,6 +34,10 @@ var ( type nullType struct{} +func (t nullType) IsNullType() bool { + return true +} + // Compare implements Type interface. Note that while this returns 0 (equals) // for ordering purposes, in SQL NULL != NULL. func (t nullType) Compare(s context.Context, a interface{}, b interface{}) (int, error) { diff --git a/sql/types/strings.go b/sql/types/strings.go index be4119680a..55779a0cae 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -496,6 +496,29 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [ return val, nil } +// convertToLongTextString safely converts a value to string using LongText.Convert with nil checking +func convertToLongTextString(ctx context.Context, val interface{}) (string, error) { + converted, _, err := LongText.Convert(ctx, val) + if err != nil { + return "", err + } + if converted == nil { + return "", nil + } + return converted.(string), nil +} + +// convertEnumToString converts an enum value to its string representation +func convertEnumToString(ctx context.Context, val interface{}, enumType sql.EnumType) (string, error) { + if enumVal, ok := val.(uint16); ok { + if enumStr, exists := enumType.At(int(enumVal)); exists { + return enumStr, nil + } + return "", nil + } + return convertToLongTextString(ctx, val) +} + // ConvertToCollatedString returns the given interface as a string, along with its collation. If the Type possess a // collation, then that collation is returned. If the Type does not possess a collation (such as an integer), then the // value is converted to a string and the default collation is used. If the value is already a string then no additional @@ -516,20 +539,32 @@ func ConvertToCollatedString(ctx context.Context, val interface{}, typ sql.Type) content = strVal } else if byteVal, ok := val.([]byte); ok { content = encodings.BytesToString(byteVal) + } else if enumType, ok := typ.(sql.EnumType); ok { + // Handle enum types in string context - return the string value, not the index + content, err = convertEnumToString(ctx, val, enumType) + if err != nil { + return "", sql.Collation_Unspecified, err + } } else { - val, _, err = LongText.Convert(ctx, val) + content, err = convertToLongTextString(ctx, val) if err != nil { return "", sql.Collation_Unspecified, err } - content = val.(string) } } else { collation = sql.Collation_Default - val, _, err = LongText.Convert(ctx, val) + // Handle enum types in string context even without collation + if enumType, ok := typ.(sql.EnumType); ok { + content, err = convertEnumToString(ctx, val, enumType) + if err != nil { + return "", sql.Collation_Unspecified, err + } + return content, collation, nil + } + content, err = convertToLongTextString(ctx, val) if err != nil { return "", sql.Collation_Unspecified, err } - content = val.(string) } return content, collation, nil } diff --git a/sql/types/tuple_value.go b/sql/types/tuple_value.go deleted file mode 100644 index 7ec2eef818..0000000000 --- a/sql/types/tuple_value.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2025 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types - -import "github.com/dolthub/go-mysql-server/sql" - -// TupleValue represents a value and its associated type information. TupleValue is used by collections of -// values where the type information is not consistent across all values (e.g. Records in Postgres). -type TupleValue struct { - Value any - Type sql.Type -} diff --git a/sql/types/typecheck.go b/sql/types/typecheck.go index 5c090d72af..26fd198907 100644 --- a/sql/types/typecheck.go +++ b/sql/types/typecheck.go @@ -106,6 +106,14 @@ func IsNumber(t sql.Type) bool { } } +func IsNullType(t sql.Type) bool { + nt, ok := t.(sql.NullType) + if !ok { + return false + } + return nt.IsNullType() +} + // IsSigned checks if t is a signed type. func IsSigned(t sql.Type) bool { if svt, ok := t.(sql.SystemVariableType); ok {