diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go index 38821bf2550..4cae3e4c62b 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go @@ -30,7 +30,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" - "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/datas/pull" ) @@ -125,6 +124,12 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { fmt.Errorf("branch %q not found on remote", pullSpec.Branch.GetPath()) } + mode := ref.UpdateMode{Force: true, Prune: false} + err = actions.FetchRefSpecs(ctx, dbData, srcDB, pullSpec.RefSpecs, pullSpec.Remote, mode, runProgFuncs, stopProgFuncs) + if err != nil { + return noConflictsOrViolations, threeWayMerge, fmt.Errorf("fetch failed: %w", err) + } + var conflicts int var fastForward int for _, refSpec := range pullSpec.RefSpecs { @@ -136,17 +141,12 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { continue } - rsSeen = true - tmpDir, err := dbData.Rsw.TempTableFilesDir() - if err != nil { - return noConflictsOrViolations, threeWayMerge, err - } - // todo: can we pass nil for either of the channels? - srcDBCommit, err := actions.FetchRemoteBranch(ctx, tmpDir, pullSpec.Remote, srcDB, dbData.Ddb, branchRef, runProgFuncs, stopProgFuncs) - if err != nil { - return noConflictsOrViolations, threeWayMerge, err + if branchRef != pullSpec.Branch { + continue } + rsSeen = true + headRef, err := dbData.Rsr.CWBHeadRef() if err != nil { return noConflictsOrViolations, threeWayMerge, err @@ -154,28 +154,6 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath()) - // TODO: this could be replaced with a canFF check to test for error - err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit) - if errors.Is(err, datas.ErrMergeNeeded) { - // If the remote tracking branch has diverged from the local copy, we just overwrite it - // TODO: none of this is transactional - h, err := srcDBCommit.HashOf() - if err != nil { - return noConflictsOrViolations, threeWayMerge, err - } - err = dbData.Ddb.SetHead(ctx, remoteTrackRef, h) - if err != nil { - return noConflictsOrViolations, threeWayMerge, err - } - } else if err != nil { - return noConflictsOrViolations, threeWayMerge, fmt.Errorf("fetch failed; %w", err) - } - - // Only merge iff branch is current branch and there is an upstream set (pullSpec.Branch is set to nil if there is no upstream) - if branchRef != pullSpec.Branch { - continue - } - roots, ok := sess.GetRoots(ctx, dbName) if !ok { return noConflictsOrViolations, threeWayMerge, sql.ErrDatabaseNotFound.New(dbName)