Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions go/sqltypes/parse_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sqltypes

import (
"fmt"
"io"
"reflect"
"strconv"
"strings"
"text/scanner"

querypb "vitess.io/vitess/go/vt/proto/query"
)

// ParseRows parses the output generated by fmt.Sprintf("#v", rows), and reifies the original []sqltypes.Row
// NOTE: This is not meant for production use!
func ParseRows(input string) ([]Row, error) {
type state int
const (
stInvalid state = iota
stInit
stBeginRow
stInRow
stInValue0
stInValue1
stInValue2
)

var (
scan scanner.Scanner
result []Row
row Row
vtype int32
st = stInit
)

scan.Init(strings.NewReader(input))

for tok := scan.Scan(); tok != scanner.EOF; tok = scan.Scan() {
var next state

switch st {
case stInit:
if tok == '[' {
next = stBeginRow
}
case stBeginRow:
switch tok {
case '[':
next = stInRow
case ']':
return result, nil
}
case stInRow:
switch tok {
case ']':
result = append(result, row)
row = nil
next = stBeginRow
case scanner.Ident:
ident := scan.TokenText()

if ident == "NULL" {
row = append(row, NULL)
continue
}

var ok bool
vtype, ok = querypb.Type_value[ident]
if !ok {
return nil, fmt.Errorf("unknown SQL type %q at %s", ident, scan.Position)
}
next = stInValue0
}
case stInValue0:
if tok == '(' {
next = stInValue1
}
case stInValue1:
literal := scan.TokenText()
switch tok {
case scanner.String:
var err error
literal, err = strconv.Unquote(literal)
if err != nil {
return nil, fmt.Errorf("failed to parse literal string at %s: %w", scan.Position, err)
}
fallthrough
case scanner.Int, scanner.Float:
row = append(row, MakeTrusted(Type(vtype), []byte(literal)))
next = stInValue2
}
case stInValue2:
if tok == ')' {
next = stInRow
}
}
if next == stInvalid {
return nil, fmt.Errorf("unexpected token '%s' at %s", scan.TokenText(), scan.Position)
}
st = next
}
return nil, io.ErrUnexpectedEOF
}

type RowMismatchError struct {
err error
want, got []Row
}

func (e *RowMismatchError) Error() string {
return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot: %v", e.err, e.want, e.got)
}

func RowsEquals(want, got []Row) error {
if len(want) != len(got) {
return &RowMismatchError{
err: fmt.Errorf("expected %d rows in result, got %d", len(want), len(got)),
want: want,
got: got,
}
}

var matched = make([]bool, len(want))
for _, aa := range want {
var ok bool
for i, bb := range got {
if matched[i] {
continue
}
if reflect.DeepEqual(aa, bb) {
matched[i] = true
ok = true
break
}
}
if !ok {
return &RowMismatchError{
err: fmt.Errorf("row %v is missing from result", aa),
want: want,
got: got,
}
}
}
for _, m := range matched {
if !m {
return fmt.Errorf("not all elements matched")
}
}
return nil
}

func RowsEqualsStr(wantStr string, got []Row) error {
want, err := ParseRows(wantStr)
if err != nil {
return fmt.Errorf("malformed row assertion: %w", err)
}
return RowsEquals(want, got)
}
187 changes: 187 additions & 0 deletions go/sqltypes/parse_rows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sqltypes

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
)

