Skip to content

Commit

Permalink
Respond to PR feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicktobey committed Aug 27, 2024
1 parent 5bea118 commit afcbe28
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
40 changes: 23 additions & 17 deletions enginetest/queries/generated_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ var GeneratedColumnTests = []ScriptTest{
{
Name: "creating unique index on stored generated column",
SetUpScript: []string{
"create table t1 (a int primary key, b int as (a + 1) stored)",
"insert into t1(a) values (1), (2)",
"create table t1 (a int primary key, b int as (a * a) stored)",
"insert into t1(a) values (-1), (-2)",
},
Assertions: []ScriptTestAssertion{
{
Expand All @@ -288,22 +288,26 @@ var GeneratedColumnTests = []ScriptTest{
Expected: []sql.Row{{"t1",
"CREATE TABLE `t1` (\n" +
" `a` int NOT NULL,\n" +
" `b` int GENERATED ALWAYS AS ((`a` + 1)) STORED,\n" +
" `b` int GENERATED ALWAYS AS ((`a` * `a`)) STORED,\n" +
" PRIMARY KEY (`a`),\n" +
" UNIQUE KEY `i1` (`b`)\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "select * from t1 where b = 2 order by a",
Expected: []sql.Row{{1, 2}},
Query: "select * from t1 where b = 4 order by a",
Expected: []sql.Row{{-2, 4}},
},
{
Query: "select * from t1 order by a",
Expected: []sql.Row{{1, 2}, {2, 3}},
Expected: []sql.Row{{-2, 4}, {-1, 1}},
},
{
Query: "select * from t1 order by b",
Expected: []sql.Row{{1, 2}, {2, 3}},
Expected: []sql.Row{{-1, 1}, {-2, 4}},
},
{
Query: "insert into t1(a) values (2)",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
},
},
Expand Down Expand Up @@ -332,7 +336,6 @@ var GeneratedColumnTests = []ScriptTest{
{
Query: "select * from t1 where b = 2 order by a",
Expected: []sql.Row{{1, 2}},
Skip: true, // https://github.com/dolthub/dolt/issues/8276
},
{
Query: "select * from t1 order by a",
Expand Down Expand Up @@ -404,7 +407,6 @@ var GeneratedColumnTests = []ScriptTest{
{
Query: "select * from t1 where b = 2 order by a",
Expected: []sql.Row{{1, float64(2)}},
Skip: true, // https://github.com/dolthub/dolt/issues/8276
},
{
Query: "select * from t1 order by a",
Expand Down Expand Up @@ -1086,7 +1088,6 @@ var GeneratedColumnTests = []ScriptTest{
{
Query: "select * from t1 where b = 2 order by a",
Expected: []sql.Row{{1, float64(2)}},
Skip: true, // https://github.com/dolthub/dolt/issues/8276
},
{
Query: "select * from t1 order by a",
Expand Down Expand Up @@ -1139,8 +1140,8 @@ var GeneratedColumnTests = []ScriptTest{
{
Name: "creating unique index on virtual generated column",
SetUpScript: []string{
"create table t1 (a int primary key, b int as (a + 1) virtual)",
"insert into t1(a) values (1), (2)",
"create table t1 (a int primary key, b int as (a * a) virtual)",
"insert into t1(a) values (-1), (-2)",
},
Assertions: []ScriptTestAssertion{
{
Expand All @@ -1152,23 +1153,28 @@ var GeneratedColumnTests = []ScriptTest{
Expected: []sql.Row{{"t1",
"CREATE TABLE `t1` (\n" +
" `a` int NOT NULL,\n" +
" `b` int GENERATED ALWAYS AS ((`a` + 1)),\n" +
" `b` int GENERATED ALWAYS AS ((`a` * `a`)),\n" +
" PRIMARY KEY (`a`),\n" +
" KEY `i1` (`b`)\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
Skip: true, // https://github.com/dolthub/dolt/issues/8275
},
{
Query: "select * from t1 where b = 2 order by a",
Expected: []sql.Row{{1, 2}},
Query: "select * from t1 where b = 4 order by a",
Expected: []sql.Row{{-2, 4}},
},
{
Query: "select * from t1 order by a",
Expected: []sql.Row{{1, 2}, {2, 3}},
Expected: []sql.Row{{-2, 4}, {-1, 1}},
},
{
Query: "select * from t1 order by b",
Expected: []sql.Row{{1, 2}, {2, 3}},
Expected: []sql.Row{{-1, 1}, {-2, 4}},
},
{
Query: "insert into t1(a) values (2)",
ExpectedErr: sql.ErrUniqueKeyViolation,
Skip: true,
},
},
},
Expand Down
26 changes: 14 additions & 12 deletions memory/table_editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
type tableEditor struct {
editedTable *Table
initialTable *Table
schema sql.Schema

discardChanges bool
ea tableEditAccumulator
Expand Down Expand Up @@ -314,12 +313,15 @@ func (t *tableEditor) pkColumnIndexes() []int {

func (t *tableEditor) pkColsDiffer(row, row2 sql.Row) bool {
pkColIdxes := t.pkColumnIndexes()
return !columnsMatch(pkColIdxes, nil, row, row2)
return !columnsMatch(pkColIdxes, nil, row, row2, t.Schema())
}

// Returns whether the values for the columns given match in the two rows provided
func columnsMatch(colIndexes []int, prefixLengths []uint16, row sql.Row, row2 sql.Row) bool {
func columnsMatch(colIndexes []int, prefixLengths []uint16, row sql.Row, row2 sql.Row, schema sql.Schema) bool {
for i, idx := range colIndexes {
if schema[idx].Virtual {
continue
}
v1 := row[idx]
v2 := row2[idx]
if len(prefixLengths) > i && prefixLengths[i] > 0 {
Expand Down Expand Up @@ -446,7 +448,7 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
pkColIdxes := pke.pkColumnIndexes()
for _, partition := range pke.tableData.partitions {
for _, partitionRow := range partition {
if columnsMatch(pkColIdxes, nil, partitionRow, value) {
if columnsMatch(pkColIdxes, nil, partitionRow, value, pke.tableData.schema.Schema) {
return partitionRow, true, nil
}
}
Expand All @@ -459,20 +461,20 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) {
// If we have this row in any delete, bail.
if _, _, exists := pke.deletes.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
return columnsMatch(cols, prefixLengths, r, value, pke.tableData.schema.Schema)
}); exists {
return nil, false, nil
}

if _, r, exists := pke.adds.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
return columnsMatch(cols, prefixLengths, r, value, pke.tableData.schema.Schema)
}); exists {
return r, true, nil
}

for _, partition := range pke.tableData.partitions {
for _, partitionRow := range partition {
if columnsMatch(cols, prefixLengths, partitionRow, value) {
if columnsMatch(cols, prefixLengths, partitionRow, value, pke.tableData.schema.Schema) {
return partitionRow, true, nil
}
}
Expand Down Expand Up @@ -541,7 +543,7 @@ func (pke *pkTableEditAccumulator) deleteHelper(table *TableData, row sql.Row) e
// have the row to be replaced, so we need to consider primary key information.
pkColIdxes := pke.pkColumnIndexes()
if len(pkColIdxes) > 0 {
if columnsMatch(pkColIdxes, nil, partitionRow, row) {
if columnsMatch(pkColIdxes, nil, partitionRow, row, pke.tableData.schema.Schema) {
table.partitions[partName] = append(partition[:partitionRowIndex], partition[partitionRowIndex+1:]...)
partKey = partName
rowIdx = partitionRowIndex
Expand Down Expand Up @@ -608,7 +610,7 @@ func (pke *pkTableEditAccumulator) insertHelper(table *TableData, row sql.Row) e
if len(pkColIdxes) > 0 {
for partitionIndex, partition := range table.partitions {
for partitionRowIndex, partitionRow := range partition {
if columnsMatch(pkColIdxes, nil, partitionRow, row) {
if columnsMatch(pkColIdxes, nil, partitionRow, row, pke.tableData.schema.Schema) {
// Instead of throwing a unique key error, we perform an update operation to essentially represent
// map semantics for the keyed table.
savedPartitionIndex = partitionIndex
Expand Down Expand Up @@ -714,14 +716,14 @@ func (k *keylessTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error)
func (k *keylessTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) {
deleteCount := 0
for _, r := range k.deletes {
if columnsMatch(cols, prefixLengths, r, value) {
if columnsMatch(cols, prefixLengths, r, value, k.tableData.schema.Schema) {
deleteCount++
}
}

for _, partition := range k.tableData.partitions {
for _, partitionRow := range partition {
if columnsMatch(cols, prefixLengths, partitionRow, value) {
if columnsMatch(cols, prefixLengths, partitionRow, value, k.tableData.schema.Schema) {
if deleteCount == 0 {
return partitionRow, true, nil
}
Expand All @@ -731,7 +733,7 @@ func (k *keylessTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefi
}

for _, r := range k.adds {
if columnsMatch(cols, prefixLengths, r, value) {
if columnsMatch(cols, prefixLengths, r, value, nil) {
if deleteCount == 0 {
return r, true, nil
}
Expand Down

0 comments on commit afcbe28

Please sign in to comment.