diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 52591e01f45f..974e06c732a5 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -47,16 +47,29 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -func TestMain(m *testing.M) { - databaseIDs := []string{DefaultDatabaseID} - databasesStr, ok := os.LookupEnv(envDatabases) - if ok { - databaseIDs = append(databaseIDs, strings.Split(databasesStr, ",")...) - } +type firestoreEdition int + +const ( + editionStandard firestoreEdition = iota // 0 + editionEnterprise // 1 +) + +const ( + envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" + envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" + envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" + envEnterpriseDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_ENTERPRISE_DATABASES" + envEmulator = "FIRESTORE_EMULATOR_HOST" + indexBuilding = "index is currently building" + databaseIDKey = "databaseID" + firestoreEditionKey = "edition" +) +func TestMain(m *testing.M) { testParams = make(map[string]interface{}) - for _, databaseID := range databaseIDs { - testParams["databaseID"] = databaseID + for databaseID, edition := range parseDatabases() { + testParams[databaseIDKey] = databaseID + testParams[firestoreEditionKey] = edition initIntegrationTest() status := m.Run() if status != 0 { @@ -68,13 +81,26 @@ func TestMain(m *testing.M) { os.Exit(0) } -const ( - envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" - envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" - envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" - envEmulator = "FIRESTORE_EMULATOR_HOST" - indexBuilding = "index is currently building" -) +func parseDatabases() map[string]firestoreEdition { + databases := map[string]firestoreEdition{ + DefaultDatabaseID: editionStandard, + } + + databasesStr, ok := os.LookupEnv(envDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionStandard + } + } + + databasesStr, ok = os.LookupEnv(envEnterpriseDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionEnterprise + } + } + return databases +} var ( iClient *Client @@ -88,7 +114,7 @@ var ( ) func initIntegrationTest() { - databaseID := testParams["databaseID"].(string) + databaseID := testParams[databaseIDKey].(string) log.Printf("Setting up tests to run on databaseID: %q\n", databaseID) flag.Parse() // needed for testing.Short() if testing.Short() { @@ -2730,12 +2756,12 @@ func TestIntegration_NewClientWithDatabase(t *testing.T) { }{ { desc: "Success", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: false, }, { desc: "Error from NewClient bubbled to NewClientWithDatabase", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: true, opt: []option.ClientOption{option.WithCredentialsFile("non existent filepath")}, },