From bc16b5ff85b425a35a72a6978eb8ae69a8d89aa8 Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Mon, 27 Nov 2023 14:13:51 -0800 Subject: [PATCH] interop: support custom creds flag for stress test client (#6809) --- interop/stress/client/main.go | 58 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/interop/stress/client/main.go b/interop/stress/client/main.go index 3516ad7166bf..0055c561c557 100644 --- a/interop/stress/client/main.go +++ b/interop/stress/client/main.go @@ -25,6 +25,7 @@ import ( "fmt" "math/rand" "net" + "os" "strconv" "strings" "sync" @@ -34,27 +35,37 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/google" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/interop" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" "google.golang.org/grpc/testdata" + _ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath. + testgrpc "google.golang.org/grpc/interop/grpc_testing" metricspb "google.golang.org/grpc/interop/stress/grpc_testing" ) +const ( + googleDefaultCredsName = "google_default_credentials" + computeEngineCredsName = "compute_engine_channel_creds" +) + var ( - serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") - testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") - testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") - numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") - numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") - metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") - useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") - testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") - tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") - caFile = flag.String("ca_file", "", "The file containing the CA root cert file") + serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") + testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") + testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") + numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") + numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") + metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") + useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") + testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") + tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") + caFile = flag.String("ca_file", "", "The file containing the CA root cert file") + customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use") totalNumCalls int64 logger = grpclog.Component("stress") @@ -71,12 +82,13 @@ func parseTestCases(testCaseString string) []testCaseWithWeight { testCaseStrings := strings.Split(testCaseString, ",") testCases := make([]testCaseWithWeight, len(testCaseStrings)) for i, str := range testCaseStrings { - testCase := strings.Split(str, ":") - if len(testCase) != 2 { + testCaseNameAndWeight := strings.Split(str, ":") + if len(testCaseNameAndWeight) != 2 { panic(fmt.Sprintf("invalid test case with weight: %s", str)) } // Check if test case is supported. - switch testCase[0] { + testCaseName := strings.ToLower(testCaseNameAndWeight[0]) + switch testCaseName { case "empty_unary", "large_unary", @@ -90,10 +102,10 @@ func parseTestCases(testCaseString string) []testCaseWithWeight { "status_code_and_message", "custom_metadata": default: - panic(fmt.Sprintf("unknown test type: %s", testCase[0])) + panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0])) } - testCases[i].name = testCase[0] - w, err := strconv.Atoi(testCase[1]) + testCases[i].name = testCaseName + w, err := strconv.Atoi(testCaseNameAndWeight[1]) if err != nil { panic(fmt.Sprintf("%v", err)) } @@ -263,6 +275,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) { logger.Infof("use_tls: %t", *useTLS) logger.Infof("use_test_ca: %t", *testCA) logger.Infof("server_host_override: %s", *tlsServerName) + logger.Infof("custom_credentials_type: %s", *customCredentialsType) logger.Infoln("addresses:") for i, addr := range addresses { @@ -276,7 +289,15 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) { func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) { var opts []grpc.DialOption - if useTLS { + if *customCredentialsType != "" { + if *customCredentialsType == googleDefaultCredsName { + opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials())) + } else if *customCredentialsType == computeEngineCredsName { + opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials())) + } else { + logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType) + } + } else if useTLS { var sn string if tlsServerName != "" { sn = tlsServerName @@ -303,6 +324,7 @@ func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.C func main() { flag.Parse() + resolver.SetDefaultScheme("dns") addresses := strings.Split(*serverAddresses, ",") tests := parseTestCases(*testCases) logParameterInfo(addresses, tests) @@ -337,6 +359,6 @@ func main() { close(stop) } wg.Wait() - logger.Infof("Total calls made: %v", totalNumCalls) + fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls) logger.Infof(" ===== ALL DONE ===== ") }