diff --git a/datastore/datastore.go b/datastore/datastore.go index 61020bf175cd..2ec5cb52917c 100644 --- a/datastore/datastore.go +++ b/datastore/datastore.go @@ -74,6 +74,7 @@ type Client struct { dataset string // Called dataset by the datastore API, synonym for project ID. databaseID string // Default value is empty string readSettings *readSettings + config *datastoreConfig } // NewClient creates a new Client for a given dataset. If the project ID is @@ -152,12 +153,15 @@ func NewClientWithDatabase(ctx context.Context, projectID, databaseID string, op if err != nil { return nil, fmt.Errorf("dialing: %w", err) } + + config := newDatastoreConfig(o...) return &Client{ connPool: connPool, client: newDatastoreClient(connPool, projectID, databaseID), dataset: projectID, readSettings: &readSettings{}, databaseID: databaseID, + config: &config, }, nil } @@ -362,6 +366,48 @@ func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) { return multiArgTypeInvalid, nil } +// processFieldMismatchError ignore FieldMismatchErr if WithIgnoreFieldMismatch client option is provided by user +func (c *Client) processFieldMismatchError(err error) error { + if c.config == nil || !c.config.ignoreFieldMismatchErrors { + return err + } + return ignoreFieldMismatchErrs(err) +} + +func ignoreFieldMismatchErrs(err error) error { + if err == nil { + return err + } + + multiErr, isMultiErr := err.(MultiError) + if isMultiErr { + foundErr := false + for i, e := range multiErr { + multiErr[i] = ignoreFieldMismatchErr(e) + if multiErr[i] != nil { + foundErr = true + } + } + if !foundErr { + return nil + } + return multiErr + } + + return ignoreFieldMismatchErr(err) +} + +func ignoreFieldMismatchErr(err error) error { + if err == nil { + return err + } + _, isFieldMismatchErr := err.(*ErrFieldMismatch) + if isFieldMismatchErr { + return nil + } + return err +} + // Close closes the Client. Call Close to clean up resources when done with the // Client. func (c *Client) Close() error { @@ -402,9 +448,9 @@ func (c *Client) Get(ctx context.Context, key *Key, dst interface{}) (err error) // as transaction id which can be ignored _, err = c.get(ctx, []*Key{key}, []interface{}{dst}, opts) if me, ok := err.(MultiError); ok { - return me[0] + return c.processFieldMismatchError(me[0]) } - return err + return c.processFieldMismatchError(err) } // GetMulti is a batch version of Get. @@ -436,7 +482,7 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst interface{}) (er // Since opts does not contain Transaction option, 'get' call below will return nil // as transaction id which can be ignored _, err = c.get(ctx, keys, dst, opts) - return err + return c.processFieldMismatchError(err) } func (c *Client) get(ctx context.Context, keys []*Key, dst interface{}, opts *pb.ReadOptions) ([]byte, error) { diff --git a/datastore/integration_test.go b/datastore/integration_test.go index 4d39ea5b3513..1d7cf3f99180 100644 --- a/datastore/integration_test.go +++ b/datastore/integration_test.go @@ -66,8 +66,8 @@ type replayInfo struct { var ( record = flag.Bool("record", false, "record RPCs") - newTestClient = func(ctx context.Context, t *testing.T) *Client { - return newClient(ctx, t, nil) + newTestClient = func(ctx context.Context, t *testing.T, opts ...option.ClientOption) *Client { + return newClient(ctx, t, nil, opts...) } testParams map[string]interface{} @@ -109,8 +109,8 @@ func testMain(m *testing.M) int { log.Fatalf("closing recorder: %v", err) } }() - newTestClient = func(ctx context.Context, t *testing.T) *Client { - return newClient(ctx, t, rec.DialOptions()) + newTestClient = func(ctx context.Context, t *testing.T, opts ...option.ClientOption) *Client { + return newClient(ctx, t, rec.DialOptions(), opts...) } log.Printf("recording to %s", replayFilename) } @@ -172,7 +172,7 @@ func initReplay() { log.Fatal(err) } - newTestClient = func(ctx context.Context, t *testing.T) *Client { + newTestClient = func(ctx context.Context, t *testing.T, opts ...option.ClientOption) *Client { grpcHeadersEnforcer := &testutil.HeadersEnforcer{ OnFailure: t.Fatalf, Checkers: []*testutil.HeaderChecker{ @@ -181,7 +181,8 @@ func initReplay() { }, } - opts := append(grpcHeadersEnforcer.CallOptions(), option.WithGRPCConn(conn)) + opts = append(opts, grpcHeadersEnforcer.CallOptions()...) + opts = append(opts, option.WithGRPCConn(conn)) client, err := NewClientWithDatabase(ctx, ri.ProjectID, testParams["databaseID"].(string), opts...) if err != nil { t.Fatalf("NewClientWithDatabase: %v", err) @@ -191,7 +192,7 @@ func initReplay() { log.Printf("replaying from %s", replayFilename) } -func newClient(ctx context.Context, t *testing.T, dialOpts []grpc.DialOption) *Client { +func newClient(ctx context.Context, t *testing.T, dialOpts []grpc.DialOption, opts ...option.ClientOption) *Client { if testing.Short() { t.Skip("Integration tests skipped in short mode") } @@ -207,7 +208,8 @@ func newClient(ctx context.Context, t *testing.T, dialOpts []grpc.DialOption) *C xGoogReqParamsHeaderChecker, }, } - opts := append(grpcHeadersEnforcer.CallOptions(), option.WithTokenSource(ts)) + opts = append(opts, grpcHeadersEnforcer.CallOptions()...) + opts = append(opts, option.WithTokenSource(ts)) for _, opt := range dialOpts { opts = append(opts, option.WithGRPCDialOption(opt)) } @@ -264,6 +266,130 @@ func TestIntegration_Basics(t *testing.T) { } } +type OldX struct { + I int + J int +} +type NewX struct { + I int + j int +} + +func TestIntegration_IgnoreFieldMismatch(t *testing.T) { + ctx := context.Background() + client := newTestClient(ctx, t, WithIgnoreFieldMismatch()) + t.Cleanup(func() { + client.Close() + }) + + // Save entities with an extra field + keys := []*Key{ + NameKey("X", "x1", nil), + NameKey("X", "x2", nil), + } + entitiesOld := []OldX{ + {I: 10, J: 20}, + {I: 30, J: 40}, + } + _, gotErr := client.PutMulti(ctx, keys, entitiesOld) + if gotErr != nil { + t.Fatalf("Failed to save: %v\n", gotErr) + } + + var wants []NewX + for _, oldX := range entitiesOld { + wants = append(wants, []NewX{{I: oldX.I}}...) + } + + t.Cleanup(func() { + client.DeleteMulti(ctx, keys) + }) + + tests := []struct { + desc string + client *Client + wantErr error + }{ + { + desc: "Without IgnoreFieldMismatch option", + client: newTestClient(ctx, t), + wantErr: &ErrFieldMismatch{ + StructType: reflect.TypeOf(NewX{}), + FieldName: "J", + Reason: "no such struct field", + }, + }, + { + desc: "With IgnoreFieldMismatch option", + client: newTestClient(ctx, t, WithIgnoreFieldMismatch()), + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + defer test.client.Close() + // FieldMismatch error in Next + query := NewQuery("X").FilterField("I", ">=", 10) + it := test.client.Run(ctx, query) + resIndex := 0 + for { + var newX NewX + _, err := it.Next(&newX) + if err == iterator.Done { + break + } + + compareIgnoreFieldMismatchResults(t, []NewX{wants[resIndex]}, []NewX{newX}, test.wantErr, err, "Next") + resIndex++ + } + + // FieldMismatch error in Get + var getX NewX + gotErr = test.client.Get(ctx, keys[0], &getX) + compareIgnoreFieldMismatchResults(t, []NewX{wants[0]}, []NewX{getX}, test.wantErr, gotErr, "Get") + + // FieldMismatch error in GetAll + var getAllX []NewX + _, gotErr = test.client.GetAll(ctx, query, &getAllX) + compareIgnoreFieldMismatchResults(t, wants, getAllX, test.wantErr, gotErr, "GetAll") + + // FieldMismatch error in GetMulti + getMultiX := make([]NewX, len(keys)) + gotErr = test.client.GetMulti(ctx, keys, getMultiX) + compareIgnoreFieldMismatchResults(t, wants, getMultiX, test.wantErr, gotErr, "GetMulti") + + tx, err := test.client.NewTransaction(ctx) + if err != nil { + t.Fatalf("tx.GetMulti got: %v, want: nil\n", err) + } + + // FieldMismatch error in tx.Get + var txGetX NewX + err = tx.Get(keys[0], &txGetX) + compareIgnoreFieldMismatchResults(t, []NewX{wants[0]}, []NewX{txGetX}, test.wantErr, err, "tx.Get") + + // FieldMismatch error in tx.GetMulti + txGetMultiX := make([]NewX, len(keys)) + err = tx.GetMulti(keys, txGetMultiX) + compareIgnoreFieldMismatchResults(t, wants, txGetMultiX, test.wantErr, err, "tx.GetMulti") + + tx.Commit() + + }) + } + +} + +func compareIgnoreFieldMismatchResults(t *testing.T, wantX []NewX, gotX []NewX, wantErr error, gotErr error, errPrefix string) { + if !equalErrs(gotErr, wantErr) { + t.Errorf("%v: error got: %v, want: %v", errPrefix, gotErr, wantErr) + } + for resIndex := 0; resIndex < len(wantX) && gotErr == nil; resIndex++ { + if wantX[resIndex].I != gotX[resIndex].I { + t.Fatalf("%v %v: got: %v, want: %v\n", errPrefix, resIndex, wantX[resIndex].I, gotX[resIndex].I) + } + } +} + func TestIntegration_GetWithReadTime(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) client := newTestClient(ctx, t) diff --git a/datastore/option.go b/datastore/option.go new file mode 100644 index 000000000000..2f5c8c89eac1 --- /dev/null +++ b/datastore/option.go @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datastore + +import ( + "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" +) + +// datastoreConfig contains the Datastore client option configuration that can be +// set through datastoreClientOptions. +type datastoreConfig struct { + ignoreFieldMismatchErrors bool +} + +// newDatastoreConfig generates a new datastoreConfig with all the given +// datastoreClientOptions applied. +func newDatastoreConfig(opts ...option.ClientOption) datastoreConfig { + var conf datastoreConfig + for _, opt := range opts { + if datastoreOpt, ok := opt.(datastoreClientOption); ok { + datastoreOpt.applyDatastoreOpt(&conf) + } + } + return conf +} + +// A datastoreClientOption is an option for a Google Datastore client. +type datastoreClientOption interface { + option.ClientOption + applyDatastoreOpt(*datastoreConfig) +} + +// WithIgnoreFieldMismatch allows ignoring ErrFieldMismatch error while +// reading or querying data. +// WARNING: Ignoring ErrFieldMismatch can cause data loss while writing +// back to Datastore. E.g. +// if entity written to Datastore is {X: 1, Y:2} and it is read into +// type NewStruct struct{X int}, then {X:1} is returned. +// Now, if this is written back to Datastore, there will be no Y field +// left for this entity in Datastore +func WithIgnoreFieldMismatch() option.ClientOption { + return &withIgnoreFieldMismatch{ignoreFieldMismatchErrors: true} +} + +type withIgnoreFieldMismatch struct { + internaloption.EmbeddableAdapter + ignoreFieldMismatchErrors bool +} + +func (w *withIgnoreFieldMismatch) applyDatastoreOpt(c *datastoreConfig) { + c.ignoreFieldMismatchErrors = true +} diff --git a/datastore/query.go b/datastore/query.go index d4dec92a8159..8815bffe999b 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -825,7 +825,7 @@ func (c *Client) GetAllWithOptions(ctx context.Context, q *Query, dst interface{ } res.Keys = append(res.Keys, k) } - return res, errFieldMismatch + return res, c.processFieldMismatchError(errFieldMismatch) } // Run runs the given query in the given context @@ -1061,7 +1061,7 @@ func (t *Iterator) Next(dst interface{}) (k *Key, err error) { if dst != nil && !t.keysOnly { err = loadEntityProto(dst, e) } - return k, err + return k, t.client.processFieldMismatchError(err) } func (t *Iterator) next() (*Key, *pb.Entity, error) { diff --git a/datastore/transaction.go b/datastore/transaction.go index f91874d8aa77..8f533dec32f3 100644 --- a/datastore/transaction.go +++ b/datastore/transaction.go @@ -571,7 +571,7 @@ func (t *Transaction) get(spanName string, keys []*Key, dst interface{}) (err er if txnID != nil && err == nil { t.setToInProgress(txnID) } - return err + return t.client.processFieldMismatchError(err) } // Get is the transaction-specific version of the package function Get. @@ -582,9 +582,9 @@ func (t *Transaction) get(spanName string, keys []*Key, dst interface{}) (err er func (t *Transaction) Get(key *Key, dst interface{}) (err error) { err = t.get("cloud.google.com/go/datastore.Transaction.Get", []*Key{key}, []interface{}{dst}) if me, ok := err.(MultiError); ok { - return me[0] + return t.client.processFieldMismatchError(me[0]) } - return err + return t.client.processFieldMismatchError(err) } // GetMulti is a batch version of Get.