diff --git a/go/libraries/doltcore/doltdb/commit.go b/go/libraries/doltcore/doltdb/commit.go index 6ff527070e7..83829d6918c 100644 --- a/go/libraries/doltcore/doltdb/commit.go +++ b/go/libraries/doltcore/doltdb/commit.go @@ -303,12 +303,14 @@ type PendingCommit struct { // commit, once written. // |headRef| is the ref of the HEAD the commit will update // |mergeParentCommits| are any merge parents for this commit +// |amend| is a flag which indicates that additional parents should not be added to the provided |mergeParentCommits|. // |cm| is the metadata for the commit // The current branch head will be automatically filled in as the first parent at commit time. func (ddb *DoltDB) NewPendingCommit( ctx context.Context, roots Roots, mergeParentCommits []*Commit, + amend bool, cm *datas.CommitMeta, ) (*PendingCommit, error) { newstaged, val, err := ddb.writeRootValue(ctx, roots.Staged) @@ -322,7 +324,7 @@ func (ddb *DoltDB) NewPendingCommit( parents = append(parents, pc.dCommit.Addr()) } - commitOpts := datas.CommitOptions{Parents: parents, Meta: cm} + commitOpts := datas.CommitOptions{Parents: parents, Meta: cm, Amend: amend} return &PendingCommit{ Roots: roots, Val: val, diff --git a/go/libraries/doltcore/env/actions/commit.go b/go/libraries/doltcore/env/actions/commit.go index 0bebcf21f59..95ea8fbdeed 100644 --- a/go/libraries/doltcore/env/actions/commit.go +++ b/go/libraries/doltcore/env/actions/commit.go @@ -98,9 +98,7 @@ func GetCommitStaged( return nil, NewTblSchemaConflictError(schConflicts) } } - } - if !props.Force { roots.Staged, err = doltdb.ValidateForeignKeysOnSchemas(ctx, roots.Staged) if err != nil { return nil, err @@ -112,5 +110,5 @@ func GetCommitStaged( return nil, err } - return db.NewPendingCommit(ctx, roots, mergeParents, meta) + return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta) } diff --git a/go/libraries/doltcore/migrate/progress.go b/go/libraries/doltcore/migrate/progress.go index 5764be23a3c..fa48bef7747 100644 --- a/go/libraries/doltcore/migrate/progress.go +++ b/go/libraries/doltcore/migrate/progress.go @@ -273,7 +273,7 @@ func commitRoot( return err } - pcm, err := ddb.NewPendingCommit(ctx, roots, parents, meta) + pcm, err := ddb.NewPendingCommit(ctx, roots, parents, false, meta) if err != nil { return err } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index cdaf1691004..e8dd927a93f 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -728,7 +728,6 @@ func (d *DoltSession) NewPendingCommit( // See NewPendingCommit func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { headCommit := branchState.headCommit - headHash, _ := headCommit.HashOf() if branchState.WorkingSet() == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead @@ -755,52 +754,18 @@ func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchStat // If the commit message isn't set and we're amending the previous commit, // go ahead and set the commit message from the current HEAD if props.Message == "" && props.Amend { - cs, err := doltdb.NewCommitSpec("HEAD") - if err != nil { - return nil, err - } - - headRef, err := branchState.dbData.Rsr.CWBHeadRef() - if err != nil { - return nil, err - } - optCmt, err := branchState.dbData.Ddb.Resolve(ctx, cs, headRef) - commit, ok := optCmt.ToCommit() - if !ok { - return nil, doltdb.ErrGhostCommitEncountered - } - - meta, err := commit.GetCommitMeta(ctx) + meta, err := headCommit.GetCommitMeta(ctx) if err != nil { return nil, err } props.Message = meta.Description } - - // TODO: This is not the correct way to write this commit as an amend. While this commit is running - // the branch head moves backwards and concurrency control here is not principled. - newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1") - if err != nil { - return nil, err - } - - err = d.SetWorkingSet(ctx, ctx.GetCurrentDatabase(), branchState.WorkingSet().WithStagedRoot(newRoots.Staged)) - if err != nil { - return nil, err - } - - roots.Head = newRoots.Head } pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props) if err != nil { - if props.Amend { - _, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String()) - if err != nil { - return nil, err - } - } - if _, ok := err.(actions.NothingStaged); err != nil && !ok { + // Special case for nothing staged, which is not an error + if _, ok := err.(actions.NothingStaged); !ok { return nil, err } } diff --git a/go/store/datas/commit_options.go b/go/store/datas/commit_options.go index 825a8f37377..ae5ba59b80e 100644 --- a/go/store/datas/commit_options.go +++ b/go/store/datas/commit_options.go @@ -32,5 +32,10 @@ type CommitOptions struct { // parent. Parents []hash.Hash + // Amend flag indicates that the commit being build it to amend an existing commit. Generally we add the branch HEAD + // as a parent, in addition to the parent set provided here. When we amend, we want to strictly use the commits + // provided in |Parents|, and no others. + Amend bool + Meta *CommitMeta } diff --git a/go/store/datas/database_common.go b/go/store/datas/database_common.go index 8019ce5943f..9502cfa06fd 100644 --- a/go/store/datas/database_common.go +++ b/go/store/datas/database_common.go @@ -586,7 +586,7 @@ func (db *database) BuildNewCommit(ctx context.Context, ds Dataset, v types.Valu if ok { opts.Parents = []hash.Hash{headAddr} } - } else { + } else if !opts.Amend { curr, ok := ds.MaybeHeadAddr() if ok { if !hasParentHash(opts, curr) { @@ -901,7 +901,7 @@ func (db *database) CommitWithWorkingSet( // Prepend the current head hash to the list of parents if one was provided. This is only necessary if parents were // provided because we fill it in automatically in buildNewCommit otherwise. - if len(opts.Parents) > 0 { + if len(opts.Parents) > 0 && !opts.Amend { headHash, ok := commitDS.MaybeHeadAddr() if ok { if !hasParentHash(opts, headHash) { diff --git a/integration-tests/go-sql-server-driver/sql_server_commit_concurrency_test.go b/integration-tests/go-sql-server-driver/sql_server_commit_concurrency_test.go new file mode 100644 index 00000000000..3edcaf3e4e6 --- /dev/null +++ b/integration-tests/go-sql-server-driver/sql_server_commit_concurrency_test.go @@ -0,0 +1,256 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver" +) + +func TestCommitConcurrency(t *testing.T) { + t.Parallel() + t.Run("SQL transaction with amend commit", testSQLTransactionWithAmendCommit) + t.Run("SQL racing amend", testSQLRacingAmend) +} + +// testSQLTransactionWithAmendCommit verifies that two transactions started at the same state will not both be able +// to commit using --amend. The first transaction will be able to commit, but the second should get an error. +func testSQLTransactionWithAmendCommit(t *testing.T) { + ctx := context.Background() + u, err := driver.NewDoltUser() + require.NoError(t, err) + t.Cleanup(func() { + u.Cleanup() + }) + + rs, err := u.MakeRepoStore() + require.NoError(t, err) + repo, err := rs.MakeRepo("commit_concurrency_test") + require.NoError(t, err) + + srvSettings := &driver.Server{ + Args: []string{"--port", `{{get_port "server"}}`}, + DynamicPort: "server", + } + var ports DynamicPorts + ports.global = &GlobalPorts + ports.t = t + server := MakeServer(t, repo, srvSettings, &ports) + server.DBName = "commit_concurrency_test" + + // Connect to the database + db, err := server.DB(driver.Connection{User: "root"}) + require.NoError(t, err) + t.Cleanup(func() { + db.Close() + }) + + _, err = db.ExecContext(ctx, ` +CREATE TABLE test_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + value VARCHAR(20) +);`) + require.NoError(t, err) + + _, err = db.ExecContext(ctx, "INSERT INTO test_table (value) VALUES ('initial')") + require.NoError(t, err) + + _, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')") + require.NoError(t, err) + + // Create a new context for the first (failing) transaction + ctx1, cancel1 := context.WithCancel(ctx) + defer cancel1() + tx1, err := db.BeginTx(ctx1, nil) + require.NoError(t, err) + _, err = tx1.ExecContext(ctx1, "UPDATE test_table SET value = 'amended by tx1' WHERE id = 1") + require.NoError(t, err) + + // Create a new context for the second (succeeding) transaction + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + tx2, err := db.BeginTx(ctx2, nil) + require.NoError(t, err) + + // Update data within the second transaction + _, err = tx2.ExecContext(ctx2, "UPDATE test_table SET value = 'amended by tx2' WHERE id = 1") + require.NoError(t, err) + + _, err = tx2.ExecContext(ctx2, "CALL DOLT_COMMIT('--amend', '-m', 'tx2 amended commit')") + require.NoError(t, err) + + // Commit --amend will result in tx2 being committed. You can still make updates on tx1, but any commit should fail + _, err = tx1.ExecContext(ctx1, "INSERT INTO test_table (value) VALUES ('new row by tx1')") + require.NoError(t, err) + + _, err = tx1.ExecContext(ctx1, "CALL DOLT_COMMIT('--amend', '-m', 'should fail')") + require.Error(t, err) + require.Contains(t, err.Error(), "this transaction conflicts with a committed transaction from another client, try restarting transaction") + + // Verify that the data in the head is what we would expect + row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1") + var value string + err = row.Scan(&value) + require.NoError(t, err) + require.Equal(t, "amended by tx2", value) + + // Verify the commit message + row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1") + var commitMessage string + err = row.Scan(&commitMessage) + require.NoError(t, err) + require.Equal(t, "tx2 amended commit", commitMessage) + +} + +func testSQLRacingAmend(t *testing.T) { + ctx := context.Background() + u, err := driver.NewDoltUser() + require.NoError(t, err) + t.Cleanup(func() { + u.Cleanup() + }) + + rs, err := u.MakeRepoStore() + require.NoError(t, err) + repo, err := rs.MakeRepo("racing_amend_test") + require.NoError(t, err) + + srvSettings := &driver.Server{ + Args: []string{"--port", `{{get_port "server"}}`}, + DynamicPort: "server", + } + var ports DynamicPorts + ports.global = &GlobalPorts + ports.t = t + server := MakeServer(t, repo, srvSettings, &ports) + server.DBName = "racing_amend_test" + + db, err := server.DB(driver.Connection{User: "root"}) + require.NoError(t, err) + t.Cleanup(func() { + db.Close() + }) + + _, err = db.ExecContext(ctx, ` + CREATE TABLE test_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + value VARCHAR(20) + );`) + require.NoError(t, err) + + _, err = db.ExecContext(ctx, "INSERT INTO test_table VALUES (1, 'initial')") + require.NoError(t, err) + + _, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')") + require.NoError(t, err) + + type txIdFunc struct { + txNum int + txFunc func() error + } + + var transactions []txIdFunc + for txNum := 1; txNum <= 200; txNum++ { + txCtx, cancel := context.WithTimeout(ctx, 15*time.Second) // Should never hit, but just in case + tx, err := db.BeginTx(txCtx, nil) + require.NoError(t, err) + f := func() error { + defer cancel() + // update required to get a transaction conflict. + _, e2 := tx.ExecContext(txCtx, "UPDATE test_table SET value = ? WHERE id = 1", fmt.Sprintf("tx%d value", txNum)) + require.NoError(t, e2) + _, e2 = tx.ExecContext(txCtx, "INSERT INTO test_table (value) VALUES (?)", fmt.Sprintf("tx%d new row", txNum)) + require.NoError(t, e2) + + // We want other transactions to have the chance to mess with the db. We'll verify that what's committed is what we expect. + time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond) + + // This will commit the transaction, or error. + _, e2 = tx.ExecContext(txCtx, "CALL DOLT_COMMIT('--amend','-a', '-m', ?)", fmt.Sprintf("tx%d amend", txNum)) + if e2 != nil { + tx.Rollback() + } + return e2 + } + transactions = append(transactions, txIdFunc{txNum: txNum, txFunc: f}) + } + + rand.Shuffle(len(transactions), func(i, j int) { + transactions[i], transactions[j] = transactions[j], transactions[i] + }) + + var atomicInt atomic.Int32 + atomicInt.Store(-1) + + var wg sync.WaitGroup + for _, txn := range transactions { + wg.Add(1) + go func() { + defer wg.Done() + // radomly sleep .5 - 1.5 seconds + time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond) + err := txn.txFunc() + + if err == nil { + // If there are multiple updates, something went wrong. + require.True(t, atomicInt.CompareAndSwap(-1, int32(txn.txNum))) + } + // Errors are expected. + }() + } + wg.Wait() + + winner := atomicInt.Load() + require.NotEqual(t, -1, winner) + + // Verify there are only 2 rows in the table + rows, err := db.QueryContext(ctx, "SELECT COUNT(*) FROM test_table") + require.NoError(t, err) + defer rows.Close() + rows.Next() + var count int + err = rows.Scan(&count) + require.NoError(t, err) + require.Equal(t, 2, count) + + // Verify the final state + row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1") + var value string + err = row.Scan(&value) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("tx%d value", winner), value) + + row = db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id != 1") + err = row.Scan(&value) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("tx%d new row", winner), value) + + // Verify the commit message + row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1") + var commitMessage string + err = row.Scan(&commitMessage) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("tx%d amend", winner), commitMessage) +}