diff --git a/go/store/blobstore/git_blobstore.go b/go/store/blobstore/git_blobstore.go index 018eb90ea08..a7ba33f27ff 100644 --- a/go/store/blobstore/git_blobstore.go +++ b/go/store/blobstore/git_blobstore.go @@ -247,19 +247,28 @@ func (gbs *GitBlobstore) Put(ctx context.Context, key string, totalSize int64, r } func (gbs *GitBlobstore) buildPutCommit(ctx context.Context, parent git.OID, hasParent bool, key string, blobOID git.OID) (git.OID, string, error) { - _, indexFile, cleanup, err := newTempIndex() + msg := fmt.Sprintf("gitblobstore: put %s", key) + commitOID, err := gbs.buildCommitWithMessage(ctx, parent, hasParent, key, blobOID, msg) if err != nil { return "", "", err } + return commitOID, msg, nil +} + +func (gbs *GitBlobstore) buildCommitWithMessage(ctx context.Context, parent git.OID, hasParent bool, key string, blobOID git.OID, msg string) (git.OID, error) { + _, indexFile, cleanup, err := newTempIndex() + if err != nil { + return "", err + } defer cleanup() if hasParent { if err := gbs.api.ReadTree(ctx, parent, indexFile); err != nil { - return "", "", err + return "", err } } else { if err := gbs.api.ReadTreeEmpty(ctx, indexFile); err != nil { - return "", "", err + return "", err } } @@ -270,12 +279,12 @@ func (gbs *GitBlobstore) buildPutCommit(ctx context.Context, parent git.OID, has // namespace keys into directories, consider proactively removing conflicting paths from the index // before UpdateIndexCacheInfo so Put/CheckAndPut remain robust. if err := gbs.api.UpdateIndexCacheInfo(ctx, indexFile, "100644", blobOID, key); err != nil { - return "", "", err + return "", err } treeOID, err := gbs.api.WriteTree(ctx, indexFile) if err != nil { - return "", "", err + return "", err } var parentPtr *git.OID @@ -283,7 +292,6 @@ func (gbs *GitBlobstore) buildPutCommit(ctx context.Context, parent git.OID, has p := parent parentPtr = &p } - msg := fmt.Sprintf("gitblobstore: put %s", key) // Prefer git's default identity from env/config when not explicitly configured. commitOID, err := gbs.api.CommitTree(ctx, treeOID, parentPtr, msg, gbs.identity) @@ -291,10 +299,10 @@ func (gbs *GitBlobstore) buildPutCommit(ctx context.Context, parent git.OID, has commitOID, err = gbs.api.CommitTree(ctx, treeOID, parentPtr, msg, defaultGitBlobstoreIdentity()) } if err != nil { - return "", "", err + return "", err } - return commitOID, msg, nil + return commitOID, nil } func defaultGitBlobstoreIdentity() *git.Identity { @@ -346,10 +354,57 @@ func (gbs *GitBlobstore) refAdvanced(ctx context.Context, old git.OID) bool { } func (gbs *GitBlobstore) CheckAndPut(ctx context.Context, expectedVersion, key string, totalSize int64, reader io.Reader) (string, error) { - if _, err := normalizeGitTreePath(key); err != nil { + key, err := normalizeGitTreePath(key) + if err != nil { + return "", err + } + + // Resolve current head and validate expectedVersion before consuming |reader|. + parent, ok, err := gbs.api.TryResolveRefCommit(ctx, gbs.ref) + if err != nil { + return "", err + } + actualVersion := "" + if ok { + actualVersion = parent.String() + } + if expectedVersion != actualVersion { + return "", CheckAndPutError{Key: key, ExpectedVersion: expectedVersion, ActualVersion: actualVersion} + } + + blobOID, err := gbs.api.HashObject(ctx, reader) + if err != nil { + return "", err + } + + msg := fmt.Sprintf("gitblobstore: checkandput %s", key) + newCommit, err := gbs.buildCommitWithMessage(ctx, parent, ok, key, blobOID, msg) + if err != nil { + return "", err + } + + if ok { + if err := gbs.api.UpdateRefCAS(ctx, gbs.ref, newCommit, parent, msg); err != nil { + // If the ref changed, surface as a standard mismatch error. + cur, ok2, err2 := gbs.api.TryResolveRefCommit(ctx, gbs.ref) + if err2 == nil && ok2 && cur != parent { + return "", CheckAndPutError{Key: key, ExpectedVersion: expectedVersion, ActualVersion: cur.String()} + } + return "", err + } + return newCommit.String(), nil + } + + // Create-only CAS: oldOID=all-zero requires the ref to not exist. + const zeroOID = git.OID("0000000000000000000000000000000000000000") + if err := gbs.api.UpdateRefCAS(ctx, gbs.ref, newCommit, zeroOID, msg); err != nil { + cur, ok2, err2 := gbs.api.TryResolveRefCommit(ctx, gbs.ref) + if err2 == nil && ok2 { + return "", CheckAndPutError{Key: key, ExpectedVersion: expectedVersion, ActualVersion: cur.String()} + } return "", err } - return "", fmt.Errorf("%w: GitBlobstore.CheckAndPut", git.ErrUnimplemented) + return newCommit.String(), nil } func (gbs *GitBlobstore) Concatenate(ctx context.Context, key string, sources []string) (string, error) { diff --git a/go/store/blobstore/git_blobstore_test.go b/go/store/blobstore/git_blobstore_test.go index cd37ec8d960..a5438085720 100644 --- a/go/store/blobstore/git_blobstore_test.go +++ b/go/store/blobstore/git_blobstore_test.go @@ -385,3 +385,120 @@ func TestGitBlobstore_Put_ContentionRetryPreservesOtherKey(t *testing.T) { _, _ = io.ReadAll(rc) _ = rc.Close() } + +type failReader struct { + called atomic.Bool +} + +func (r *failReader) Read(_ []byte) (int, error) { + r.called.Store(true) + return 0, io.EOF +} + +func TestGitBlobstore_CheckAndPut_CreateOnly(t *testing.T) { + requireGitOnPath(t) + + ctx := context.Background() + repo, err := gitrepo.InitBare(ctx, t.TempDir()+"/repo.git") + require.NoError(t, err) + + bs, err := NewGitBlobstoreWithIdentity(repo.GitDir, DoltDataRef, testIdentity()) + require.NoError(t, err) + + want := []byte("created\n") + ver, err := bs.CheckAndPut(ctx, "", "k", int64(len(want)), bytes.NewReader(want)) + require.NoError(t, err) + require.NotEmpty(t, ver) + + got, ver2, err := GetBytes(ctx, bs, "k", AllRange) + require.NoError(t, err) + require.Equal(t, ver, ver2) + require.Equal(t, want, got) +} + +func TestGitBlobstore_CheckAndPut_MismatchDoesNotRead(t *testing.T) { + requireGitOnPath(t) + + ctx := context.Background() + repo, err := gitrepo.InitBare(ctx, t.TempDir()+"/repo.git") + require.NoError(t, err) + + commit, err := repo.SetRefToTree(ctx, DoltDataRef, map[string][]byte{ + "k": []byte("base\n"), + }, "seed") + require.NoError(t, err) + + bs, err := NewGitBlobstoreWithIdentity(repo.GitDir, DoltDataRef, testIdentity()) + require.NoError(t, err) + + r := &failReader{} + _, err = bs.CheckAndPut(ctx, commit+"-wrong", "k", 1, r) + require.Error(t, err) + require.True(t, IsCheckAndPutError(err)) + require.False(t, r.called.Load(), "expected reader not to be consumed on version mismatch") +} + +func TestGitBlobstore_CheckAndPut_UpdateSuccess(t *testing.T) { + requireGitOnPath(t) + + ctx := context.Background() + repo, err := gitrepo.InitBare(ctx, t.TempDir()+"/repo.git") + require.NoError(t, err) + + commit, err := repo.SetRefToTree(ctx, DoltDataRef, map[string][]byte{ + "k": []byte("base\n"), + "keep": []byte("keep\n"), + }, "seed") + require.NoError(t, err) + + bs, err := NewGitBlobstoreWithIdentity(repo.GitDir, DoltDataRef, testIdentity()) + require.NoError(t, err) + + want := []byte("updated\n") + ver2, err := bs.CheckAndPut(ctx, commit, "k", int64(len(want)), bytes.NewReader(want)) + require.NoError(t, err) + require.NotEmpty(t, ver2) + require.NotEqual(t, commit, ver2) + + got, ver3, err := GetBytes(ctx, bs, "k", AllRange) + require.NoError(t, err) + require.Equal(t, ver2, ver3) + require.Equal(t, want, got) + + got, _, err = GetBytes(ctx, bs, "keep", AllRange) + require.NoError(t, err) + require.Equal(t, []byte("keep\n"), got) +} + +func TestGitBlobstore_CheckAndPut_ConcurrentUpdateReturnsMismatch(t *testing.T) { + requireGitOnPath(t) + + ctx := context.Background() + repo, err := gitrepo.InitBare(ctx, t.TempDir()+"/repo.git") + require.NoError(t, err) + + commit, err := repo.SetRefToTree(ctx, DoltDataRef, map[string][]byte{ + "k": []byte("base\n"), + }, "seed") + require.NoError(t, err) + + bs, err := NewGitBlobstoreWithIdentity(repo.GitDir, DoltDataRef, testIdentity()) + require.NoError(t, err) + + origAPI := bs.api + h := &hookGitAPI{GitAPI: origAPI, ref: DoltDataRef} + h.onFirstCAS = func(ctx context.Context, old git.OID) { + // Advance the ref (without touching "k") to make UpdateRefCAS fail. + _, _ = writeKeyToRef(ctx, origAPI, DoltDataRef, "external", []byte("external\n"), testIdentity()) + } + bs.api = h + + _, err = bs.CheckAndPut(ctx, commit, "k", 0, bytes.NewReader([]byte("mine\n"))) + require.Error(t, err) + require.True(t, IsCheckAndPutError(err)) + + // Verify key did not change, since our CAS should have failed. + got, _, err := GetBytes(ctx, bs, "k", AllRange) + require.NoError(t, err) + require.Equal(t, []byte("base\n"), got) +}