var TestRows = []string{
"[]",
"[[INT64(1)]]",
"[[DECIMAL(6)]]",
"[[DECIMAL(5)]]",
"[[DECIMAL(6)]]",
"[[DECIMAL(8)]]",
"[[NULL]]",
"[[INT32(1) INT64(2) INT64(420)]]",
"[[INT32(1) INT64(2) INT64(420)]]",
"[[INT32(1) INT64(2) INT64(420)] [INT32(2) INT64(4) INT64(420)] [INT32(3) INT64(6) INT64(420)]]",
"[[INT64(3) INT64(420)]]",
"[[INT32(1) INT64(2) INT64(420)]]",
"[[INT32(1) INT64(2) INT64(420)]]",
"[[INT64(666) INT64(20) INT64(420)]]",
"[[INT64(4)]]",
"[[INT64(12) DECIMAL(7900)]]",
"[[INT64(3) INT64(4)]]",
"[[INT32(3)]]",
"[[INT32(2)]]",
"[[INT64(3) INT64(4)]]",
"[[INT32(100) INT64(1) INT64(2)] [INT32(200) INT64(1) INT64(1)] [INT32(300) INT64(1) INT64(1)]]",
"[[INT64(1) INT64(1)]]",
"[[INT64(0) INT64(0)]]",
"[[DECIMAL(2.0000)]]",
"[[INT32(100) DECIMAL(1.0000)] [INT32(200) DECIMAL(2.0000)] [INT32(300) DECIMAL(3.0000)]]",
"[[INT64(3) DECIMAL(2.0000)]]",
"[[INT64(3) INT64(4)]]",
"[[INT32(100) INT64(1) INT64(2)] [INT32(200) INT64(1) INT64(1)] [INT32(300) INT64(1) INT64(1)]]",
"[[INT64(1) INT64(1)]]",
"[[DECIMAL(6)]]",
"[[FLOAT64(6)]]",
"[[INT32(3)]]",
"[[FLOAT64(3)]]",
"[[INT32(1)]]",
"[[FLOAT64(1)]]",
"[[DECIMAL(6) FLOAT64(1)]]",
"[[INT32(2) DECIMAL(14)]]",
"[[INT32(3) INT32(9)] [INT32(2) INT32(4)] [INT32(1) INT32(1)]]",
"[[INT32(3) INT32(9)] [INT32(2) INT32(4)] [INT32(1) INT32(1)]]",
"[[INT32(1) INT64(20)] [INT32(1) INT64(10)] [INT32(4) INT64(20)] [INT32(2) INT64(10)] [INT32(9) INT64(20)] [INT32(3) INT64(10)]]",
"[[INT32(2) INT32(4)]]",
"[[INT32(5) INT32(4)] [INT32(3) INT32(9)] [INT32(2) INT32(4)] [INT32(1) INT32(1)]]",
"[[INT32(5) INT32(4)] [INT32(3) INT32(9)] [INT32(2) INT32(4)] [INT32(1) INT32(1)]]",
"[[INT32(2) INT32(4)] [INT32(5) INT32(4)]]",
"[[INT64(2) DECIMAL(2)] [INT64(1) DECIMAL(0)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(3)]]",
"[[INT64(1) INT64(3)]]",
"[[INT64(1) VARCHAR(\"Article 1\") INT64(10)]]",
"[[INT64(2) VARCHAR(\"Article 2\") INT64(10)]]",
"[[INT64(1) VARCHAR(\"Article 1\") INT64(10)]]",
"[[INT64(2) VARCHAR(\"Article 2\") INT64(10)]]",
"[[VARCHAR(\"albumQ\") INT32(4)] [VARCHAR(\"albumY\") INT32(1)] [VARCHAR(\"albumY\") INT32(2)] [VARCHAR(\"albumX\") INT32(2)] [VARCHAR(\"albumX\") INT32(3)] [VARCHAR(\"albumX\") INT32(1)]]",
"[[VARCHAR(\"albumQ\") INT32(4)] [VARCHAR(\"albumY\") INT32(1)] [VARCHAR(\"albumY\") INT32(2)] [VARCHAR(\"albumX\") INT32(2)] [VARCHAR(\"albumX\") INT32(3)] [VARCHAR(\"albumX\") INT32(1)]]",
"[[INT64(2)]]",
"[[INT32(1) INT32(100)]]",
"[[INT32(1) INT32(100)]]",
"[[INT64(2)]]",
"[[INT64(2)]]",
"[[INT64(2)]]",
"[[INT64(2)]]",
"[[UINT32(70)]]",
"[[INT64(1) VARCHAR(\"Article 1\") INT64(20)]]",
"[[INT64(2) VARCHAR(\"Article 2\") INT64(20)]]",
"[[INT64(1) VARCHAR(\"Article 1\") INT64(20)]]",
"[[INT64(2) VARCHAR(\"Article 2\") INT64(20)]]",
"[[INT64(1) VARCHAR(\"Article 1\") INT64(10)]]",
"[[INT64(2) VARCHAR(\"Article 2\") INT64(10)]]",
"[]",
"[]",
"[[INT64(1) NULL] [INT64(2) INT64(2)]]",
"[[INT64(1) INT64(1)] [INT64(2) NULL]]",
"[[INT64(1) INT64(1)]]",
"[[INT64(1) INT64(8)] [INT64(1) INT64(9)]]",
"[[INT64(1)] [INT64(2)]]",
"[[INT64(1)]]",
"[[INT64(4)] [INT64(8)] [INT64(12)]]",
"[[INT64(1)]]",
"[[INT64(1)]]",
"[[INT64(1)]]",
"[]",
"[]",
"[[INT64(1)]]",
"[[INT64(1)]]",
"[[INT64(1)]]",
"[]",
"[]",
"[[INT64(1)]]",
"[[INT64(1)]]",
"[[DECIMAL(2) INT64(1)]]",
"[[NULL INT64(0)]]",
"[[DECIMAL(420) INT64(1)]]",
"[[DECIMAL(420) INT64(1)]]",
"[[NULL INT64(0)]]",
"[]",
"[[NULL INT64(0)]]",
"[]",
"[[DECIMAL(3) INT64(3)]]",
"[[DECIMAL(2) INT64(1)] [DECIMAL(1) INT64(1)] [DECIMAL(0) INT64(1)]]",
"[[NULL INT64(0)]]",
"[]",
"[[DECIMAL(423) INT64(4)]]",
"[[DECIMAL(423) INT64(4)]]",
"[[DECIMAL(420) INT64(1)] [DECIMAL(2) INT64(1)] [DECIMAL(1) INT64(1)] [DECIMAL(0) INT64(1)]]",
"[[DECIMAL(420) INT64(1)]]",
"[[DECIMAL(420) INT64(1)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)] [INT64(1) INT64(4)]]",
"[[INT64(1) INT64(4)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(1) INT64(2)]]",
"[[INT64(2) INT64(4)]]",
"[[INT64(2) INT64(4)]]",
"[[INT64(1) INT64(1)] [INT64(1) INT64(2)] [INT64(1) INT64(3)]]",
"[[INT64(1) INT64(1)] [INT64(1) INT64(2)] [INT64(1) INT64(3)] [INT64(1) INT64(4)] [INT64(1) INT64(5)] [INT64(1) INT64(6)]]",
"[[INT64(1) INT64(1)] [INT64(1) INT64(2)] [INT64(1) INT64(3)]]",
}

