diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index b0ee9cdd543..435915690af 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -76,40 +76,42 @@ var ExternalDisableUsers bool = false var ErrCouldNotLockDatabase = goerrors.NewKind("database \"%s\" is locked by another dolt process; either clone the database to run a second server, or stop the dolt process which currently holds an exclusive write lock on the database") +type Config struct { + ServerConfig servercfg.ServerConfig + DoltEnv *env.DoltEnv + SkipRootUserInit bool + Version string + Controller *svcs.Controller + ProtocolListenerFactory server.ProtocolListenerFunc +} + // Serve starts a MySQL-compatible server. Returns any errors that were encountered. func Serve( ctx context.Context, - version string, - serverConfig servercfg.ServerConfig, - controller *svcs.Controller, - dEnv *env.DoltEnv, - skipRootUserInitialization bool, + cfg *Config, ) (startError error, closeError error) { // Code is easier to work through if we assume that serverController is never nil - if controller == nil { - controller = svcs.NewController() + if cfg.Controller == nil { + cfg.Controller = svcs.NewController() } - ConfigureServices(serverConfig, controller, version, dEnv, skipRootUserInitialization) + ConfigureServices(cfg) - go controller.Start(ctx) - err := controller.WaitForStart() + go cfg.Controller.Start(ctx) + err := cfg.Controller.WaitForStart() if err != nil { return err, nil } - return nil, controller.WaitForStop() + return nil, cfg.Controller.WaitForStop() } func ConfigureServices( - serverConfig servercfg.ServerConfig, - controller *svcs.Controller, - version string, - dEnv *env.DoltEnv, - skipRootUserInitialization bool, + cfg *Config, ) { + controller := cfg.Controller ValidateConfigStep := &svcs.AnonService{ InitF: func(context.Context) error { - return servercfg.ValidateConfig(serverConfig) + return servercfg.ValidateConfig(cfg.ServerConfig) }, } controller.Register(ValidateConfigStep) @@ -118,18 +120,18 @@ func ConfigureServices( lgr.SetOutput(cli.CliErr) InitLogging := &svcs.AnonService{ InitF: func(context.Context) error { - level, err := logrus.ParseLevel(serverConfig.LogLevel().String()) + level, err := logrus.ParseLevel(cfg.ServerConfig.LogLevel().String()) if err != nil { return err } logrus.SetLevel(level) - switch strings.ToLower(string(serverConfig.LogFormat())) { + switch strings.ToLower(string(cfg.ServerConfig.LogFormat())) { case string(servercfg.LogFormat_JSON): logrus.SetFormatter(&logrus.JSONFormatter{}) case string(servercfg.LogFormat_Text): logrus.SetFormatter(&logrus.TextFormatter{}) default: - return fmt.Errorf("unknown log format: %s", serverConfig.LogFormat()) + return fmt.Errorf("unknown log format: %s", cfg.ServerConfig.LogFormat()) } sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ @@ -164,12 +166,12 @@ func ConfigureServices( } controller.Register(InitLogging) - controller.Register(newHeartbeatService(version, dEnv)) + controller.Register(newHeartbeatService(cfg.Version, cfg.DoltEnv)) - fs := dEnv.FS + fs := cfg.DoltEnv.FS InitFailsafes := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { - dEnv.Config.SetFailsafes(env.DefaultFailsafeConfig) + cfg.DoltEnv.Config.SetFailsafes(env.DefaultFailsafeConfig) return nil }, } @@ -178,7 +180,7 @@ func ConfigureServices( var mrEnv *env.MultiRepoEnv InitMultiEnv := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { - mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv) + mrEnv, err = env.MultiEnvForDirectory(ctx, cfg.DoltEnv.Config.WriteableConfig(), fs, cfg.DoltEnv.Version, cfg.DoltEnv) return err }, } @@ -199,11 +201,11 @@ func ConfigureServices( var localCreds *LocalCreds InitServerLocalCreds := &svcs.AnonService{ InitF: func(context.Context) (err error) { - localCreds, err = persistServerLocalCreds(serverConfig.Port(), dEnv) + localCreds, err = persistServerLocalCreds(cfg.ServerConfig.Port(), cfg.DoltEnv) return err }, StopF: func() error { - RemoveLocalCreds(dEnv.FS) + RemoveLocalCreds(cfg.DoltEnv.FS) return nil }, } @@ -212,7 +214,7 @@ func ConfigureServices( var clusterController *cluster.Controller InitClusterController := &svcs.AnonService{ InitF: func(context.Context) (err error) { - clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config()) + clusterController, err = cluster.NewController(lgr, cfg.ServerConfig.ClusterConfig(), mrEnv.Config()) return err }, } @@ -221,7 +223,7 @@ func ConfigureServices( var serverConf server.Config LoadServerConfig := &svcs.AnonService{ InitF: func(context.Context) (err error) { - serverConf, err = getConfigFromServerConfig(serverConfig) + serverConf, err = getConfigFromServerConfig(cfg.ServerConfig, cfg.ProtocolListenerFactory) return err }, } @@ -232,20 +234,20 @@ func ConfigureServices( InitSqlEngineConfig := &svcs.AnonService{ InitF: func(context.Context) error { config = &engine.SqlEngineConfig{ - IsReadOnly: serverConfig.ReadOnly(), - PrivFilePath: serverConfig.PrivilegeFilePath(), - BranchCtrlFilePath: serverConfig.BranchControlFilePath(), - DoltCfgDirPath: serverConfig.CfgDir(), - ServerUser: serverConfig.User(), - ServerPass: serverConfig.Password(), - ServerHost: serverConfig.Host(), - Autocommit: serverConfig.AutoCommit(), - DoltTransactionCommit: serverConfig.DoltTransactionCommit(), - JwksConfig: serverConfig.JwksConfig(), - SystemVariables: serverConfig.SystemVars(), + IsReadOnly: cfg.ServerConfig.ReadOnly(), + PrivFilePath: cfg.ServerConfig.PrivilegeFilePath(), + BranchCtrlFilePath: cfg.ServerConfig.BranchControlFilePath(), + DoltCfgDirPath: cfg.ServerConfig.CfgDir(), + ServerUser: cfg.ServerConfig.User(), + ServerPass: cfg.ServerConfig.Password(), + ServerHost: cfg.ServerConfig.Host(), + Autocommit: cfg.ServerConfig.AutoCommit(), + DoltTransactionCommit: cfg.ServerConfig.DoltTransactionCommit(), + JwksConfig: cfg.ServerConfig.JwksConfig(), + SystemVariables: cfg.ServerConfig.SystemVars(), ClusterController: clusterController, BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, - SkipRootUserInitialization: skipRootUserInitialization, + SkipRootUserInitialization: cfg.SkipRootUserInit, } return nil }, @@ -255,7 +257,7 @@ func ConfigureServices( var esStatus eventscheduler.SchedulerStatus InitEventSchedulerStatus := &svcs.AnonService{ InitF: func(context.Context) (err error) { - esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus()) + esStatus, err = getEventSchedulerStatus(cfg.ServerConfig.EventSchedulerStatus()) if err != nil { return err } @@ -267,8 +269,8 @@ func ConfigureServices( InitAutoGCController := &svcs.AnonService{ InitF: func(context.Context) error { - if serverConfig.AutoGCBehavior() != nil && - serverConfig.AutoGCBehavior().Enable() { + if cfg.ServerConfig.AutoGCBehavior() != nil && + cfg.ServerConfig.AutoGCBehavior().Enable() { config.AutoGCController = sqle.NewAutoGCController(lgr) } return nil @@ -352,7 +354,7 @@ func ConfigureServices( // in the configuration files for a sql-server, and not global for the whole host. PersistNondeterministicSystemVarDefaults := &svcs.AnonService{ InitF: func(ctx context.Context) error { - err := dsess.PersistSystemVarDefaults(dEnv) + err := dsess.PersistSystemVarDefaults(cfg.DoltEnv) if err != nil { logrus.Errorf("unable to persist system variable defaults: %v", err) } @@ -404,7 +406,7 @@ func ConfigureServices( if logBin == 1 { logrus.Infof("Enabling binary logging for branch %s", logBinBranch) - binlogProducer, err := binlogreplication.NewBinlogProducer(dEnv.FS) + binlogProducer, err := binlogreplication.NewBinlogProducer(cfg.DoltEnv.FS) if err != nil { return err } @@ -441,7 +443,7 @@ func ConfigureServices( InitF: func(ctx context.Context) error { // If privileges.db has already been initialized, indicating that this is NOT the // first time sql-server has been launched, then don't initialize the root superuser. - if permissionDbExists, err := doesPrivilegesDbExist(dEnv, serverConfig.PrivilegeFilePath()); err != nil { + if permissionDbExists, err := doesPrivilegesDbExist(cfg.DoltEnv, cfg.ServerConfig.PrivilegeFilePath()); err != nil { return err } else if permissionDbExists { logrus.Debug("privileges.db already exists, not creating root superuser") @@ -480,7 +482,7 @@ func ConfigureServices( // for persisting the privileges database. The filesys API // is in the Dolt layer, so when the file path is passed to // GMS, it expects it to be a path on disk, and errors out. - if _, isInMemFs := dEnv.FS.(*filesys.InMemFS); isInMemFs { + if _, isInMemFs := cfg.DoltEnv.FS.(*filesys.InMemFS); isInMemFs { return nil } else { sqlCtx, err := sqlEngine.NewDefaultContext(context.Background()) @@ -496,8 +498,8 @@ func ConfigureServices( var metListener *metricsListener InitMetricsListener := &svcs.AnonService{ InitF: func(context.Context) (err error) { - labels := serverConfig.MetricsLabels() - metListener, err = newMetricsListener(labels, version, clusterController) + labels := cfg.ServerConfig.MetricsLabels() + metListener, err = newMetricsListener(labels, cfg.Version, clusterController) return err }, StopF: func() error { @@ -539,10 +541,10 @@ func ConfigureServices( RunMetricsServer := &svcs.AnonService{ InitF: func(context.Context) (err error) { - if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 { + if cfg.ServerConfig.MetricsHost() != "" && cfg.ServerConfig.MetricsPort() > 0 { metSrv.state.Swap(svcs.ServiceState_Init) - addr := fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()) + addr := fmt.Sprintf("%s:%d", cfg.ServerConfig.MetricsHost(), cfg.ServerConfig.MetricsPort()) metSrv.lis, err = net.Listen("tcp", addr) if err != nil { return err @@ -583,16 +585,16 @@ func ConfigureServices( var remoteSrv RemoteSrvService RunRemoteSrv := &svcs.AnonService{ InitF: func(ctx context.Context) error { - if serverConfig.RemotesapiPort() == nil { + if cfg.ServerConfig.RemotesapiPort() == nil { return nil } remoteSrv.state.Swap(svcs.ServiceState_Init) - port := *serverConfig.RemotesapiPort() + port := *cfg.ServerConfig.RemotesapiPort() apiReadOnly := false - if serverConfig.RemotesapiReadOnly() != nil { - apiReadOnly = *serverConfig.RemotesapiReadOnly() + if cfg.ServerConfig.RemotesapiReadOnly() != nil { + apiReadOnly = *cfg.ServerConfig.RemotesapiReadOnly() } listenaddr := fmt.Sprintf(":%d", port) @@ -601,7 +603,7 @@ func ConfigureServices( } args := remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), - ReadOnly: apiReadOnly || serverConfig.ReadOnly(), + ReadOnly: apiReadOnly || cfg.ServerConfig.ReadOnly(), HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, @@ -666,7 +668,7 @@ func ConfigureServices( } args.FS = sqlEngine.FileSystem() - clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) + clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(cfg.ServerConfig.ClusterConfig()) if err != nil { lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err) return err @@ -675,7 +677,7 @@ func ConfigureServices( clusterRemoteSrv.srv, err = remotesrv.NewServer(args) if err != nil { - lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) + lgr.Errorf("error creating remotesapi server on port %d: %v", *cfg.ServerConfig.RemotesapiPort(), err) return err } clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer()) @@ -715,13 +717,13 @@ func ConfigureServices( var sqlServerClosed bool InitSQLServer := &svcs.AnonService{ InitF: func(context.Context) (err error) { - v, ok := serverConfig.(servercfg.ValidatingServerConfig) + v, ok := cfg.ServerConfig.(servercfg.ValidatingServerConfig) if ok && v.GoldenMysqlConnectionString() != "" { mySQLServer, err = server.NewServerWithHandler( serverConf, sqlEngine.GetUnderlyingEngine(), sql.NewContext, - newSessionBuilder(sqlEngine, serverConfig), + newSessionBuilder(sqlEngine, cfg.ServerConfig), metListener, func(h mysql.Handler) (mysql.Handler, error) { return golden.NewValidatingHandler(h, v.GoldenMysqlConnectionString(), logrus.StandardLogger()) @@ -732,7 +734,7 @@ func ConfigureServices( serverConf, sqlEngine.GetUnderlyingEngine(), sql.NewContext, - newSessionBuilder(sqlEngine, serverConfig), + newSessionBuilder(sqlEngine, cfg.ServerConfig), metListener, ) } @@ -1065,7 +1067,7 @@ func newSessionBuilder(se *engine.SqlEngine, config servercfg.ServerConfig) serv } // getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server. -func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Config, error) { +func getConfigFromServerConfig(serverConfig servercfg.ServerConfig, plf server.ProtocolListenerFunc) (server.Config, error) { serverConf, err := handleProtocolAndAddress(serverConfig) if err != nil { return server.Config{}, err @@ -1095,6 +1097,7 @@ func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Conf serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport() serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen() serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery() + serverConf.ProtocolListenerFactory = plf return serverConf, nil } diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index b6fe44f65d6..762d6acefa1 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -212,7 +212,12 @@ func TestServerGoodParams(t *testing.T) { t.Run(servercfg.ConfigInfo(test), func(t *testing.T) { sc := svcs.NewController() go func(config servercfg.ServerConfig, sc *svcs.Controller) { - _, _ = Serve(context.Background(), "0.0.0", config, sc, env, false) + _, _ = Serve(context.Background(), &Config{ + Version: "0.0.0", + ServerConfig: config, + Controller: sc, + DoltEnv: env, + }) }(test, sc) err := sc.WaitForStart() require.NoError(t, err) @@ -240,7 +245,12 @@ func TestServerSelect(t *testing.T) { sc := svcs.NewController() defer sc.Stop() go func() { - _, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env, false) + _, _ = Serve(context.Background(), &Config{ + Version: "0.0.0", + ServerConfig: serverConfig, + Controller: sc, + DoltEnv: env, + }) }() err = sc.WaitForStart() require.NoError(t, err) @@ -339,7 +349,12 @@ func TestServerSetDefaultBranch(t *testing.T) { sc := svcs.NewController() defer sc.Stop() go func() { - _, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv, false) + _, _ = Serve(context.Background(), &Config{ + Version: "0.0.0", + ServerConfig: serverConfig, + Controller: sc, + DoltEnv: dEnv, + }) }() err = sc.WaitForStart() require.NoError(t, err) @@ -503,7 +518,12 @@ func TestReadReplica(t *testing.T) { os.Chdir(multiSetup.DbPaths[readReplicaDbName]) go func() { - err, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, multiSetup.GetEnv(readReplicaDbName), false) + err, _ = Serve(context.Background(), &Config{ + Version: "0.0.0", + ServerConfig: serverConfig, + Controller: sc, + DoltEnv: multiSetup.GetEnv(readReplicaDbName), + }) require.NoError(t, err) }() require.NoError(t, sc.WaitForStart()) diff --git a/go/cmd/dolt/commands/sqlserver/sqlserver.go b/go/cmd/dolt/commands/sqlserver/sqlserver.go index 2dff0326cfb..c09a97393ab 100644 --- a/go/cmd/dolt/commands/sqlserver/sqlserver.go +++ b/go/cmd/dolt/commands/sqlserver/sqlserver.go @@ -267,7 +267,13 @@ func StartServer(ctx context.Context, versionStr, commandStr string, args []stri cli.Printf("Starting server with Config %v\n", servercfg.ConfigInfo(serverConfig)) skipRootUserInitialization := apr.Contains(skipRootUserInitialization) - startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv, skipRootUserInitialization) + startError, closeError := Serve(ctx, &Config{ + Version: versionStr, + ServerConfig: serverConfig, + Controller: controller, + DoltEnv: dEnv, + SkipRootUserInit: skipRootUserInitialization, + }) if startError != nil { return startError } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_server_tests.go b/go/libraries/doltcore/sqle/enginetest/dolt_server_tests.go index 497737b89f8..ab1be919651 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_server_tests.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_server_tests.go @@ -551,7 +551,12 @@ func makeDestinationSlice(t *testing.T, columnTypes []*gosql.ColumnType) []inter func startServerOnEnv(t *testing.T, serverConfig servercfg.ServerConfig, dEnv *env.DoltEnv) (*svcs.Controller, servercfg.ServerConfig) { sc := svcs.NewController() go func() { - _, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv, false) + _, _ = sqlserver.Serve(context.Background(), &sqlserver.Config{ + Version: "0.0.0", + ServerConfig: serverConfig, + Controller: sc, + DoltEnv: dEnv, + }) }() err := sc.WaitForStart() require.NoError(t, err) diff --git a/go/performance/replicationbench/replica_test.go b/go/performance/replicationbench/replica_test.go index ece31b30f16..2365956b647 100644 --- a/go/performance/replicationbench/replica_test.go +++ b/go/performance/replicationbench/replica_test.go @@ -171,7 +171,12 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, //b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg)) eg.Go(func() (err error) { - startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv, false) + startErr, closeErr := srv.Serve(ctx, &srv.Config{ + Version: "", + ServerConfig: cfg, + Controller: sc, + DoltEnv: dEnv, + }) if startErr != nil { return startErr } diff --git a/go/performance/serverbench/bench_test.go b/go/performance/serverbench/bench_test.go index ade78d30927..5747d76265a 100644 --- a/go/performance/serverbench/bench_test.go +++ b/go/performance/serverbench/bench_test.go @@ -183,7 +183,12 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, //b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg)) eg.Go(func() (err error) { - startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv, false) + startErr, closeErr := srv.Serve(ctx, &srv.Config{ + Version: "", + ServerConfig: cfg, + Controller: sc, + DoltEnv: dEnv, + }) if startErr != nil { return startErr }