diff --git a/go/cmd/vtbench/vtbench.go b/go/cmd/vtbench/vtbench.go index deab7f4e6be..663f45568e7 100644 --- a/go/cmd/vtbench/vtbench.go +++ b/go/cmd/vtbench/vtbench.go @@ -18,16 +18,16 @@ package main import ( "context" - "flag" "fmt" "strings" "time" "github.com/spf13/pflag" + "vitess.io/vitess/go/vt/grpccommon" + "vitess.io/vitess/go/exit" "vitess.io/vitess/go/vt/dbconfigs" - "vitess.io/vitess/go/vt/grpccommon" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/logutil" "vitess.io/vitess/go/vt/servenv" @@ -37,9 +37,6 @@ import ( _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" // Import and register the gRPC tabletconn client _ "vitess.io/vitess/go/vt/vttablet/grpctabletconn" - - // Include deprecation warnings for soon-to-be-unsupported flag invocations. - _flag "vitess.io/vitess/go/internal/flag" ) /* @@ -86,37 +83,48 @@ import ( */ var ( - // connection flags - host = flag.String("host", "", "vtgate host(s) in the form 'host1,host2,...'") - port = flag.Int("port", 0, "vtgate port") - unixSocket = flag.String("unix_socket", "", "vtgate unix socket") - protocol = flag.String("protocol", "mysql", "client protocol, either mysql (default), grpc-vtgate, or grpc-vttablet") - user = flag.String("user", "", "username to connect using mysql (password comes from the db-credentials-file)") - db = flag.String("db", "", "db name to use when connecting / running the queries (e.g. @replica, keyspace, keyspace/shard etc)") - - // test flags - deadline = flag.Duration("deadline", 5*time.Minute, "maximum duration for the test run (default 5 minutes)") - sql = flag.String("sql", "", "sql statement to execute") - threads = flag.Int("threads", 2, "number of parallel threads to run") - count = flag.Int("count", 1000, "number of queries per thread") + host, unixSocket, user, db, sql string + port int + protocol = "mysql" + deadline = 5 * time.Minute + threads = 2 + count = 1000 ) -func main() { - logger := logutil.NewConsoleLogger() - flag.CommandLine.SetOutput(logutil.NewLoggerWriter(logger)) +func initFlags(fs *pflag.FlagSet) { + fs.StringVar(&host, "host", host, "VTGate host(s) in the form 'host1,host2,...'") + fs.IntVar(&port, "port", port, "VTGate port") + fs.StringVar(&unixSocket, "unix_socket", unixSocket, "VTGate unix socket") + fs.StringVar(&protocol, "protocol", protocol, "Client protocol, either mysql (default), grpc-vtgate, or grpc-vttablet") + fs.StringVar(&user, "user", user, "Username to connect using mysql (password comes from the db-credentials-file)") + fs.StringVar(&db, "db", db, "Database name to use when connecting / running the queries (e.g. @replica, keyspace, keyspace/shard etc)") - defer exit.Recover() + fs.DurationVar(&deadline, "deadline", deadline, "Maximum duration for the test run (default 5 minutes)") + fs.StringVar(&sql, "sql", sql, "SQL statement to execute") + fs.IntVar(&threads, "threads", threads, "Number of parallel threads to run") + fs.IntVar(&count, "count", count, "Number of queries per thread") - flag.Lookup("logtostderr").Value.Set("true") - fs := pflag.NewFlagSet("vtbench", pflag.ExitOnError) grpccommon.RegisterFlags(fs) log.RegisterFlags(fs) logutil.RegisterFlags(fs) servenv.RegisterMySQLServerFlags(fs) - _flag.Parse(fs) +} + +func main() { + servenv.OnParseFor("vtbench", func(fs *pflag.FlagSet) { + logger := logutil.NewConsoleLogger() + fs.SetOutput(logutil.NewLoggerWriter(logger)) + + initFlags(fs) + _ = fs.Set("logtostderr", "true") + }) + + servenv.ParseFlags("vtbench") + + defer exit.Recover() clientProto := vtbench.MySQL - switch *protocol { + switch protocol { case "", "mysql": clientProto = vtbench.MySQL case "grpc-vtgate": @@ -124,51 +132,51 @@ func main() { case "grpc-vttablet": clientProto = vtbench.GRPCVttablet default: - log.Exitf("invalid client protocol %s", *protocol) + log.Exitf("invalid client protocol %s", protocol) } - if (*host != "" || *port != 0) && *unixSocket != "" { + if (host != "" || port != 0) && unixSocket != "" { log.Exitf("can't specify both host:port and unix_socket") } - if *host != "" && *port == 0 { + if host != "" && port == 0 { log.Exitf("must specify port when using host") } - if *host == "" && *port != 0 { + if host == "" && port != 0 { log.Exitf("must specify host when using port") } - if *host == "" && *port == 0 && *unixSocket == "" { + if host == "" && port == 0 && unixSocket == "" { log.Exitf("vtbench requires either host/port or unix_socket") } - if *sql == "" { + if sql == "" { log.Exitf("must specify sql") } var password string if clientProto == vtbench.MySQL { var err error - _, password, err = dbconfigs.GetCredentialsServer().GetUserAndPassword(*user) + _, password, err = dbconfigs.GetCredentialsServer().GetUserAndPassword(user) if err != nil { - log.Exitf("error reading password for user %v from file: %v", *user, err) + log.Exitf("error reading password for user %v from file: %v", user, err) } } connParams := vtbench.ConnParams{ - Hosts: strings.Split(*host, ","), - Port: *port, - UnixSocket: *unixSocket, + Hosts: strings.Split(host, ","), + Port: port, + UnixSocket: unixSocket, Protocol: clientProto, - DB: *db, - Username: *user, + DB: db, + Username: user, Password: password, } - b := vtbench.NewBench(*threads, *count, connParams, *sql) + b := vtbench.NewBench(threads, count, connParams, sql) - ctx, cancel := context.WithTimeout(context.Background(), *deadline) + ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() fmt.Printf("Initializing test with %s protocol / %d threads / %d iterations\n",