func TestRowParsing(t *testing.T) {
for _, r := range TestRows {
output, err := ParseRows(r)
require.NoError(t, err)
outputstr := fmt.Sprintf("%v", output)
require.Equal(t, r, outputstr, "did not roundtrip")
}
}

func TestRowsEquals(t *testing.T) {
var cases = []struct {
left, right string
}{
{"[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]", "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]"},
}

for _, tc := range cases {
left, err := ParseRows(tc.left)
require.NoError(t, err)

right, err := ParseRows(tc.right)
require.NoError(t, err)

err = RowsEquals(left, right)
require.NoError(t, err)
}
}
11 changes: 6 additions & 5 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
)

type MySQLCompare struct {
Expand Down Expand Up @@ -132,8 +131,9 @@ func (mcmp *MySQLCompare) AssertContainsError(query, expected string) {
func (mcmp *MySQLCompare) AssertMatchesNoOrder(query, expected string) {
mcmp.t.Helper()
qr := mcmp.Exec(query)
actual := fmt.Sprintf("%v", qr.Rows)
assert.Equal(mcmp.t, utils.SortString(expected), utils.SortString(actual), "for query: [%s] expected \n%s \nbut actual \n%s", query, expected, actual)
if err := sqltypes.RowsEqualsStr(expected, qr.Rows); err != nil {
mcmp.t.Errorf("for query [%s] %v", query, err)
}
}

// AssertMatchesNoOrderInclColumnNames executes the given query against both Vitess and MySQL.
Expand All @@ -142,8 +142,9 @@ func (mcmp *MySQLCompare) AssertMatchesNoOrder(query, expected string) {
func (mcmp *MySQLCompare) AssertMatchesNoOrderInclColumnNames(query, expected string) {
mcmp.t.Helper()
qr := mcmp.ExecWithColumnCompare(query)
actual := fmt.Sprintf("%v", qr.Rows)
assert.Equal(mcmp.t, utils.SortString(expected), utils.SortString(actual), "for query: [%s] expected \n%s \nbut actual \n%s", query, expected, actual)
if err := sqltypes.RowsEqualsStr(expected, qr.Rows); err != nil {
mcmp.t.Errorf("for query [%s] %v", query, err)
}
}

// AssertIsEmpty executes the given query against both Vitess and MySQL and ensures
Expand Down
Loading