diff --git a/applicationset/services/scm_provider/bitbucket_cloud_test.go b/applicationset/services/scm_provider/bitbucket_cloud_test.go index 359eac17e3f11..fca03e1693ade 100644 --- a/applicationset/services/scm_provider/bitbucket_cloud_test.go +++ b/applicationset/services/scm_provider/bitbucket_cloud_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" @@ -62,7 +61,7 @@ func TestBitbucketHasRepo(t *testing.T) { })) defer func() { testServer.Close() }() - os.Setenv("BITBUCKET_API_BASE_URL", testServer.URL) + t.Setenv("BITBUCKET_API_BASE_URL", testServer.URL) cases := []struct { name, path, repo, owner, sha string status int @@ -449,7 +448,7 @@ func TestBitbucketListRepos(t *testing.T) { })) defer func() { testServer.Close() }() - os.Setenv("BITBUCKET_API_BASE_URL", testServer.URL) + t.Setenv("BITBUCKET_API_BASE_URL", testServer.URL) cases := []struct { name, proto, owner string hasError, allBranches bool diff --git a/controller/metrics/metrics_test.go b/controller/metrics/metrics_test.go index c5583cb478bff..00309fb0944a7 100644 --- a/controller/metrics/metrics_test.go +++ b/controller/metrics/metrics_test.go @@ -5,7 +5,6 @@ import ( "log" "net/http" "net/http/httptest" - "os" "strings" "testing" "time" @@ -292,8 +291,7 @@ argocd_app_labels{label_non_existing="",name="my-app-3",namespace="argocd",proje } func TestLegacyMetrics(t *testing.T) { - os.Setenv(EnvVarLegacyControllerMetrics, "true") - defer os.Unsetenv(EnvVarLegacyControllerMetrics) + t.Setenv(EnvVarLegacyControllerMetrics, "true") expectedResponse := ` # HELP argocd_app_created_time Creation time in unix timestamp for an application. diff --git a/controller/sharding/sharding_test.go b/controller/sharding/sharding_test.go index ca44bf32e2d6b..629c023c4a054 100644 --- a/controller/sharding/sharding_test.go +++ b/controller/sharding/sharding_test.go @@ -14,7 +14,7 @@ import ( ) func TestGetShardByID_NotEmptyID(t *testing.T) { - os.Setenv(common.EnvControllerReplicas, "1") + t.Setenv(common.EnvControllerReplicas, "1") assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "1"})) assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "2"})) assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "3"})) @@ -22,21 +22,21 @@ func TestGetShardByID_NotEmptyID(t *testing.T) { } func TestGetShardByID_EmptyID(t *testing.T) { - os.Setenv(common.EnvControllerReplicas, "1") + t.Setenv(common.EnvControllerReplicas, "1") distributionFunction := LegacyDistributionFunction shard := distributionFunction()(&v1alpha1.Cluster{}) assert.Equal(t, 0, shard) } func TestGetShardByID_NoReplicas(t *testing.T) { - os.Setenv(common.EnvControllerReplicas, "0") + t.Setenv(common.EnvControllerReplicas, "0") distributionFunction := LegacyDistributionFunction shard := distributionFunction()(&v1alpha1.Cluster{}) assert.Equal(t, -1, shard) } func TestGetShardByID_NoReplicasUsingHashDistributionFunction(t *testing.T) { - os.Setenv(common.EnvControllerReplicas, "0") + t.Setenv(common.EnvControllerReplicas, "0") distributionFunction := LegacyDistributionFunction shard := distributionFunction()(&v1alpha1.Cluster{}) assert.Equal(t, -1, shard) @@ -45,8 +45,8 @@ func TestGetShardByID_NoReplicasUsingHashDistributionFunction(t *testing.T) { func TestGetShardByID_NoReplicasUsingHashDistributionFunctionWithClusters(t *testing.T) { db, cluster1, cluster2, cluster3, cluster4, cluster5 := createTestClusters() // Test with replicas set to 0 - os.Setenv(common.EnvControllerReplicas, "0") - os.Setenv(common.EnvControllerShardingAlgorithm, common.RoundRobinShardingAlgorithm) + t.Setenv(common.EnvControllerReplicas, "0") + t.Setenv(common.EnvControllerShardingAlgorithm, common.RoundRobinShardingAlgorithm) distributionFunction := RoundRobinDistributionFunction(db) assert.Equal(t, -1, distributionFunction(nil)) assert.Equal(t, -1, distributionFunction(&cluster1)) @@ -54,13 +54,12 @@ func TestGetShardByID_NoReplicasUsingHashDistributionFunctionWithClusters(t *tes assert.Equal(t, -1, distributionFunction(&cluster3)) assert.Equal(t, -1, distributionFunction(&cluster4)) assert.Equal(t, -1, distributionFunction(&cluster5)) - } func TestGetClusterFilterDefault(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) os.Unsetenv(common.EnvControllerShardingAlgorithm) - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") filter := GetClusterFilter(GetDistributionFunction(nil, common.DefaultShardingAlgorithm), shardIndex) assert.False(t, filter(&v1alpha1.Cluster{ID: "1"})) assert.True(t, filter(&v1alpha1.Cluster{ID: "2"})) @@ -70,8 +69,8 @@ func TestGetClusterFilterDefault(t *testing.T) { func TestGetClusterFilterLegacy(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) - os.Setenv(common.EnvControllerReplicas, "2") - os.Setenv(common.EnvControllerShardingAlgorithm, common.LegacyShardingAlgorithm) + t.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerShardingAlgorithm, common.LegacyShardingAlgorithm) filter := GetClusterFilter(GetDistributionFunction(nil, common.LegacyShardingAlgorithm), shardIndex) assert.False(t, filter(&v1alpha1.Cluster{ID: "1"})) assert.True(t, filter(&v1alpha1.Cluster{ID: "2"})) @@ -81,8 +80,8 @@ func TestGetClusterFilterLegacy(t *testing.T) { func TestGetClusterFilterUnknown(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) - os.Setenv(common.EnvControllerReplicas, "2") - os.Setenv(common.EnvControllerShardingAlgorithm, "unknown") + t.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerShardingAlgorithm, "unknown") filter := GetClusterFilter(GetDistributionFunction(nil, "unknown"), shardIndex) assert.False(t, filter(&v1alpha1.Cluster{ID: "1"})) assert.True(t, filter(&v1alpha1.Cluster{ID: "2"})) @@ -92,7 +91,7 @@ func TestGetClusterFilterUnknown(t *testing.T) { func TestLegacyGetClusterFilterWithFixedShard(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") filter := GetClusterFilter(GetDistributionFunction(nil, common.DefaultShardingAlgorithm), shardIndex) assert.False(t, filter(nil)) assert.False(t, filter(&v1alpha1.Cluster{ID: "1"})) @@ -107,12 +106,11 @@ func TestLegacyGetClusterFilterWithFixedShard(t *testing.T) { fixedShard = 1 filter = GetClusterFilter(GetDistributionFunction(nil, common.DefaultShardingAlgorithm), int(fixedShard)) assert.True(t, filter(&v1alpha1.Cluster{Name: "cluster4", ID: "4", Shard: &fixedShard})) - } func TestRoundRobinGetClusterFilterWithFixedShard(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") db, cluster1, cluster2, cluster3, cluster4, _ := createTestClusters() filter := GetClusterFilter(GetDistributionFunction(db, common.RoundRobinShardingAlgorithm), shardIndex) @@ -135,8 +133,8 @@ func TestRoundRobinGetClusterFilterWithFixedShard(t *testing.T) { func TestGetClusterFilterLegacyHash(t *testing.T) { shardIndex := 1 // ensuring that a shard with index 1 will process all the clusters with an "even" id (2,4,6,...) - os.Setenv(common.EnvControllerReplicas, "2") - os.Setenv(common.EnvControllerShardingAlgorithm, "hash") + t.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerShardingAlgorithm, "hash") db, cluster1, cluster2, cluster3, cluster4, _ := createTestClusters() filter := GetClusterFilter(GetDistributionFunction(db, common.LegacyShardingAlgorithm), shardIndex) assert.False(t, filter(&cluster1)) @@ -158,55 +156,64 @@ func TestGetClusterFilterLegacyHash(t *testing.T) { func TestGetClusterFilterWithEnvControllerShardingAlgorithms(t *testing.T) { db, cluster1, cluster2, cluster3, cluster4, _ := createTestClusters() shardIndex := 1 - os.Setenv(common.EnvControllerReplicas, "2") - os.Setenv(common.EnvControllerShardingAlgorithm, common.LegacyShardingAlgorithm) - shardShouldProcessCluster := GetClusterFilter(GetDistributionFunction(db, common.LegacyShardingAlgorithm), shardIndex) - assert.False(t, shardShouldProcessCluster(&cluster1)) - assert.True(t, shardShouldProcessCluster(&cluster2)) - assert.False(t, shardShouldProcessCluster(&cluster3)) - assert.True(t, shardShouldProcessCluster(&cluster4)) - assert.False(t, shardShouldProcessCluster(nil)) - - os.Setenv(common.EnvControllerShardingAlgorithm, common.RoundRobinShardingAlgorithm) - shardShouldProcessCluster = GetClusterFilter(GetDistributionFunction(db, common.LegacyShardingAlgorithm), shardIndex) - assert.False(t, shardShouldProcessCluster(&cluster1)) - assert.True(t, shardShouldProcessCluster(&cluster2)) - assert.False(t, shardShouldProcessCluster(&cluster3)) - assert.True(t, shardShouldProcessCluster(&cluster4)) - assert.False(t, shardShouldProcessCluster(nil)) + t.Setenv(common.EnvControllerReplicas, "2") + + t.Run("legacy", func(t *testing.T) { + t.Setenv(common.EnvControllerShardingAlgorithm, common.LegacyShardingAlgorithm) + shardShouldProcessCluster := GetClusterFilter(GetDistributionFunction(db, common.LegacyShardingAlgorithm), shardIndex) + assert.False(t, shardShouldProcessCluster(&cluster1)) + assert.True(t, shardShouldProcessCluster(&cluster2)) + assert.False(t, shardShouldProcessCluster(&cluster3)) + assert.True(t, shardShouldProcessCluster(&cluster4)) + assert.False(t, shardShouldProcessCluster(nil)) + }) + + t.Run("roundrobin", func(t *testing.T) { + t.Setenv(common.EnvControllerShardingAlgorithm, common.RoundRobinShardingAlgorithm) + shardShouldProcessCluster := GetClusterFilter(GetDistributionFunction(db, common.LegacyShardingAlgorithm), shardIndex) + assert.False(t, shardShouldProcessCluster(&cluster1)) + assert.True(t, shardShouldProcessCluster(&cluster2)) + assert.False(t, shardShouldProcessCluster(&cluster3)) + assert.True(t, shardShouldProcessCluster(&cluster4)) + assert.False(t, shardShouldProcessCluster(nil)) + }) } func TestGetShardByIndexModuloReplicasCountDistributionFunction2(t *testing.T) { db, cluster1, cluster2, cluster3, cluster4, cluster5 := createTestClusters() - // Test with replicas set to 1 - os.Setenv(common.EnvControllerReplicas, "1") - distributionFunction := RoundRobinDistributionFunction(db) - assert.Equal(t, 0, distributionFunction(nil)) - assert.Equal(t, 0, distributionFunction(&cluster1)) - assert.Equal(t, 0, distributionFunction(&cluster2)) - assert.Equal(t, 0, distributionFunction(&cluster3)) - assert.Equal(t, 0, distributionFunction(&cluster4)) - assert.Equal(t, 0, distributionFunction(&cluster5)) - - // Test with replicas set to 2 - os.Setenv(common.EnvControllerReplicas, "2") - distributionFunction = RoundRobinDistributionFunction(db) - assert.Equal(t, 0, distributionFunction(nil)) - assert.Equal(t, 0, distributionFunction(&cluster1)) - assert.Equal(t, 1, distributionFunction(&cluster2)) - assert.Equal(t, 0, distributionFunction(&cluster3)) - assert.Equal(t, 1, distributionFunction(&cluster4)) - assert.Equal(t, 0, distributionFunction(&cluster5)) - // // Test with replicas set to 3 - os.Setenv(common.EnvControllerReplicas, "3") - distributionFunction = RoundRobinDistributionFunction(db) - assert.Equal(t, 0, distributionFunction(nil)) - assert.Equal(t, 0, distributionFunction(&cluster1)) - assert.Equal(t, 1, distributionFunction(&cluster2)) - assert.Equal(t, 2, distributionFunction(&cluster3)) - assert.Equal(t, 0, distributionFunction(&cluster4)) - assert.Equal(t, 1, distributionFunction(&cluster5)) + t.Run("replicas set to 1", func(t *testing.T) { + t.Setenv(common.EnvControllerReplicas, "1") + distributionFunction := RoundRobinDistributionFunction(db) + assert.Equal(t, 0, distributionFunction(nil)) + assert.Equal(t, 0, distributionFunction(&cluster1)) + assert.Equal(t, 0, distributionFunction(&cluster2)) + assert.Equal(t, 0, distributionFunction(&cluster3)) + assert.Equal(t, 0, distributionFunction(&cluster4)) + assert.Equal(t, 0, distributionFunction(&cluster5)) + }) + + t.Run("replicas set to 2", func(t *testing.T) { + t.Setenv(common.EnvControllerReplicas, "2") + distributionFunction := RoundRobinDistributionFunction(db) + assert.Equal(t, 0, distributionFunction(nil)) + assert.Equal(t, 0, distributionFunction(&cluster1)) + assert.Equal(t, 1, distributionFunction(&cluster2)) + assert.Equal(t, 0, distributionFunction(&cluster3)) + assert.Equal(t, 1, distributionFunction(&cluster4)) + assert.Equal(t, 0, distributionFunction(&cluster5)) + }) + + t.Run("replicas set to 3", func(t *testing.T) { + t.Setenv(common.EnvControllerReplicas, "3") + distributionFunction := RoundRobinDistributionFunction(db) + assert.Equal(t, 0, distributionFunction(nil)) + assert.Equal(t, 0, distributionFunction(&cluster1)) + assert.Equal(t, 1, distributionFunction(&cluster2)) + assert.Equal(t, 2, distributionFunction(&cluster3)) + assert.Equal(t, 0, distributionFunction(&cluster4)) + assert.Equal(t, 1, distributionFunction(&cluster5)) + }) } func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterNumberIsHigh(t *testing.T) { @@ -222,7 +229,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterNumber clusterList.Items = append(clusterList.Items, cluster) } db.On("ListClusters", mock.Anything).Return(clusterList, nil) - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") distributionFunction := RoundRobinDistributionFunction(&db) for i, c := range clusterList.Items { assert.Equal(t, i%2, distributionFunction(&c)) @@ -242,7 +249,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterIsAdde db.On("ListClusters", mock.Anything).Return(clusterList, nil) // Test with replicas set to 2 - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") distributionFunction := RoundRobinDistributionFunction(&db) assert.Equal(t, 0, distributionFunction(nil)) assert.Equal(t, 0, distributionFunction(&cluster1)) @@ -259,12 +266,11 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterIsAdde // Now, we remove the last added cluster, it should be unassigned as well clusterList.Items = clusterList.Items[:len(clusterList.Items)-1] assert.Equal(t, -1, distributionFunction(&cluster6)) - } func TestGetShardByIndexModuloReplicasCountDistributionFunction(t *testing.T) { db, cluster1, cluster2, _, _, _ := createTestClusters() - os.Setenv(common.EnvControllerReplicas, "2") + t.Setenv(common.EnvControllerReplicas, "2") distributionFunction := RoundRobinDistributionFunction(db) // Test that the function returns the correct shard for cluster1 and cluster2 @@ -303,7 +309,6 @@ func TestInferShard(t *testing.T) { osHostnameFunction = func() (string, error) { return "example-shard", nil } _, err = InferShard() assert.NotNil(t, err) - } func createTestClusters() (*dbmocks.ArgoDB, v1alpha1.Cluster, v1alpha1.Cluster, v1alpha1.Cluster, v1alpha1.Cluster, v1alpha1.Cluster) { diff --git a/controller/sharding/shuffle_test.go b/controller/sharding/shuffle_test.go index 2baaa6a758ca9..9e089e31bad0f 100644 --- a/controller/sharding/shuffle_test.go +++ b/controller/sharding/shuffle_test.go @@ -3,7 +3,6 @@ package sharding import ( "fmt" "math" - "os" "testing" "github.com/argoproj/argo-cd/v2/common" @@ -24,7 +23,7 @@ func TestLargeShuffle(t *testing.T) { } db.On("ListClusters", mock.Anything).Return(clusterList, nil) // Test with replicas set to 256 - os.Setenv(common.EnvControllerReplicas, "256") + t.Setenv(common.EnvControllerReplicas, "256") distributionFunction := RoundRobinDistributionFunction(&db) for i, c := range clusterList.Items { assert.Equal(t, i%2567, distributionFunction(&c)) @@ -47,7 +46,7 @@ func TestShuffle(t *testing.T) { db.On("ListClusters", mock.Anything).Return(clusterList, nil) // Test with replicas set to 3 - os.Setenv(common.EnvControllerReplicas, "3") + t.Setenv(common.EnvControllerReplicas, "3") distributionFunction := RoundRobinDistributionFunction(&db) assert.Equal(t, 0, distributionFunction(nil)) assert.Equal(t, 0, distributionFunction(&cluster1)) diff --git a/controller/state_test.go b/controller/state_test.go index 537c0208e734b..ab004af591807 100644 --- a/controller/state_test.go +++ b/controller/state_test.go @@ -341,7 +341,6 @@ func TestAppRevisionsSingleSource(t *testing.T) { assert.NotNil(t, compRes.syncStatus) assert.NotEmpty(t, compRes.syncStatus.Revision) assert.Len(t, compRes.syncStatus.Revisions, 0) - } // TestAppRevisions tests that revisions are properly propagated for a multi source app @@ -708,9 +707,8 @@ var signedProj = argoappv1.AppProject{ } func TestSignedResponseNoSignatureRequired(t *testing.T) { - oldval := os.Getenv("ARGOCD_GPG_ENABLED") - os.Setenv("ARGOCD_GPG_ENABLED", "true") - defer os.Setenv("ARGOCD_GPG_ENABLED", oldval) + t.Setenv("ARGOCD_GPG_ENABLED", "true") + // We have a good signature response, but project does not require signed commits { app := newFakeApp() @@ -766,9 +764,7 @@ func TestSignedResponseNoSignatureRequired(t *testing.T) { } func TestSignedResponseSignatureRequired(t *testing.T) { - oldval := os.Getenv("ARGOCD_GPG_ENABLED") - os.Setenv("ARGOCD_GPG_ENABLED", "true") - defer os.Setenv("ARGOCD_GPG_ENABLED", oldval) + t.Setenv("ARGOCD_GPG_ENABLED", "true") // We have a good signature response, valid key, and signing is required - sync! { @@ -934,7 +930,7 @@ func TestSignedResponseSignatureRequired(t *testing.T) { assert.Contains(t, app.Status.Conditions[0].Message, "Cannot use local manifests") } - os.Setenv("ARGOCD_GPG_ENABLED", "false") + t.Setenv("ARGOCD_GPG_ENABLED", "false") // We have a bad signature response and signing would be required, but GPG subsystem is disabled - sync { app := newFakeApp() @@ -990,7 +986,6 @@ func TestSignedResponseSignatureRequired(t *testing.T) { assert.Len(t, compRes.managedResources, 0) assert.Len(t, app.Status.Conditions, 0) } - } func TestComparisonResult_GetHealthStatus(t *testing.T) { diff --git a/controller/sync_test.go b/controller/sync_test.go index a1a8161386436..da68e5d9a3dfe 100644 --- a/controller/sync_test.go +++ b/controller/sync_test.go @@ -2,7 +2,6 @@ package controller import ( "context" - "os" "testing" "github.com/argoproj/gitops-engine/pkg/sync" @@ -179,8 +178,7 @@ func TestSyncComparisonError(t *testing.T) { opState := &v1alpha1.OperationState{Operation: v1alpha1.Operation{ Sync: &v1alpha1.SyncOperation{}, }} - os.Setenv("ARGOCD_GPG_ENABLED", "true") - defer os.Setenv("ARGOCD_GPG_ENABLED", "false") + t.Setenv("ARGOCD_GPG_ENABLED", "true") ctrl.appStateManager.SyncAppState(app, opState) conditions := app.Status.GetConditions(map[v1alpha1.ApplicationConditionType]bool{v1alpha1.ApplicationConditionComparisonError: true}) diff --git a/pkg/apis/application/v1alpha1/types_test.go b/pkg/apis/application/v1alpha1/types_test.go index aa629529a25e9..fdabb9b009571 100644 --- a/pkg/apis/application/v1alpha1/types_test.go +++ b/pkg/apis/application/v1alpha1/types_test.go @@ -3156,7 +3156,7 @@ func TestGetCAPath(t *testing.T) { if err != nil { panic(err) } - os.Setenv(argocdcommon.EnvVarTLSDataPath, temppath) + t.Setenv(argocdcommon.EnvVarTLSDataPath, temppath) validcert := []string{ "https://foo.example.com", "oci://foo.example.com", diff --git a/util/cert/cert_test.go b/util/cert/cert_test.go index ea70adc60ea4e..3e5713102d735 100644 --- a/util/cert/cert_test.go +++ b/util/cert/cert_test.go @@ -441,31 +441,29 @@ func Test_EscapeBracketPattern(t *testing.T) { func TestGetTLSCertificateDataPath(t *testing.T) { t.Run("Get default path", func(t *testing.T) { - os.Setenv(common.EnvVarTLSDataPath, "") + t.Setenv(common.EnvVarTLSDataPath, "") path := GetTLSCertificateDataPath() assert.Equal(t, common.DefaultPathTLSConfig, path) }) t.Run("Get custom path", func(t *testing.T) { - os.Setenv(common.EnvVarTLSDataPath, "/some/where") + t.Setenv(common.EnvVarTLSDataPath, "/some/where") path := GetTLSCertificateDataPath() assert.Equal(t, "/some/where", path) - os.Setenv(common.EnvVarTLSDataPath, "") }) } func TestGetSSHKnownHostsDataPath(t *testing.T) { t.Run("Get default path", func(t *testing.T) { - os.Setenv(common.EnvVarSSHDataPath, "") + t.Setenv(common.EnvVarSSHDataPath, "") p := GetSSHKnownHostsDataPath() assert.Equal(t, path.Join(common.DefaultPathSSHConfig, "ssh_known_hosts"), p) }) t.Run("Get custom path", func(t *testing.T) { - os.Setenv(common.EnvVarSSHDataPath, "/some/where") + t.Setenv(common.EnvVarSSHDataPath, "/some/where") path := GetSSHKnownHostsDataPath() assert.Equal(t, "/some/where/ssh_known_hosts", path) - os.Setenv(common.EnvVarSSHDataPath, "") }) } @@ -480,7 +478,7 @@ func TestGetCertificateForConnect(t *testing.T) { if err != nil { panic(err) } - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certs, err := GetCertificateForConnect("127.0.0.1") assert.NoError(t, err) assert.Len(t, certs, 1) @@ -488,7 +486,7 @@ func TestGetCertificateForConnect(t *testing.T) { t.Run("No cert found", func(t *testing.T) { temppath := t.TempDir() - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certs, err := GetCertificateForConnect("127.0.0.1") assert.NoError(t, err) assert.Len(t, certs, 0) @@ -500,7 +498,7 @@ func TestGetCertificateForConnect(t *testing.T) { if err != nil { panic(err) } - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certs, err := GetCertificateForConnect("127.0.0.1") assert.Error(t, err) assert.Len(t, certs, 0) @@ -520,7 +518,7 @@ func TestGetCertBundlePathForRepository(t *testing.T) { if err != nil { panic(err) } - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certpath, err := GetCertBundlePathForRepository("127.0.0.1") assert.NoError(t, err) assert.Equal(t, certpath, path.Join(temppath, "127.0.0.1")) @@ -528,7 +526,7 @@ func TestGetCertBundlePathForRepository(t *testing.T) { t.Run("No cert found", func(t *testing.T) { temppath := t.TempDir() - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certpath, err := GetCertBundlePathForRepository("127.0.0.1") assert.NoError(t, err) assert.Empty(t, certpath) @@ -540,7 +538,7 @@ func TestGetCertBundlePathForRepository(t *testing.T) { if err != nil { panic(err) } - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) certpath, err := GetCertBundlePathForRepository("127.0.0.1") assert.NoError(t, err) assert.Empty(t, certpath) diff --git a/util/config/env_test.go b/util/config/env_test.go index 2456f8b06cde4..c19961813a457 100644 --- a/util/config/env_test.go +++ b/util/config/env_test.go @@ -1,18 +1,18 @@ package config import ( - "os" "testing" "github.com/stretchr/testify/assert" ) func loadOpts(t *testing.T, opts string) { - assert.Nil(t, os.Setenv("ARGOCD_OPTS", opts)) + t.Setenv("ARGOCD_OPTS", opts) assert.Nil(t, loadFlags()) } + func loadInvalidOpts(t *testing.T, opts string) { - assert.Nil(t, os.Setenv("ARGOCD_OPTS", opts)) + t.Setenv("ARGOCD_OPTS", opts) assert.Error(t, loadFlags()) } @@ -41,6 +41,7 @@ func TestBoolFlagAtStart(t *testing.T) { assert.True(t, GetBoolFlag("foo")) } + func TestBoolFlagInMiddle(t *testing.T) { loadOpts(t, "--bar baz --foo --qux") @@ -76,6 +77,7 @@ func TestFlagWithSingleQuotes(t *testing.T) { assert.Equal(t, "bar baz", GetFlag("foo", "")) } + func TestFlagWithDoubleQuotes(t *testing.T) { loadOpts(t, "--foo \"bar baz\"") diff --git a/util/db/gpgkeys_test.go b/util/db/gpgkeys_test.go index df33ee1aecd0e..c6377c75124ee 100644 --- a/util/db/gpgkeys_test.go +++ b/util/db/gpgkeys_test.go @@ -258,8 +258,8 @@ func Test_AddGPGPublicKey(t *testing.T) { func Test_DeleteGPGPublicKey(t *testing.T) { defer os.Setenv("GNUPGHOME", "") - // Good case - { + + t.Run("good case", func(t *testing.T) { clientset := getGPGKeysClientset(gpgCMMultiGoodPubkey) settings := settings.NewSettingsManager(context.Background(), clientset, testNamespace) db := NewDB(testNamespace, settings, clientset) @@ -289,10 +289,9 @@ func Test_DeleteGPGPublicKey(t *testing.T) { n, err = db.ListConfiguredGPGPublicKeys(context.Background()) assert.NoError(t, err) assert.Len(t, n, 0) + }) - } - // Bad case - empty ConfigMap - { + t.Run("bad case - empty ConfigMap", func(t *testing.T) { clientset := getGPGKeysClientset(gpgCMEmpty) settings := settings.NewSettingsManager(context.Background(), clientset, testNamespace) db := NewDB(testNamespace, settings, clientset) @@ -300,5 +299,5 @@ func Test_DeleteGPGPublicKey(t *testing.T) { // Key should be removed err := db.DeleteGPGPublicKey(context.Background(), "F7842A5CEAA9C0B1") assert.Error(t, err) - } + }) } diff --git a/util/exec/exec_test.go b/util/exec/exec_test.go index 4740f6a8e2e54..0347abf0955e0 100644 --- a/util/exec/exec_test.go +++ b/util/exec/exec_test.go @@ -1,7 +1,6 @@ package exec import ( - "os" "os/exec" "regexp" "syscall" @@ -13,13 +12,12 @@ import ( ) func Test_timeout(t *testing.T) { - defer func() { _ = os.Unsetenv("ARGOCD_EXEC_TIMEOUT") }() t.Run("Default", func(t *testing.T) { initTimeout() assert.Equal(t, 90*time.Second, timeout) }) t.Run("Default", func(t *testing.T) { - _ = os.Setenv("ARGOCD_EXEC_TIMEOUT", "1s") + t.Setenv("ARGOCD_EXEC_TIMEOUT", "1s") initTimeout() assert.Equal(t, 1*time.Second, timeout) }) @@ -35,7 +33,7 @@ func TestHideUsernamePassword(t *testing.T) { _, err := RunWithRedactor(exec.Command("helm registry login https://charts.bitnami.com/bitnami", "--username", "foo", "--password", "bar"), nil) assert.NotEmpty(t, err) - var redactor = func(text string) string { + redactor := func(text string) string { return regexp.MustCompile("(--username|--password) [^ ]*").ReplaceAllString(text, "$1 ******") } _, err = RunWithRedactor(exec.Command("helm registry login https://charts.bitnami.com/bitnami", "--username", "foo", "--password", "bar"), redactor) diff --git a/util/git/git_test.go b/util/git/git_test.go index 5cc13a9fdc74c..0eebe354dcb13 100644 --- a/util/git/git_test.go +++ b/util/git/git_test.go @@ -160,10 +160,7 @@ func TestCustomHTTPClient(t *testing.T) { assert.Equal(t, "http://proxy:5000", proxy.String()) } - os.Setenv("http_proxy", "http://proxy-from-env:7878") - defer func() { - assert.Nil(t, os.Unsetenv("http_proxy")) - }() + t.Setenv("http_proxy", "http://proxy-from-env:7878") // Get HTTPSCreds without client cert creds, but insecure connection creds = NewHTTPSCreds("test", "test", "", "", true, "", &NoopCredsStore{}, false) @@ -199,7 +196,7 @@ func TestCustomHTTPClient(t *testing.T) { defer os.RemoveAll(temppath) err = os.WriteFile(filepath.Join(temppath, "127.0.0.1"), cert, 0666) assert.NoError(t, err) - os.Setenv(common.EnvVarTLSDataPath, temppath) + t.Setenv(common.EnvVarTLSDataPath, temppath) client = GetRepoHTTPClient("https://127.0.0.1", false, creds, "") assert.NotNil(t, client) assert.NotNil(t, client.Transport) diff --git a/util/gpg/gpg_test.go b/util/gpg/gpg_test.go index 2a88c22c217ae..97f4976aa0b66 100644 --- a/util/gpg/gpg_test.go +++ b/util/gpg/gpg_test.go @@ -35,10 +35,7 @@ func initTempDir(t *testing.T) string { panic(err) } fmt.Printf("-> Using %s as GNUPGHOME\n", p) - err = os.Setenv(common.EnvGnuPGHome, p) - if err != nil { - panic(err) - } + t.Setenv(common.EnvGnuPGHome, p) t.Cleanup(func() { err := os.RemoveAll(p) if err != nil { @@ -49,12 +46,20 @@ func initTempDir(t *testing.T) string { } func Test_IsGPGEnabled(t *testing.T) { - os.Setenv("ARGOCD_GPG_ENABLED", "true") - assert.True(t, IsGPGEnabled()) - os.Setenv("ARGOCD_GPG_ENABLED", "false") - assert.False(t, IsGPGEnabled()) - os.Setenv("ARGOCD_GPG_ENABLED", "") - assert.True(t, IsGPGEnabled()) + t.Run("true", func(t *testing.T) { + t.Setenv("ARGOCD_GPG_ENABLED", "true") + assert.True(t, IsGPGEnabled()) + }) + + t.Run("false", func(t *testing.T) { + t.Setenv("ARGOCD_GPG_ENABLED", "false") + assert.False(t, IsGPGEnabled()) + }) + + t.Run("empty", func(t *testing.T) { + t.Setenv("ARGOCD_GPG_ENABLED", "") + assert.True(t, IsGPGEnabled()) + }) } func Test_GPG_InitializeGnuPG(t *testing.T) { @@ -85,46 +90,50 @@ func Test_GPG_InitializeGnuPG(t *testing.T) { assert.Len(t, keys, 1) assert.Equal(t, keys[0].Trust, "ultimate") - // GNUPGHOME is a file - we need to error out - f, err := os.CreateTemp("", "gpg-test") - assert.NoError(t, err) - defer os.Remove(f.Name()) + t.Run("GNUPGHOME is a file", func(t *testing.T) { + f, err := os.CreateTemp("", "gpg-test") + assert.NoError(t, err) + defer os.Remove(f.Name()) - os.Setenv(common.EnvGnuPGHome, f.Name()) - err = InitializeGnuPG() - assert.Error(t, err) - assert.Contains(t, err.Error(), "does not point to a directory") + // we need to error out + t.Setenv(common.EnvGnuPGHome, f.Name()) + err = InitializeGnuPG() + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not point to a directory") + }) - // Unaccessible GNUPGHOME - p = initTempDir(t) - fp := fmt.Sprintf("%s/gpg", p) - err = os.Mkdir(fp, 0000) - if err != nil { - panic(err.Error()) - } - if err != nil { - panic(err.Error()) - } - os.Setenv(common.EnvGnuPGHome, fp) - err = InitializeGnuPG() - assert.Error(t, err) - // Restore permissions so path can be deleted - err = os.Chmod(fp, 0700) - if err != nil { - panic(err.Error()) - } + t.Run("Unaccessible GNUPGHOME", func(t *testing.T) { + p := initTempDir(t) + fp := fmt.Sprintf("%s/gpg", p) + err = os.Mkdir(fp, 0o000) + if err != nil { + panic(err.Error()) + } + if err != nil { + panic(err.Error()) + } + t.Setenv(common.EnvGnuPGHome, fp) + err := InitializeGnuPG() + assert.Error(t, err) + // Restore permissions so path can be deleted + err = os.Chmod(fp, 0o700) + if err != nil { + panic(err.Error()) + } + }) - // GNUPGHOME with too wide permissions - // We do not expect an error here, because of openshift's random UIDs that - // forced us to use an emptyDir mount (#4127) - p = initTempDir(t) - err = os.Chmod(p, 0777) - if err != nil { - panic(err.Error()) - } - os.Setenv(common.EnvGnuPGHome, p) - err = InitializeGnuPG() - assert.NoError(t, err) + t.Run("GNUPGHOME with too wide permissions", func(t *testing.T) { + // We do not expect an error here, because of openshift's random UIDs that + // forced us to use an emptyDir mount (#4127) + p := initTempDir(t) + err := os.Chmod(p, 0o777) + if err != nil { + panic(err.Error()) + } + t.Setenv(common.EnvGnuPGHome, p) + err = InitializeGnuPG() + assert.NoError(t, err) + }) } func Test_GPG_KeyManagement(t *testing.T) { @@ -219,7 +228,6 @@ func Test_GPG_KeyManagement(t *testing.T) { assert.NoError(t, err) assert.Len(t, keys, 3) } - } func Test_ImportPGPKeysFromString(t *testing.T) { @@ -236,7 +244,6 @@ func Test_ImportPGPKeysFromString(t *testing.T) { assert.Contains(t, keys[0].Owner, "noreply@github.com") assert.Equal(t, "unknown", keys[0].Trust) assert.Equal(t, "unknown", keys[0].SubType) - } func Test_ValidateGPGKeysFromString(t *testing.T) { @@ -258,7 +265,6 @@ func Test_ValidateGPGKeysFromString(t *testing.T) { assert.NoError(t, err) assert.Len(t, keys, 2) } - } func Test_ValidateGPGKeys(t *testing.T) { @@ -446,16 +452,17 @@ func Test_GPG_ParseGitCommitVerification(t *testing.T) { } func Test_GetGnuPGHomePath(t *testing.T) { - { - os.Setenv(common.EnvGnuPGHome, "") + t.Run("empty", func(t *testing.T) { + t.Setenv(common.EnvGnuPGHome, "") p := common.GetGnuPGHomePath() assert.Equal(t, common.DefaultGnuPgHomePath, p) - } - { - os.Setenv(common.EnvGnuPGHome, "/tmp/gpghome") + }) + + t.Run("tempdir", func(t *testing.T) { + t.Setenv(common.EnvGnuPGHome, "/tmp/gpghome") p := common.GetGnuPGHomePath() assert.Equal(t, "/tmp/gpghome", p) - } + }) } func Test_KeyID(t *testing.T) { @@ -494,6 +501,7 @@ func Test_IsShortKeyID(t *testing.T) { assert.False(t, IsShortKeyID(longKeyID)) assert.False(t, IsShortKeyID("ab")) } + func Test_IsLongKeyID(t *testing.T) { assert.True(t, IsLongKeyID(longKeyID)) assert.False(t, IsLongKeyID(shortKeyID)) @@ -530,7 +538,6 @@ func Test_IsSecretKey(t *testing.T) { assert.NoError(t, err) assert.False(t, secret) } - } func Test_SyncKeyRingFromDirectory(t *testing.T) { diff --git a/util/log/logrus_test.go b/util/log/logrus_test.go index 109473fff3c30..06cf71fd952b0 100644 --- a/util/log/logrus_test.go +++ b/util/log/logrus_test.go @@ -1,7 +1,6 @@ package log import ( - "os" "testing" "github.com/sirupsen/logrus" @@ -15,12 +14,12 @@ func TestCreateFormatter(t *testing.T) { }) t.Run("log format is text", func(t *testing.T) { t.Run("FORCE_LOG_COLORS == 1", func(t *testing.T) { - os.Setenv("FORCE_LOG_COLORS", "1") + t.Setenv("FORCE_LOG_COLORS", "1") result := CreateFormatter("text") assert.Equal(t, &logrus.TextFormatter{ForceColors: true}, result) }) t.Run("FORCE_LOG_COLORS != 1", func(t *testing.T) { - os.Setenv("FORCE_LOG_COLORS", "0") + t.Setenv("FORCE_LOG_COLORS", "0") result := CreateFormatter("text") assert.Equal(t, &logrus.TextFormatter{}, result) }) diff --git a/util/proxy/proxy_test.go b/util/proxy/proxy_test.go index 8f50203a8d4a4..e2bee322250bd 100644 --- a/util/proxy/proxy_test.go +++ b/util/proxy/proxy_test.go @@ -3,7 +3,6 @@ package proxy import ( "net/http" "net/http/httptest" - "os" "os/exec" "testing" @@ -36,8 +35,7 @@ func TestGetCallBack(t *testing.T) { }) t.Run("custom proxy absent", func(t *testing.T) { proxyEnv := "http://proxy:8888" - os.Setenv("http_proxy", "http://proxy:8888") - defer os.Unsetenv("http_proxy") + t.Setenv("http_proxy", "http://proxy:8888") url, err := GetCallback("")(httptest.NewRequest(http.MethodGet, proxyEnv, nil)) assert.Nil(t, err) assert.Equal(t, proxyEnv, url.String()) diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index 52b7c76aa755d..d01ba3ef5f32d 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -9,7 +9,6 @@ import ( "math" "net/http" "net/http/httptest" - "os" "strconv" "strings" "testing" @@ -449,59 +448,47 @@ func TestCacheValueGetters(t *testing.T) { }) t.Run("Valid environment overrides", func(t *testing.T) { - os.Setenv(envLoginMaxFailCount, "5") - os.Setenv(envLoginMaxCacheSize, "5") + t.Setenv(envLoginMaxFailCount, "5") + t.Setenv(envLoginMaxCacheSize, "5") mlf := getMaxLoginFailures() assert.Equal(t, 5, mlf) mcs := getMaximumCacheSize() assert.Equal(t, 5, mcs) - - os.Setenv(envLoginMaxFailCount, "") - os.Setenv(envLoginMaxCacheSize, "") }) t.Run("Invalid environment overrides", func(t *testing.T) { - os.Setenv(envLoginMaxFailCount, "invalid") - os.Setenv(envLoginMaxCacheSize, "invalid") + t.Setenv(envLoginMaxFailCount, "invalid") + t.Setenv(envLoginMaxCacheSize, "invalid") mlf := getMaxLoginFailures() assert.Equal(t, defaultMaxLoginFailures, mlf) mcs := getMaximumCacheSize() assert.Equal(t, defaultMaxCacheSize, mcs) - - os.Setenv(envLoginMaxFailCount, "") - os.Setenv(envLoginMaxCacheSize, "") }) t.Run("Less than allowed in environment overrides", func(t *testing.T) { - os.Setenv(envLoginMaxFailCount, "-1") - os.Setenv(envLoginMaxCacheSize, "-1") + t.Setenv(envLoginMaxFailCount, "-1") + t.Setenv(envLoginMaxCacheSize, "-1") mlf := getMaxLoginFailures() assert.Equal(t, defaultMaxLoginFailures, mlf) mcs := getMaximumCacheSize() assert.Equal(t, defaultMaxCacheSize, mcs) - - os.Setenv(envLoginMaxFailCount, "") - os.Setenv(envLoginMaxCacheSize, "") }) t.Run("Greater than allowed in environment overrides", func(t *testing.T) { - os.Setenv(envLoginMaxFailCount, fmt.Sprintf("%d", math.MaxInt32+1)) - os.Setenv(envLoginMaxCacheSize, fmt.Sprintf("%d", math.MaxInt32+1)) + t.Setenv(envLoginMaxFailCount, fmt.Sprintf("%d", math.MaxInt32+1)) + t.Setenv(envLoginMaxCacheSize, fmt.Sprintf("%d", math.MaxInt32+1)) mlf := getMaxLoginFailures() assert.Equal(t, defaultMaxLoginFailures, mlf) mcs := getMaximumCacheSize() assert.Equal(t, defaultMaxCacheSize, mcs) - - os.Setenv(envLoginMaxFailCount, "") - os.Setenv(envLoginMaxCacheSize, "") }) } @@ -561,7 +548,7 @@ func TestMaxCacheSize(t *testing.T) { invalidUsers := []string{"invalid1", "invalid2", "invalid3", "invalid4", "invalid5", "invalid6", "invalid7"} // Temporarily decrease max cache size - os.Setenv(envLoginMaxCacheSize, "5") + t.Setenv(envLoginMaxCacheSize, "5") for _, user := range invalidUsers { err := mgr.VerifyUsernamePassword(user, "password") @@ -577,7 +564,7 @@ func TestFailedAttemptsExpiry(t *testing.T) { invalidUsers := []string{"invalid1", "invalid2", "invalid3", "invalid4", "invalid5", "invalid6", "invalid7"} - os.Setenv(envLoginFailureWindowSeconds, "1") + t.Setenv(envLoginFailureWindowSeconds, "1") for _, user := range invalidUsers { err := mgr.VerifyUsernamePassword(user, "password") @@ -589,8 +576,6 @@ func TestFailedAttemptsExpiry(t *testing.T) { err := mgr.VerifyUsernamePassword("invalid8", "password") assert.Error(t, err) assert.Len(t, mgr.GetLoginFailures(), 1) - - os.Setenv(envLoginFailureWindowSeconds, "") } func getKubeClientWithConfig(config map[string]string, secretConfig map[string][]byte) *fake.Clientset {