Skip to content

Commit

Permalink
interop: support custom creds flag for stress test client (#6809)
Browse files Browse the repository at this point in the history
  • Loading branch information
temawi authored Nov 27, 2023
1 parent 02ea031 commit bc16b5f
Showing 1 changed file with 40 additions and 18 deletions.
58 changes: 40 additions & 18 deletions interop/stress/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"math/rand"
"net"
"os"
"strconv"
"strings"
"sync"
Expand All @@ -34,27 +35,37 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/google"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/interop"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"

_ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath.

testgrpc "google.golang.org/grpc/interop/grpc_testing"
metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
)

const (
googleDefaultCredsName = "google_default_credentials"
computeEngineCredsName = "compute_engine_channel_creds"
)

var (
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")

totalNumCalls int64
logger = grpclog.Component("stress")
Expand All @@ -71,12 +82,13 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
testCaseStrings := strings.Split(testCaseString, ",")
testCases := make([]testCaseWithWeight, len(testCaseStrings))
for i, str := range testCaseStrings {
testCase := strings.Split(str, ":")
if len(testCase) != 2 {
testCaseNameAndWeight := strings.Split(str, ":")
if len(testCaseNameAndWeight) != 2 {
panic(fmt.Sprintf("invalid test case with weight: %s", str))
}
// Check if test case is supported.
switch testCase[0] {
testCaseName := strings.ToLower(testCaseNameAndWeight[0])
switch testCaseName {
case
"empty_unary",
"large_unary",
Expand All @@ -90,10 +102,10 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
"status_code_and_message",
"custom_metadata":
default:
panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
}
testCases[i].name = testCase[0]
w, err := strconv.Atoi(testCase[1])
testCases[i].name = testCaseName
w, err := strconv.Atoi(testCaseNameAndWeight[1])
if err != nil {
panic(fmt.Sprintf("%v", err))
}
Expand Down Expand Up @@ -263,6 +275,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
logger.Infof("use_tls: %t", *useTLS)
logger.Infof("use_test_ca: %t", *testCA)
logger.Infof("server_host_override: %s", *tlsServerName)
logger.Infof("custom_credentials_type: %s", *customCredentialsType)

logger.Infoln("addresses:")
for i, addr := range addresses {
Expand All @@ -276,7 +289,15 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {

func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
var opts []grpc.DialOption
if useTLS {
if *customCredentialsType != "" {
if *customCredentialsType == googleDefaultCredsName {
opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
} else if *customCredentialsType == computeEngineCredsName {
opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
} else {
logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
}
} else if useTLS {
var sn string
if tlsServerName != "" {
sn = tlsServerName
Expand All @@ -303,6 +324,7 @@ func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.C

func main() {
flag.Parse()
resolver.SetDefaultScheme("dns")
addresses := strings.Split(*serverAddresses, ",")
tests := parseTestCases(*testCases)
logParameterInfo(addresses, tests)
Expand Down Expand Up @@ -337,6 +359,6 @@ func main() {
close(stop)
}
wg.Wait()
logger.Infof("Total calls made: %v", totalNumCalls)
fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
logger.Infof(" ===== ALL DONE ===== ")
}

0 comments on commit bc16b5f

Please sign in to comment.