Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions data/test/vtexplain/insertsharded-output.json
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,14 @@
}
},
{
"Original": "insert ignore into user(id, name, nickname) values (:vtg1, :vtg2, :vtg2)",
"Original": "insert ignore into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3)",
"Instructions": {
"Opcode": "InsertShardedIgnore",
"Keyspace": {
"Name": "ks_sharded",
"Sharded": true
},
"Query": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg2)",
"Query": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg3)",
"Values": [
[
":vtg1"
Expand All @@ -339,7 +339,7 @@
"Table": "user",
"Prefix": "insert ignore into user(id, name, nickname) values ",
"Mid": [
"(:_id0, :_name0, :vtg2)"
"(:_id0, :_name0, :vtg3)"
]
}
}
Expand All @@ -348,12 +348,13 @@
"ks_sharded/-80": {
"TabletQueries": [
{
"SQL": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg2) /* vtgate:: keyspace_id:06e7ea22ce92708f */",
"SQL": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg3) /* vtgate:: keyspace_id:06e7ea22ce92708f */",
"BindVars": {
"_id0": "2",
"_name0": "'bob'",
"vtg1": "2",
"vtg2": "'bob'"
"vtg2": "'bob'",
"vtg3": "'bob'"
}
}
],
Expand Down Expand Up @@ -562,14 +563,14 @@
}
},
{
"Original": "insert into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3) on duplicate key update nickname = :vtg3",
"Original": "insert into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3) on duplicate key update nickname = :vtg4",
"Instructions": {
"Opcode": "InsertShardedIgnore",
"Keyspace": {
"Name": "ks_sharded",
"Sharded": true
},
"Query": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg3",
"Query": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg4",
"Values": [
[
":vtg1"
Expand All @@ -583,21 +584,22 @@
"Mid": [
"(:_id0, :_name0, :vtg3)"
],
"Suffix": " on duplicate key update nickname = :vtg3"
"Suffix": " on duplicate key update nickname = :vtg4"
}
}
],
"TabletActions": {
"ks_sharded/-80": {
"TabletQueries": [
{
"SQL": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg3 /* vtgate:: keyspace_id:06e7ea22ce92708f */",
"SQL": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg4 /* vtgate:: keyspace_id:06e7ea22ce92708f */",
"BindVars": {
"_id0": "2",
"_name0": "'bob'",
"vtg1": "2",
"vtg2": "'bob'",
"vtg3": "'bobby'"
"vtg3": "'bobby'",
"vtg4": "'bobby'"
}
}
],
Expand Down
220 changes: 146 additions & 74 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,78 +28,150 @@ import (
// updates the bind vars to those values. The supplied prefix
// is used to generate the bind var names. The function ensures
// that there are no collisions with existing bind vars.
// Within Select constructs, bind vars are deduped. This allows
// us to identify vindex equality. Otherwise, every value is
// treated as distinct.
func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) {
reserved := GetBindvars(stmt)
// vals allows us to reuse bindvars for
// identical values.
counter := 1
vals := make(map[string]string)
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *SQLVal:
// Make the bindvar
bval := sqlToBindvar(node)
if bval == nil {
// If unsuccessful continue.
return true, nil
}
// Check if there's a bindvar for that value already.
var key string
if bval.Type == sqltypes.VarBinary {
// Prefixing strings with "'" ensures that a string
// and number that have the same representation don't
// collide.
key = "'" + string(node.Val)
} else {
key = string(node.Val)
}
bvname, ok := vals[key]
if !ok {
// If there's no such bindvar, make a new one.
bvname, counter = newName(prefix, counter, reserved)
vals[key] = bvname
bindVars[bvname] = bval
}
// Modify the AST node to a bindvar.
node.Type = ValArg
node.Val = append([]byte(":"), bvname...)
case *ComparisonExpr:
switch node.Operator {
case InStr, NotInStr:
default:
return true, nil
}
// It's either IN or NOT IN.
tupleVals, ok := node.Right.(ValTuple)
if !ok {
return true, nil
}
// The RHS is a tuple of values.
// Make a list bindvar.
bvals := &querypb.BindVariable{
Type: querypb.Type_TUPLE,
}
for _, val := range tupleVals {
bval := sqlToBindvar(val)
if bval == nil {
return true, nil
}
bvals.Values = append(bvals.Values, &querypb.Value{
Type: bval.Type,
Value: bval.Value,
})
}
var bvname string
bvname, counter = newName(prefix, counter, reserved)
bindVars[bvname] = bvals
// Modify RHS to be a list bindvar.
node.Right = ListArg(append([]byte("::"), bvname...))
nz := newNormalizer(stmt, bindVars, prefix)
_ = Walk(nz.WalkStatement, stmt)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this particular change, but it seems like we should be propagating errors in the Normalize traversal out instead of silently swallowing them?

For PII reasons we actually do depend on normalize working properly, so it would be better for us to fail the query with an error than let it go through unnormalized.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the error conditions. They are all for invalid inputs, which will be caught downstream. So, it's better that normalizer ignores them and completes everything else. That way, any other PII would have been converted to bind vars.

}

type normalizer struct {
stmt Statement
bindVars map[string]*querypb.BindVariable
prefix string
reserved map[string]struct{}
counter int
vals map[string]string
}

func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer {
return &normalizer{
stmt: stmt,
bindVars: bindVars,
prefix: prefix,
reserved: GetBindvars(stmt),
counter: 1,
vals: make(map[string]string),
}
}

// WalkStatement is the top level walk function.
// If it encounters a Select, it switches to a mode
// where variables are deduped.
func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) {
switch node := node.(type) {
case *Select:
_ = Walk(nz.WalkSelect, node)
// Don't continue
return false, nil
case *SQLVal:
nz.convertSQLVal(node)
case *ComparisonExpr:
nz.convertComparison(node)
}
return true, nil
}

// WalkSelect normalizes the AST in Select mode.
func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) {
switch node := node.(type) {
case *SQLVal:
nz.convertSQLValDedup(node)
case *ComparisonExpr:
nz.convertComparison(node)
}
return true, nil
}

func (nz *normalizer) convertSQLValDedup(node *SQLVal) {
// If value is too long, don't dedup.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason for this because the comparison is likely to be expensive CPU wise? If so we should add an additional comment indicating why.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

// Such values are most likely not for vindexes.
// We save a lot of CPU because we avoid building
// the key for them.
if len(node.Val) > 256 {
nz.convertSQLVal(node)
return
}

// Make the bindvar
bval := nz.sqlToBindvar(node)
if bval == nil {
return
}

// Check if there's a bindvar for that value already.
var key string
if bval.Type == sqltypes.VarBinary {
// Prefixing strings with "'" ensures that a string
// and number that have the same representation don't
// collide.
key = "'" + string(node.Val)
} else {
key = string(node.Val)
}
bvname, ok := nz.vals[key]
if !ok {
// If there's no such bindvar, make a new one.
bvname = nz.newName()
nz.vals[key] = bvname
nz.bindVars[bvname] = bval
}

// Modify the AST node to a bindvar.
node.Type = ValArg
node.Val = append([]byte(":"), bvname...)
}

// convertSQLVal converts an SQLVal without the dedup.
func (nz *normalizer) convertSQLVal(node *SQLVal) {
bval := nz.sqlToBindvar(node)
if bval == nil {
return
}

bvname := nz.newName()
nz.bindVars[bvname] = bval

node.Type = ValArg
node.Val = append([]byte(":"), bvname...)
}

// convertComparison attempts to convert IN clauses to
// use the list bind var construct. If it fails, it returns
// with no change made. The walk function will then continue
// and iterate on converting each individual value into separate
// bind vars.
func (nz *normalizer) convertComparison(node *ComparisonExpr) {
if node.Operator != InStr && node.Operator != NotInStr {
return
}
tupleVals, ok := node.Right.(ValTuple)
if !ok {
return
}
// The RHS is a tuple of values.
// Make a list bindvar.
bvals := &querypb.BindVariable{
Type: querypb.Type_TUPLE,
}
for _, val := range tupleVals {
bval := nz.sqlToBindvar(val)
if bval == nil {
return
}
return true, nil
}, stmt)
bvals.Values = append(bvals.Values, &querypb.Value{
Type: bval.Type,
Value: bval.Value,
})
}
bvname := nz.newName()
nz.bindVars[bvname] = bvals
// Modify RHS to be a list bindvar.
node.Right = ListArg(append([]byte("::"), bvname...))
}

