diff --git a/contrib/drivers/mssql/mssql_do_filter.go b/contrib/drivers/mssql/mssql_do_filter.go index 8cf790e4faa..96dd3709c9c 100644 --- a/contrib/drivers/mssql/mssql_do_filter.go +++ b/contrib/drivers/mssql/mssql_do_filter.go @@ -18,7 +18,8 @@ import ( ) var ( - selectSqlTmp = `SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ` + orderBySqlTmp = `SELECT %s %s OFFSET %d ROWS FETCH NEXT %d ROWS ONLY` + withoutOrderBySqlTmp = `SELECT %s OFFSET %d ROWS FETCH NEXT %d ROWS ONLY` selectWithOrderSqlTmp = ` SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d @@ -78,89 +79,55 @@ func (d *Driver) parseSql(toBeCommittedSql string) (string, error) { func (d *Driver) handleSelectSqlReplacement(toBeCommittedSql string) (newSql string, err error) { // SELECT * FROM USER WHERE ID=1 LIMIT 1 - match, err := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, toBeCommittedSql) + match, err := gregex.MatchString(`^SELECT(.+?)LIMIT\s+1$`, toBeCommittedSql) if err != nil { return "", err } if len(match) > 1 { - return fmt.Sprintf(`SELECT TOP 1 %s`, match[1]), nil + return fmt.Sprintf(`SELECT TOP 1 %s`, strings.TrimSpace(match[1])), nil } // SELECT * FROM USER WHERE AGE>18 ORDER BY ID DESC LIMIT 100, 200 - patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` - if gregex.IsMatchString(patten, toBeCommittedSql) == false { + pattern := `(?i)SELECT(.+?)(ORDER BY.+?)?\s*LIMIT\s*(\d+)(?:\s*,\s*(\d+))?` + if !gregex.IsMatchString(pattern, toBeCommittedSql) { return toBeCommittedSql, nil } - allMatch, err := gregex.MatchAllString(patten, toBeCommittedSql) + + allMatch, err := gregex.MatchString(pattern, toBeCommittedSql) if err != nil { return "", err } - var index = 1 - // LIMIT statement checks. - if len(allMatch) < 2 || - (strings.HasPrefix(allMatch[index][0], "LIMIT") == false && - strings.HasPrefix(allMatch[index][0], "limit") == false) { - return toBeCommittedSql, nil - } - if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", toBeCommittedSql) == false { - return toBeCommittedSql, nil + + // SELECT and ORDER BY + selectStr := strings.TrimSpace(allMatch[1]) + orderStr := "" + if len(allMatch[2]) > 0 { + orderStr = strings.TrimSpace(allMatch[2]) } - // ORDER BY statement checks. - var ( - selectStr = "" - orderStr = "" - haveOrder = gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", toBeCommittedSql) - ) - if haveOrder { - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", toBeCommittedSql) - if len(queryExpr) != 4 || - strings.EqualFold(queryExpr[1], "SELECT") == false || - strings.EqualFold(queryExpr[3], "ORDER BY") == false { - return toBeCommittedSql, nil - } - selectStr = queryExpr[2] - orderExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", toBeCommittedSql) - if len(orderExpr) != 4 || - strings.EqualFold(orderExpr[1], "ORDER BY") == false || - strings.EqualFold(orderExpr[3], "LIMIT") == false { - return toBeCommittedSql, nil - } - orderStr = orderExpr[2] + + // LIMIT and OFFSET value + first, _ := strconv.Atoi(allMatch[3]) // LIMIT first parameter + limit := 0 + if len(allMatch) > 4 && allMatch[4] != "" { + limit, _ = strconv.Atoi(allMatch[4]) // LIMIT second parameter } else { - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", toBeCommittedSql) - if len(queryExpr) != 4 || - strings.EqualFold(queryExpr[1], "SELECT") == false || - strings.EqualFold(queryExpr[3], "LIMIT") == false { - return toBeCommittedSql, nil - } - selectStr = queryExpr[2] + limit = first + first = 0 } - first, limit := 0, 0 - for i := 1; i < len(allMatch[index]); i++ { - if len(strings.TrimSpace(allMatch[index][i])) == 0 { - continue - } - if strings.HasPrefix(allMatch[index][i], "LIMIT") || - strings.HasPrefix(allMatch[index][i], "limit") { - first, _ = strconv.Atoi(allMatch[index][i+1]) - limit, _ = strconv.Atoi(allMatch[index][i+2]) - break - } - } - if haveOrder { - toBeCommittedSql = fmt.Sprintf( - selectWithOrderSqlTmp, - orderStr, selectStr, first, first+limit, + + if orderStr != "" { + // have ORDER BY clause + newSql = fmt.Sprintf( + orderBySqlTmp, + selectStr, orderStr, first, limit, + ) + } else { + // without ORDER BY clause + newSql = fmt.Sprintf( + withoutOrderBySqlTmp, + selectStr, first, limit, ) - return toBeCommittedSql, nil } - if first == 0 { - first = limit - } - toBeCommittedSql = fmt.Sprintf( - selectSqlTmp, - limit, first+limit, selectStr, - ) - return toBeCommittedSql, nil + return newSql, nil } diff --git a/contrib/drivers/mssql/mssql_do_filter_test.go b/contrib/drivers/mssql/mssql_do_filter_test.go new file mode 100644 index 00000000000..91ded572382 --- /dev/null +++ b/contrib/drivers/mssql/mssql_do_filter_test.go @@ -0,0 +1,132 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package mssql + +import ( + "context" + "reflect" + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func TestDriver_DoFilter(t *testing.T) { + type fields struct { + Core *gdb.Core + } + type args struct { + ctx context.Context + link gdb.Link + sql string + args []interface{} + } + var tests []struct { + name string + fields fields + args args + wantNewSql string + wantNewArgs []interface{} + wantErr bool + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &Driver{ + Core: tt.fields.Core, + } + gotNewSql, gotNewArgs, err := d.DoFilter(tt.args.ctx, tt.args.link, tt.args.sql, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("DoFilter() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotNewSql != tt.wantNewSql { + t.Errorf("DoFilter() gotNewSql = %v, want %v", gotNewSql, tt.wantNewSql) + } + if !reflect.DeepEqual(gotNewArgs, tt.wantNewArgs) { + t.Errorf("DoFilter() gotNewArgs = %v, want %v", gotNewArgs, tt.wantNewArgs) + } + }) + } +} + +func TestDriver_handleSelectSqlReplacement(t *testing.T) { + + gtest.C(t, func(t *gtest.T) { + d := &Driver{} + + // LIMIT 1 + inputSql := "SELECT * FROM User WHERE ID = 1 LIMIT 1" + expectedSql := "SELECT TOP 1 * FROM User WHERE ID = 1" + resultSql, err := d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // LIMIT query with offset and number of rows + inputSql = "SELECT * FROM User ORDER BY ID DESC LIMIT 100, 200" + expectedSql = "SELECT * FROM User ORDER BY ID DESC OFFSET 100 ROWS FETCH NEXT 200 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // Simple query with no LIMIT + inputSql = "SELECT * FROM User WHERE age > 18" + expectedSql = "SELECT * FROM User WHERE age > 18" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // without LIMIT + inputSql = "SELECT * FROM User ORDER BY ID DESC" + expectedSql = "SELECT * FROM User ORDER BY ID DESC" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // LIMIT query with only rows + inputSql = "SELECT * FROM User LIMIT 50" + expectedSql = "SELECT * FROM User OFFSET 0 ROWS FETCH NEXT 50 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // LIMIT query without ORDER BY + inputSql = "SELECT * FROM User LIMIT 30" + expectedSql = "SELECT * FROM User OFFSET 0 ROWS FETCH NEXT 30 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // Complex query with ORDER BY and LIMIT + inputSql = "SELECT name, age FROM User WHERE age > 18 ORDER BY age ASC LIMIT 10, 5" + expectedSql = "SELECT name, age FROM User WHERE age > 18 ORDER BY age ASC OFFSET 10 ROWS FETCH NEXT 5 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // Complex conditional queries have limits + inputSql = "SELECT * FROM User WHERE age > 18 AND status = 'active' LIMIT 100, 50" + expectedSql = "SELECT * FROM User WHERE age > 18 AND status = 'active' OFFSET 100 ROWS FETCH NEXT 50 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // A LIMIT query that contains subquery + inputSql = "SELECT * FROM (SELECT * FROM User WHERE age > 18) AS subquery LIMIT 10" + expectedSql = "SELECT * FROM (SELECT * FROM User WHERE age > 18) AS subquery OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + // Queries with complex ORDER BY and LIMIT + inputSql = "SELECT name, age FROM User WHERE age > 18 ORDER BY age DESC, name ASC LIMIT 20, 10" + expectedSql = "SELECT name, age FROM User WHERE age > 18 ORDER BY age DESC, name ASC OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY" + resultSql, err = d.handleSelectSqlReplacement(inputSql) + t.AssertNil(err) + t.Assert(resultSql, expectedSql) + + }) +}