Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 10 additions & 32 deletions go/libraries/doltcore/sqle/dprocedures/dolt_pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/datas/pull"
)

Expand Down Expand Up @@ -125,6 +124,12 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) {
fmt.Errorf("branch %q not found on remote", pullSpec.Branch.GetPath())
}

mode := ref.UpdateMode{Force: true, Prune: false}
err = actions.FetchRefSpecs(ctx, dbData, srcDB, pullSpec.RefSpecs, pullSpec.Remote, mode, runProgFuncs, stopProgFuncs)
if err != nil {
return noConflictsOrViolations, threeWayMerge, fmt.Errorf("fetch failed: %w", err)
}

var conflicts int
var fastForward int
for _, refSpec := range pullSpec.RefSpecs {
Expand All @@ -136,46 +141,19 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) {
continue
}

rsSeen = true
tmpDir, err := dbData.Rsw.TempTableFilesDir()
if err != nil {
return noConflictsOrViolations, threeWayMerge, err
}
// todo: can we pass nil for either of the channels?
srcDBCommit, err := actions.FetchRemoteBranch(ctx, tmpDir, pullSpec.Remote, srcDB, dbData.Ddb, branchRef, runProgFuncs, stopProgFuncs)
if err != nil {
return noConflictsOrViolations, threeWayMerge, err
if branchRef != pullSpec.Branch {
continue
}

rsSeen = true

headRef, err := dbData.Rsr.CWBHeadRef()
if err != nil {
return noConflictsOrViolations, threeWayMerge, err
}

msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath())

// TODO: this could be replaced with a canFF check to test for error
err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit)
if errors.Is(err, datas.ErrMergeNeeded) {
// If the remote tracking branch has diverged from the local copy, we just overwrite it
// TODO: none of this is transactional
h, err := srcDBCommit.HashOf()
if err != nil {
return noConflictsOrViolations, threeWayMerge, err
}
err = dbData.Ddb.SetHead(ctx, remoteTrackRef, h)
if err != nil {
return noConflictsOrViolations, threeWayMerge, err
}
} else if err != nil {
return noConflictsOrViolations, threeWayMerge, fmt.Errorf("fetch failed; %w", err)
}

// Only merge iff branch is current branch and there is an upstream set (pullSpec.Branch is set to nil if there is no upstream)
if branchRef != pullSpec.Branch {
continue
}

roots, ok := sess.GetRoots(ctx, dbName)
if !ok {
return noConflictsOrViolations, threeWayMerge, sql.ErrDatabaseNotFound.New(dbName)
Expand Down