func sqlToBindvar(node SQLNode) *querypb.BindVariable {
func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on the above comment re error propagation, this should also return an error if the value should have been converted to a bind var but failed for some reason.

if node, ok := node.(*SQLVal); ok {
var v sqltypes.Value
var err error
Expand All @@ -121,14 +193,14 @@ func sqlToBindvar(node SQLNode) *querypb.BindVariable {
return nil
}

func newName(prefix string, counter int, reserved map[string]struct{}) (string, int) {
func (nz *normalizer) newName() string {
for {
newName := fmt.Sprintf("%s%d", prefix, counter)
if _, ok := reserved[newName]; !ok {
reserved[newName] = struct{}{}
return newName, counter + 1
newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter)
if _, ok := nz.reserved[newName]; !ok {
nz.reserved[newName] = struct{}{}
return newName
}
counter++
nz.counter++
}
}

Expand Down
51 changes: 51 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package sqlparser

import (
"fmt"
"reflect"
"testing"

Expand Down Expand Up @@ -81,6 +82,56 @@ func TestNormalize(t *testing.T) {
"bv1": sqltypes.Int64BindVariable(1),
"bv2": sqltypes.BytesBindVariable([]byte("1")),
},
}, {
// val should not be reused for non-select statements
in: "insert into a values(1, 1)",
outstmt: "insert into a values (:bv1, :bv2)",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(1),
"bv2": sqltypes.Int64BindVariable(1),
},
}, {
// val should be reused only in subqueries of DMLs
in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5",
outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv2, v3 = (select :bv1 from t), v4 = :bv3",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(5),
"bv2": sqltypes.Int64BindVariable(5),
"bv3": sqltypes.Int64BindVariable(5),
},
}, {
// list vars should work for DMLs also
in: "update a set v1=5 where v2 in (1, 4, 5)",
outstmt: "update a set v1 = :bv1 where v2 in ::bv2",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(5),
"bv2": sqltypes.TestBindVariable([]interface{}{1, 4, 5}),
},
}, {
// Hex value does not convert
in: "select * from t where v1 = 0x1234",
outstmt: "select * from t where v1 = 0x1234",
outbv: map[string]*querypb.BindVariable{},
}, {
// Hex value does not convert for DMLs
in: "update a set v1 = 0x1234",
outstmt: "update a set v1 = 0x1234",
outbv: map[string]*querypb.BindVariable{},
}, {
// Values up to len 256 will reuse.
in: fmt.Sprintf("select * from t where v1 = '%256s' and v2 = '%256s'", "a", "a"),
outstmt: "select * from t where v1 = :bv1 and v2 = :bv1",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%256s", "a"))),
},
}, {
// Values greater than len 256 will not reuse.
in: fmt.Sprintf("select * from t where v1 = '%257s' and v2 = '%257s'", "b", "b"),
outstmt: "select * from t where v1 = :bv1 and v2 = :bv2",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))),
"bv2": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))),
},
}, {
// bad int
in: "select * from t where v1 = 12345678901234567890",
Expand Down