diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 52591e01f45f..974e06c732a5 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -47,16 +47,29 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -func TestMain(m *testing.M) { - databaseIDs := []string{DefaultDatabaseID} - databasesStr, ok := os.LookupEnv(envDatabases) - if ok { - databaseIDs = append(databaseIDs, strings.Split(databasesStr, ",")...) - } +type firestoreEdition int + +const ( + editionStandard firestoreEdition = iota // 0 + editionEnterprise // 1 +) + +const ( + envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" + envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" + envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" + envEnterpriseDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_ENTERPRISE_DATABASES" + envEmulator = "FIRESTORE_EMULATOR_HOST" + indexBuilding = "index is currently building" + databaseIDKey = "databaseID" + firestoreEditionKey = "edition" +) +func TestMain(m *testing.M) { testParams = make(map[string]interface{}) - for _, databaseID := range databaseIDs { - testParams["databaseID"] = databaseID + for databaseID, edition := range parseDatabases() { + testParams[databaseIDKey] = databaseID + testParams[firestoreEditionKey] = edition initIntegrationTest() status := m.Run() if status != 0 { @@ -68,13 +81,26 @@ func TestMain(m *testing.M) { os.Exit(0) } -const ( - envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" - envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" - envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" - envEmulator = "FIRESTORE_EMULATOR_HOST" - indexBuilding = "index is currently building" -) +func parseDatabases() map[string]firestoreEdition { + databases := map[string]firestoreEdition{ + DefaultDatabaseID: editionStandard, + } + + databasesStr, ok := os.LookupEnv(envDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionStandard + } + } + + databasesStr, ok = os.LookupEnv(envEnterpriseDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionEnterprise + } + } + return databases +} var ( iClient *Client @@ -88,7 +114,7 @@ var ( ) func initIntegrationTest() { - databaseID := testParams["databaseID"].(string) + databaseID := testParams[databaseIDKey].(string) log.Printf("Setting up tests to run on databaseID: %q\n", databaseID) flag.Parse() // needed for testing.Short() if testing.Short() { @@ -2730,12 +2756,12 @@ func TestIntegration_NewClientWithDatabase(t *testing.T) { }{ { desc: "Success", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: false, }, { desc: "Error from NewClient bubbled to NewClientWithDatabase", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: true, opt: []option.ClientOption{option.WithCredentialsFile("non existent filepath")}, }, diff --git a/firestore/pipeline_result.go b/firestore/pipeline_result.go index a6d17917ba93..c06da71bd82f 100644 --- a/firestore/pipeline_result.go +++ b/firestore/pipeline_result.go @@ -92,7 +92,11 @@ func (p *PipelineResult) Data() (map[string]any, error) { if p == nil { return nil, status.Errorf(codes.NotFound, "result does not exist") } - m, err := createMapFromValueMap(p.proto.Fields, p.c) + var fields map[string]*pb.Value + if p.proto != nil { + fields = p.proto.Fields + } + m, err := createMapFromValueMap(fields, p.c) // Any error here is a bug in the client. if err != nil { panic(fmt.Sprintf("firestore: %v", err)) @@ -107,7 +111,11 @@ func (p *PipelineResult) DataTo(v any) error { if p == nil { return status.Errorf(codes.NotFound, "document does not exist") } - return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: p.proto.Fields}}}, p.c) + var fields map[string]*pb.Value + if p.proto != nil { + fields = p.proto.Fields + } + return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: fields}}}, p.c) } // PipelineResultIterator is an iterator over PipelineResults from a pipeline execution. diff --git a/firestore/pipeline_result_test.go b/firestore/pipeline_result_test.go index d457033c0b9f..fb8c3ba5358c 100644 --- a/firestore/pipeline_result_test.go +++ b/firestore/pipeline_result_test.go @@ -245,6 +245,11 @@ func TestPipelineResultIterator_GetAll(t *testing.T) { if data["id"].(int64) != 1 { t.Errorf("first result id: got %v, want: 1", data["id"]) } + + data, err = allResults[1].Data() + if err != nil { + t.Fatalf("Data: %v", err) + } if data["id"].(int64) != 2 { t.Errorf("second result id: got %v, want: 2", data["id"]) } @@ -357,11 +362,14 @@ func TestPipelineResult_NoResults(t *testing.T) { } data, err := pr.Data() - if err == nil { - t.Errorf("pr.Data() for non-existent result err: got nil, want %v", err) + if err != nil { + t.Errorf("pr.Data() for non-existent result err: got %v, want nil", err) + } + if data == nil { + t.Errorf("pr.Data() for non-existent result: got nil, want non-nil empty map") } - if data != nil { - t.Errorf("pr.Data() for non-existent result: got %v, want nil. Err: got", data) + if len(data) != 0 { + t.Errorf("pr.Data() for non-existent result: got map with %d elements, want empty map", len(data)) } type MyStruct struct{ Foo string }