From afcbe2823e0546a623dcf40d34783141cf4e9060 Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Tue, 27 Aug 2024 14:29:04 -0700 Subject: [PATCH] Respond to PR feedback. --- enginetest/queries/generated_columns.go | 40 ++++++++++++++----------- memory/table_editor.go | 26 ++++++++-------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/enginetest/queries/generated_columns.go b/enginetest/queries/generated_columns.go index 72ce1fc124..d15c823f2c 100644 --- a/enginetest/queries/generated_columns.go +++ b/enginetest/queries/generated_columns.go @@ -275,8 +275,8 @@ var GeneratedColumnTests = []ScriptTest{ { Name: "creating unique index on stored generated column", SetUpScript: []string{ - "create table t1 (a int primary key, b int as (a + 1) stored)", - "insert into t1(a) values (1), (2)", + "create table t1 (a int primary key, b int as (a * a) stored)", + "insert into t1(a) values (-1), (-2)", }, Assertions: []ScriptTestAssertion{ { @@ -288,22 +288,26 @@ var GeneratedColumnTests = []ScriptTest{ Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + " `a` int NOT NULL,\n" + - " `b` int GENERATED ALWAYS AS ((`a` + 1)) STORED,\n" + + " `b` int GENERATED ALWAYS AS ((`a` * `a`)) STORED,\n" + " PRIMARY KEY (`a`),\n" + " UNIQUE KEY `i1` (`b`)\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, }, { - Query: "select * from t1 where b = 2 order by a", - Expected: []sql.Row{{1, 2}}, + Query: "select * from t1 where b = 4 order by a", + Expected: []sql.Row{{-2, 4}}, }, { Query: "select * from t1 order by a", - Expected: []sql.Row{{1, 2}, {2, 3}}, + Expected: []sql.Row{{-2, 4}, {-1, 1}}, }, { Query: "select * from t1 order by b", - Expected: []sql.Row{{1, 2}, {2, 3}}, + Expected: []sql.Row{{-1, 1}, {-2, 4}}, + }, + { + Query: "insert into t1(a) values (2)", + ExpectedErr: sql.ErrUniqueKeyViolation, }, }, }, @@ -332,7 +336,6 @@ var GeneratedColumnTests = []ScriptTest{ { Query: "select * from t1 where b = 2 order by a", Expected: []sql.Row{{1, 2}}, - Skip: true, // https://github.com/dolthub/dolt/issues/8276 }, { Query: "select * from t1 order by a", @@ -404,7 +407,6 @@ var GeneratedColumnTests = []ScriptTest{ { Query: "select * from t1 where b = 2 order by a", Expected: []sql.Row{{1, float64(2)}}, - Skip: true, // https://github.com/dolthub/dolt/issues/8276 }, { Query: "select * from t1 order by a", @@ -1086,7 +1088,6 @@ var GeneratedColumnTests = []ScriptTest{ { Query: "select * from t1 where b = 2 order by a", Expected: []sql.Row{{1, float64(2)}}, - Skip: true, // https://github.com/dolthub/dolt/issues/8276 }, { Query: "select * from t1 order by a", @@ -1139,8 +1140,8 @@ var GeneratedColumnTests = []ScriptTest{ { Name: "creating unique index on virtual generated column", SetUpScript: []string{ - "create table t1 (a int primary key, b int as (a + 1) virtual)", - "insert into t1(a) values (1), (2)", + "create table t1 (a int primary key, b int as (a * a) virtual)", + "insert into t1(a) values (-1), (-2)", }, Assertions: []ScriptTestAssertion{ { @@ -1152,23 +1153,28 @@ var GeneratedColumnTests = []ScriptTest{ Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + " `a` int NOT NULL,\n" + - " `b` int GENERATED ALWAYS AS ((`a` + 1)),\n" + + " `b` int GENERATED ALWAYS AS ((`a` * `a`)),\n" + " PRIMARY KEY (`a`),\n" + " KEY `i1` (`b`)\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, Skip: true, // https://github.com/dolthub/dolt/issues/8275 }, { - Query: "select * from t1 where b = 2 order by a", - Expected: []sql.Row{{1, 2}}, + Query: "select * from t1 where b = 4 order by a", + Expected: []sql.Row{{-2, 4}}, }, { Query: "select * from t1 order by a", - Expected: []sql.Row{{1, 2}, {2, 3}}, + Expected: []sql.Row{{-2, 4}, {-1, 1}}, }, { Query: "select * from t1 order by b", - Expected: []sql.Row{{1, 2}, {2, 3}}, + Expected: []sql.Row{{-1, 1}, {-2, 4}}, + }, + { + Query: "insert into t1(a) values (2)", + ExpectedErr: sql.ErrUniqueKeyViolation, + Skip: true, }, }, }, diff --git a/memory/table_editor.go b/memory/table_editor.go index 6267ddfb8b..ff61ed4db4 100644 --- a/memory/table_editor.go +++ b/memory/table_editor.go @@ -28,7 +28,6 @@ import ( type tableEditor struct { editedTable *Table initialTable *Table - schema sql.Schema discardChanges bool ea tableEditAccumulator @@ -314,12 +313,15 @@ func (t *tableEditor) pkColumnIndexes() []int { func (t *tableEditor) pkColsDiffer(row, row2 sql.Row) bool { pkColIdxes := t.pkColumnIndexes() - return !columnsMatch(pkColIdxes, nil, row, row2) + return !columnsMatch(pkColIdxes, nil, row, row2, t.Schema()) } // Returns whether the values for the columns given match in the two rows provided -func columnsMatch(colIndexes []int, prefixLengths []uint16, row sql.Row, row2 sql.Row) bool { +func columnsMatch(colIndexes []int, prefixLengths []uint16, row sql.Row, row2 sql.Row, schema sql.Schema) bool { for i, idx := range colIndexes { + if schema[idx].Virtual { + continue + } v1 := row[idx] v2 := row2[idx] if len(prefixLengths) > i && prefixLengths[i] > 0 { @@ -446,7 +448,7 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) { pkColIdxes := pke.pkColumnIndexes() for _, partition := range pke.tableData.partitions { for _, partitionRow := range partition { - if columnsMatch(pkColIdxes, nil, partitionRow, value) { + if columnsMatch(pkColIdxes, nil, partitionRow, value, pke.tableData.schema.Schema) { return partitionRow, true, nil } } @@ -459,20 +461,20 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) { func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) { // If we have this row in any delete, bail. if _, _, exists := pke.deletes.FindForeach(func(key string, r sql.Row) bool { - return columnsMatch(cols, prefixLengths, r, value) + return columnsMatch(cols, prefixLengths, r, value, pke.tableData.schema.Schema) }); exists { return nil, false, nil } if _, r, exists := pke.adds.FindForeach(func(key string, r sql.Row) bool { - return columnsMatch(cols, prefixLengths, r, value) + return columnsMatch(cols, prefixLengths, r, value, pke.tableData.schema.Schema) }); exists { return r, true, nil } for _, partition := range pke.tableData.partitions { for _, partitionRow := range partition { - if columnsMatch(cols, prefixLengths, partitionRow, value) { + if columnsMatch(cols, prefixLengths, partitionRow, value, pke.tableData.schema.Schema) { return partitionRow, true, nil } } @@ -541,7 +543,7 @@ func (pke *pkTableEditAccumulator) deleteHelper(table *TableData, row sql.Row) e // have the row to be replaced, so we need to consider primary key information. pkColIdxes := pke.pkColumnIndexes() if len(pkColIdxes) > 0 { - if columnsMatch(pkColIdxes, nil, partitionRow, row) { + if columnsMatch(pkColIdxes, nil, partitionRow, row, pke.tableData.schema.Schema) { table.partitions[partName] = append(partition[:partitionRowIndex], partition[partitionRowIndex+1:]...) partKey = partName rowIdx = partitionRowIndex @@ -608,7 +610,7 @@ func (pke *pkTableEditAccumulator) insertHelper(table *TableData, row sql.Row) e if len(pkColIdxes) > 0 { for partitionIndex, partition := range table.partitions { for partitionRowIndex, partitionRow := range partition { - if columnsMatch(pkColIdxes, nil, partitionRow, row) { + if columnsMatch(pkColIdxes, nil, partitionRow, row, pke.tableData.schema.Schema) { // Instead of throwing a unique key error, we perform an update operation to essentially represent // map semantics for the keyed table. savedPartitionIndex = partitionIndex @@ -714,14 +716,14 @@ func (k *keylessTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) func (k *keylessTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) { deleteCount := 0 for _, r := range k.deletes { - if columnsMatch(cols, prefixLengths, r, value) { + if columnsMatch(cols, prefixLengths, r, value, k.tableData.schema.Schema) { deleteCount++ } } for _, partition := range k.tableData.partitions { for _, partitionRow := range partition { - if columnsMatch(cols, prefixLengths, partitionRow, value) { + if columnsMatch(cols, prefixLengths, partitionRow, value, k.tableData.schema.Schema) { if deleteCount == 0 { return partitionRow, true, nil } @@ -731,7 +733,7 @@ func (k *keylessTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefi } for _, r := range k.adds { - if columnsMatch(cols, prefixLengths, r, value) { + if columnsMatch(cols, prefixLengths, r, value, nil) { if deleteCount == 0 { return r, true, nil }