diff --git a/firestore/pipeline_result.go b/firestore/pipeline_result.go index fc66342d1c83..24c77a6b1ced 100644 --- a/firestore/pipeline_result.go +++ b/firestore/pipeline_result.go @@ -88,25 +88,23 @@ func newPipelineResult(ref *DocumentRef, proto *pb.Document, c *Client, executio // // var m map[string]interface{} // p.DataTo(&m) -// -// except that it returns nil if the document does not exist. -func (p *PipelineResult) Data() map[string]interface{} { - if !p.Exists() { - return nil +func (p *PipelineResult) Data() (map[string]interface{}, error) { + if p == nil { + return nil, status.Errorf(codes.NotFound, "result does not exist") } m, err := createMapFromValueMap(p.proto.Fields, p.c) // Any error here is a bug in the client. if err != nil { panic(fmt.Sprintf("firestore: %v", err)) } - return m + return m, nil } // DataTo uses the PipelineResult's fields to populate v, which can be a pointer to a // map[string]interface{} or a pointer to a struct. // This is similar to [DocumentSnapshot.DataTo] func (p *PipelineResult) DataTo(v interface{}) error { - if !p.Exists() { + if p == nil { return status.Errorf(codes.NotFound, "document does not exist") } return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: p.proto.Fields}}}, p.c) diff --git a/firestore/pipeline_result_test.go b/firestore/pipeline_result_test.go index 143059e2c085..3cbe4d4efea4 100644 --- a/firestore/pipeline_result_test.go +++ b/firestore/pipeline_result_test.go @@ -151,7 +151,11 @@ func TestStreamPipelineResultIterator_Next(t *testing.T) { t.Fatalf("Result count mismatch for data check: expected %d, got %d", len(tc.wantData), len(results)) } for i, pr := range results { - if diff := cmp.Diff(tc.wantData[i], pr.Data()); diff != "" { + data, err := pr.Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + if diff := cmp.Diff(tc.wantData[i], data); diff != "" { t.Errorf("Data mismatch for result %d (-want +got):\n%s", i, diff) } } @@ -233,11 +237,16 @@ func TestPipelineResultIterator_GetAll(t *testing.T) { if len(allResults) != 2 { t.Errorf("results from GetAll(): got %d, want: 2", len(allResults)) } - if allResults[0].Data()["id"].(int64) != 1 { - t.Errorf("first result id: got %v, want: 1", allResults[0].Data()["id"]) + + data, err := allResults[0].Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + if data["id"].(int64) != 1 { + t.Errorf("first result id: got %v, want: 1", data["id"]) } - if allResults[1].Data()["id"].(int64) != 2 { - t.Errorf("second result id: got %v, want: 2", allResults[1].Data()["id"]) + if data["id"].(int64) != 2 { + t.Errorf("second result id: got %v, want: 2", data["id"]) } // After GetAll, Next should return iterator.Done @@ -275,12 +284,12 @@ func TestPipelineResult_DataExtraction(t *testing.T) { t.Fatalf("newPipelineResult: %v", err) } - if !pr.Exists() { - t.Error("pr.Exists: got false, want true") + // Test Data() + dataMap, err := pr.Data() + if err != nil { + t.Fatalf("Data: %+v", err) } - // Test Data() - dataMap := pr.Data() if dataMap["stringProp"].(string) != "hello" { t.Errorf("stringProp: got %v, want 'hello'", dataMap["stringProp"]) } @@ -347,11 +356,12 @@ func TestPipelineResult_NoResults(t *testing.T) { t.Fatalf("newPipelineResult: %v", err) } - if pr.Exists() { - t.Error("pr.Exists() for non-existent result: got true, want false") + data, err := pr.Data() + if err == nil { + t.Errorf("pr.Data() for non-existent result err: got nil, want %v", err) } - if data := pr.Data(); data != nil { - t.Errorf("pr.Data() for non-existent result: got %v, want nil", data) + if data != nil { + t.Errorf("pr.Data() for non-existent result: got %v, want nil. Err: got", data) } type MyStruct struct{ Foo string } diff --git a/firestore/pipeline_source.go b/firestore/pipeline_source.go index b6c277f0e0fa..4fc499caff75 100644 --- a/firestore/pipeline_source.go +++ b/firestore/pipeline_source.go @@ -20,7 +20,38 @@ type PipelineSource struct { client *Client } -// Collection returns all documents from the entire collection. +// Collection creates a new [Pipeline] that operates on the specified Firestore collection. func (ps *PipelineSource) Collection(path string) *Pipeline { return newPipeline(ps.client, newInputStageCollection(path)) } + +// CollectionGroup creates a new [Pipeline] that operates on all documents in a group +// of collections that include the given ID, regardless of parent document. +// +// For example, consider: +// France/Cities/Paris = {population: 100} +// Canada/Cities/Montreal = {population: 90} +// +// CollectionGroup can be used to query across all "Cities" regardless of +// its parent "Countries". +func (ps *PipelineSource) CollectionGroup(collectionID string) *Pipeline { + return newPipeline(ps.client, newInputStageCollectionGroup("", collectionID)) +} + +// CollectionGroupWithAncestor creates a new [Pipeline] that operates on all documents in a group +// of collections that include the given ID, that are underneath a given document. +// +// For example, consider: +// /continents/Europe/Germany/Cities/Paris = {population: 100} +// /continents/Europe/France/Cities/Paris = {population: 100} +// /continents/NorthAmerica/Canada/Cities/Montreal = {population: 90} +// +// CollectionGroupWithAncestor can be used to query across all "Cities" in "/continents/Europe". +func (ps *PipelineSource) CollectionGroupWithAncestor(ancestor, collectionID string) *Pipeline { + return newPipeline(ps.client, newInputStageCollectionGroup(ancestor, collectionID)) +} + +// Database creates a new [Pipeline] that operates on all documents in the Firestore database. +func (ps *PipelineSource) Database() *Pipeline { + return newPipeline(ps.client, newInputStageDatabase()) +} diff --git a/firestore/pipeline_source_test.go b/firestore/pipeline_source_test.go index e6c41f70d8a5..c3d979721f54 100644 --- a/firestore/pipeline_source_test.go +++ b/firestore/pipeline_source_test.go @@ -50,3 +50,99 @@ func TestPipelineSource_Collection(t *testing.T) { t.Errorf("toExecutePipelineRequest mismatch for collection stage (-want +got):\n%s", diff) } } + +func TestPipelineSource_CollectionGroup(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.CollectionGroup("cities") + + if p.err != nil { + t.Fatalf("CollectionGroup: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "collection_group", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: ""}}, + {ValueType: &pb.Value_StringValue{StringValue: "cities"}}, + }, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for collectionGroup stage (-want +got):\n%s", diff) + } +} + +func TestPipelineSource_CollectionGroupWithAncestor(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.CollectionGroupWithAncestor("ancestor/path", "items") + + if p.err != nil { + t.Fatalf("CollectionGroupWithAncestor: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "collection_group", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "ancestor/path"}}, + {ValueType: &pb.Value_StringValue{StringValue: "items"}}, + }, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for collectionGroupWithAncestor stage (-want +got):\n%s", diff) + } +} + +func TestPipelineSource_Database(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Database() + + if p.err != nil { + t.Fatalf("Database: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "database", + Args: nil, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for database stage (-want +got):\n%s", diff) + } +} diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index 1a5a95d5eb42..058a796915c6 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -46,6 +46,40 @@ func (s *inputStageCollection) toProto() (*pb.Pipeline_Stage, error) { }, nil } +// inputStageCollection returns all documents from the entire collection. +type inputStageCollectionGroup struct { + collectionID string + ancestor string +} + +func newInputStageCollectionGroup(ancestor, collectionID string) *inputStageCollectionGroup { + return &inputStageCollectionGroup{ancestor: ancestor, collectionID: collectionID} +} +func (s *inputStageCollectionGroup) name() string { return "collection_group" } +func (s *inputStageCollectionGroup) toProto() (*pb.Pipeline_Stage, error) { + ancestor := &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: s.ancestor}} + collectionID := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: s.collectionID}} + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{ancestor, collectionID}, + }, nil +} + +// inputStageDatabase returns all documents from the entire database. +type inputStageDatabase struct { + path string +} + +func newInputStageDatabase() *inputStageDatabase { + return &inputStageDatabase{} +} +func (s *inputStageDatabase) name() string { return "database" } +func (s *inputStageDatabase) toProto() (*pb.Pipeline_Stage, error) { + return &pb.Pipeline_Stage{ + Name: s.name(), + }, nil +} + type limitStage struct { limit int }