Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go/libraries/doltcore/doltdb/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,14 @@ type PendingCommit struct {
// commit, once written.
// |headRef| is the ref of the HEAD the commit will update
// |mergeParentCommits| are any merge parents for this commit
// |amend| is a flag which indicates that additional parents should not be added to the provided |mergeParentCommits|.
// |cm| is the metadata for the commit
// The current branch head will be automatically filled in as the first parent at commit time.
func (ddb *DoltDB) NewPendingCommit(
ctx context.Context,
roots Roots,
mergeParentCommits []*Commit,
amend bool,
cm *datas.CommitMeta,
) (*PendingCommit, error) {
newstaged, val, err := ddb.writeRootValue(ctx, roots.Staged)
Expand All @@ -322,7 +324,7 @@ func (ddb *DoltDB) NewPendingCommit(
parents = append(parents, pc.dCommit.Addr())
}

commitOpts := datas.CommitOptions{Parents: parents, Meta: cm}
commitOpts := datas.CommitOptions{Parents: parents, Meta: cm, Amend: amend}
return &PendingCommit{
Roots: roots,
Val: val,
Expand Down
4 changes: 1 addition & 3 deletions go/libraries/doltcore/env/actions/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ func GetCommitStaged(
return nil, NewTblSchemaConflictError(schConflicts)
}
}
}

if !props.Force {
roots.Staged, err = doltdb.ValidateForeignKeysOnSchemas(ctx, roots.Staged)
if err != nil {
return nil, err
Expand All @@ -112,5 +110,5 @@ func GetCommitStaged(
return nil, err
}

return db.NewPendingCommit(ctx, roots, mergeParents, meta)
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
}
2 changes: 1 addition & 1 deletion go/libraries/doltcore/migrate/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func commitRoot(
return err
}

pcm, err := ddb.NewPendingCommit(ctx, roots, parents, meta)
pcm, err := ddb.NewPendingCommit(ctx, roots, parents, false, meta)
if err != nil {
return err
}
Expand Down
41 changes: 3 additions & 38 deletions go/libraries/doltcore/sqle/dsess/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,6 @@ func (d *DoltSession) NewPendingCommit(
// See NewPendingCommit
func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
headCommit := branchState.headCommit
headHash, _ := headCommit.HashOf()

if branchState.WorkingSet() == nil {
return nil, doltdb.ErrOperationNotSupportedInDetachedHead
Expand All @@ -755,52 +754,18 @@ func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchStat
// If the commit message isn't set and we're amending the previous commit,
// go ahead and set the commit message from the current HEAD
if props.Message == "" && props.Amend {
cs, err := doltdb.NewCommitSpec("HEAD")
if err != nil {
return nil, err
}

headRef, err := branchState.dbData.Rsr.CWBHeadRef()
if err != nil {
return nil, err
}
optCmt, err := branchState.dbData.Ddb.Resolve(ctx, cs, headRef)
commit, ok := optCmt.ToCommit()
if !ok {
return nil, doltdb.ErrGhostCommitEncountered
}

meta, err := commit.GetCommitMeta(ctx)
meta, err := headCommit.GetCommitMeta(ctx)
if err != nil {
return nil, err
}
props.Message = meta.Description
}

// TODO: This is not the correct way to write this commit as an amend. While this commit is running
// the branch head moves backwards and concurrency control here is not principled.
newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1")
if err != nil {
return nil, err
}

err = d.SetWorkingSet(ctx, ctx.GetCurrentDatabase(), branchState.WorkingSet().WithStagedRoot(newRoots.Staged))
if err != nil {
return nil, err
}

roots.Head = newRoots.Head
}

pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props)
if err != nil {
if props.Amend {
_, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String())
if err != nil {
return nil, err
}
}
if _, ok := err.(actions.NothingStaged); err != nil && !ok {
// Special case for nothing staged, which is not an error
if _, ok := err.(actions.NothingStaged); !ok {
return nil, err
}
}
Expand Down
5 changes: 5 additions & 0 deletions go/store/datas/commit_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,10 @@ type CommitOptions struct {
// parent.
Parents []hash.Hash

// Amend flag indicates that the commit being build it to amend an existing commit. Generally we add the branch HEAD
// as a parent, in addition to the parent set provided here. When we amend, we want to strictly use the commits
// provided in |Parents|, and no others.
Amend bool

Meta *CommitMeta
}
4 changes: 2 additions & 2 deletions go/store/datas/database_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ func (db *database) BuildNewCommit(ctx context.Context, ds Dataset, v types.Valu
if ok {
opts.Parents = []hash.Hash{headAddr}
}
} else {
} else if !opts.Amend {
curr, ok := ds.MaybeHeadAddr()
if ok {
if !hasParentHash(opts, curr) {
Expand Down Expand Up @@ -901,7 +901,7 @@ func (db *database) CommitWithWorkingSet(

// Prepend the current head hash to the list of parents if one was provided. This is only necessary if parents were
// provided because we fill it in automatically in buildNewCommit otherwise.
if len(opts.Parents) > 0 {
if len(opts.Parents) > 0 && !opts.Amend {
headHash, ok := commitDS.MaybeHeadAddr()
if ok {
if !hasParentHash(opts, headHash) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"fmt"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"

driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
)

func TestCommitConcurrency(t *testing.T) {
t.Parallel()
t.Run("SQL transaction with amend commit", testSQLTransactionWithAmendCommit)
t.Run("SQL racing amend", testSQLRacingAmend)
}

// testSQLTransactionWithAmendCommit verifies that two transactions started at the same state will not both be able
// to commit using --amend. The first transaction will be able to commit, but the second should get an error.
func testSQLTransactionWithAmendCommit(t *testing.T) {
ctx := context.Background()
u, err := driver.NewDoltUser()
require.NoError(t, err)
t.Cleanup(func() {
u.Cleanup()
})

rs, err := u.MakeRepoStore()
require.NoError(t, err)
repo, err := rs.MakeRepo("commit_concurrency_test")
require.NoError(t, err)

srvSettings := &driver.Server{
Args: []string{"--port", `{{get_port "server"}}`},
DynamicPort: "server",
}
var ports DynamicPorts
ports.global = &GlobalPorts
ports.t = t
server := MakeServer(t, repo, srvSettings, &ports)
server.DBName = "commit_concurrency_test"

// Connect to the database
db, err := server.DB(driver.Connection{User: "root"})
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
})

_, err = db.ExecContext(ctx, `
CREATE TABLE test_table (
id INT AUTO_INCREMENT PRIMARY KEY,
value VARCHAR(20)
);`)
require.NoError(t, err)

_, err = db.ExecContext(ctx, "INSERT INTO test_table (value) VALUES ('initial')")
require.NoError(t, err)

_, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')")
require.NoError(t, err)

// Create a new context for the first (failing) transaction
ctx1, cancel1 := context.WithCancel(ctx)
defer cancel1()
tx1, err := db.BeginTx(ctx1, nil)
require.NoError(t, err)
_, err = tx1.ExecContext(ctx1, "UPDATE test_table SET value = 'amended by tx1' WHERE id = 1")
require.NoError(t, err)

// Create a new context for the second (succeeding) transaction
ctx2, cancel2 := context.WithCancel(ctx)
defer cancel2()
tx2, err := db.BeginTx(ctx2, nil)
require.NoError(t, err)

// Update data within the second transaction
_, err = tx2.ExecContext(ctx2, "UPDATE test_table SET value = 'amended by tx2' WHERE id = 1")
require.NoError(t, err)

_, err = tx2.ExecContext(ctx2, "CALL DOLT_COMMIT('--amend', '-m', 'tx2 amended commit')")
require.NoError(t, err)

// Commit --amend will result in tx2 being committed. You can still make updates on tx1, but any commit should fail
_, err = tx1.ExecContext(ctx1, "INSERT INTO test_table (value) VALUES ('new row by tx1')")
require.NoError(t, err)

_, err = tx1.ExecContext(ctx1, "CALL DOLT_COMMIT('--amend', '-m', 'should fail')")
require.Error(t, err)
require.Contains(t, err.Error(), "this transaction conflicts with a committed transaction from another client, try restarting transaction")

// Verify that the data in the head is what we would expect
row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1")
var value string
err = row.Scan(&value)
require.NoError(t, err)
require.Equal(t, "amended by tx2", value)

// Verify the commit message
row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1")
var commitMessage string
err = row.Scan(&commitMessage)
require.NoError(t, err)
require.Equal(t, "tx2 amended commit", commitMessage)

}

func testSQLRacingAmend(t *testing.T) {
ctx := context.Background()
u, err := driver.NewDoltUser()
require.NoError(t, err)
t.Cleanup(func() {
u.Cleanup()
})

rs, err := u.MakeRepoStore()
require.NoError(t, err)
repo, err := rs.MakeRepo("racing_amend_test")
require.NoError(t, err)

srvSettings := &driver.Server{
Args: []string{"--port", `{{get_port "server"}}`},
DynamicPort: "server",
}
var ports DynamicPorts
ports.global = &GlobalPorts
ports.t = t
server := MakeServer(t, repo, srvSettings, &ports)
server.DBName = "racing_amend_test"

db, err := server.DB(driver.Connection{User: "root"})
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
})

_, err = db.ExecContext(ctx, `
CREATE TABLE test_table (
id INT AUTO_INCREMENT PRIMARY KEY,
value VARCHAR(20)
);`)
require.NoError(t, err)

_, err = db.ExecContext(ctx, "INSERT INTO test_table VALUES (1, 'initial')")
require.NoError(t, err)

_, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')")
require.NoError(t, err)

type txIdFunc struct {
txNum int
txFunc func() error
}

var transactions []txIdFunc
for txNum := 1; txNum <= 200; txNum++ {
txCtx, cancel := context.WithTimeout(ctx, 15*time.Second) // Should never hit, but just in case
tx, err := db.BeginTx(txCtx, nil)
require.NoError(t, err)
f := func() error {
defer cancel()
// update required to get a transaction conflict.
_, e2 := tx.ExecContext(txCtx, "UPDATE test_table SET value = ? WHERE id = 1", fmt.Sprintf("tx%d value", txNum))
require.NoError(t, e2)
_, e2 = tx.ExecContext(txCtx, "INSERT INTO test_table (value) VALUES (?)", fmt.Sprintf("tx%d new row", txNum))
require.NoError(t, e2)

// We want other transactions to have the chance to mess with the db. We'll verify that what's committed is what we expect.
time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond)

// This will commit the transaction, or error.
_, e2 = tx.ExecContext(txCtx, "CALL DOLT_COMMIT('--amend','-a', '-m', ?)", fmt.Sprintf("tx%d amend", txNum))
if e2 != nil {
tx.Rollback()
}
return e2
}
transactions = append(transactions, txIdFunc{txNum: txNum, txFunc: f})
}

rand.Shuffle(len(transactions), func(i, j int) {
transactions[i], transactions[j] = transactions[j], transactions[i]
})

var atomicInt atomic.Int32
atomicInt.Store(-1)

var wg sync.WaitGroup
for _, txn := range transactions {
wg.Add(1)
go func() {
defer wg.Done()
// radomly sleep .5 - 1.5 seconds
time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond)
err := txn.txFunc()

if err == nil {
// If there are multiple updates, something went wrong.
require.True(t, atomicInt.CompareAndSwap(-1, int32(txn.txNum)))
}
// Errors are expected.
}()
}
wg.Wait()

winner := atomicInt.Load()
require.NotEqual(t, -1, winner)

// Verify there are only 2 rows in the table
rows, err := db.QueryContext(ctx, "SELECT COUNT(*) FROM test_table")
require.NoError(t, err)
defer rows.Close()
rows.Next()
var count int
err = rows.Scan(&count)
require.NoError(t, err)
require.Equal(t, 2, count)

// Verify the final state
row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1")
var value string
err = row.Scan(&value)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("tx%d value", winner), value)

row = db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id != 1")
err = row.Scan(&value)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("tx%d new row", winner), value)

// Verify the commit message
row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1")
var commitMessage string
err = row.Scan(&commitMessage)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("tx%d amend", winner), commitMessage)
}
Loading