Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 65 additions & 62 deletions go/cmd/dolt/commands/sqlserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,40 +76,42 @@ var ExternalDisableUsers bool = false

var ErrCouldNotLockDatabase = goerrors.NewKind("database \"%s\" is locked by another dolt process; either clone the database to run a second server, or stop the dolt process which currently holds an exclusive write lock on the database")

type Config struct {
ServerConfig servercfg.ServerConfig
DoltEnv *env.DoltEnv
SkipRootUserInit bool
Version string
Controller *svcs.Controller
ProtocolListenerFactory server.ProtocolListenerFunc
}

// Serve starts a MySQL-compatible server. Returns any errors that were encountered.
func Serve(
ctx context.Context,
version string,
serverConfig servercfg.ServerConfig,
controller *svcs.Controller,
dEnv *env.DoltEnv,
skipRootUserInitialization bool,
cfg *Config,
) (startError error, closeError error) {
// Code is easier to work through if we assume that serverController is never nil
if controller == nil {
controller = svcs.NewController()
if cfg.Controller == nil {
cfg.Controller = svcs.NewController()
}

ConfigureServices(serverConfig, controller, version, dEnv, skipRootUserInitialization)
ConfigureServices(cfg)

go controller.Start(ctx)
err := controller.WaitForStart()
go cfg.Controller.Start(ctx)
err := cfg.Controller.WaitForStart()
if err != nil {
return err, nil
}
return nil, controller.WaitForStop()
return nil, cfg.Controller.WaitForStop()
}

func ConfigureServices(
serverConfig servercfg.ServerConfig,
controller *svcs.Controller,
version string,
dEnv *env.DoltEnv,
skipRootUserInitialization bool,
cfg *Config,
) {
controller := cfg.Controller
ValidateConfigStep := &svcs.AnonService{
InitF: func(context.Context) error {
return servercfg.ValidateConfig(serverConfig)
return servercfg.ValidateConfig(cfg.ServerConfig)
},
}
controller.Register(ValidateConfigStep)
Expand All @@ -118,18 +120,18 @@ func ConfigureServices(
lgr.SetOutput(cli.CliErr)
InitLogging := &svcs.AnonService{
InitF: func(context.Context) error {
level, err := logrus.ParseLevel(serverConfig.LogLevel().String())
level, err := logrus.ParseLevel(cfg.ServerConfig.LogLevel().String())
if err != nil {
return err
}
logrus.SetLevel(level)
switch strings.ToLower(string(serverConfig.LogFormat())) {
switch strings.ToLower(string(cfg.ServerConfig.LogFormat())) {
case string(servercfg.LogFormat_JSON):
logrus.SetFormatter(&logrus.JSONFormatter{})
case string(servercfg.LogFormat_Text):
logrus.SetFormatter(&logrus.TextFormatter{})
default:
return fmt.Errorf("unknown log format: %s", serverConfig.LogFormat())
return fmt.Errorf("unknown log format: %s", cfg.ServerConfig.LogFormat())
}

sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
Expand Down Expand Up @@ -164,12 +166,12 @@ func ConfigureServices(
}
controller.Register(InitLogging)

controller.Register(newHeartbeatService(version, dEnv))
controller.Register(newHeartbeatService(cfg.Version, cfg.DoltEnv))

fs := dEnv.FS
fs := cfg.DoltEnv.FS
InitFailsafes := &svcs.AnonService{
InitF: func(ctx context.Context) (err error) {
dEnv.Config.SetFailsafes(env.DefaultFailsafeConfig)
cfg.DoltEnv.Config.SetFailsafes(env.DefaultFailsafeConfig)
return nil
},
}
Expand All @@ -178,7 +180,7 @@ func ConfigureServices(
var mrEnv *env.MultiRepoEnv
InitMultiEnv := &svcs.AnonService{
InitF: func(ctx context.Context) (err error) {
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv)
mrEnv, err = env.MultiEnvForDirectory(ctx, cfg.DoltEnv.Config.WriteableConfig(), fs, cfg.DoltEnv.Version, cfg.DoltEnv)
return err
},
}
Expand All @@ -199,11 +201,11 @@ func ConfigureServices(
var localCreds *LocalCreds
InitServerLocalCreds := &svcs.AnonService{
InitF: func(context.Context) (err error) {
localCreds, err = persistServerLocalCreds(serverConfig.Port(), dEnv)
localCreds, err = persistServerLocalCreds(cfg.ServerConfig.Port(), cfg.DoltEnv)
return err
},
StopF: func() error {
RemoveLocalCreds(dEnv.FS)
RemoveLocalCreds(cfg.DoltEnv.FS)
return nil
},
}
Expand All @@ -212,7 +214,7 @@ func ConfigureServices(
var clusterController *cluster.Controller
InitClusterController := &svcs.AnonService{
InitF: func(context.Context) (err error) {
clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
clusterController, err = cluster.NewController(lgr, cfg.ServerConfig.ClusterConfig(), mrEnv.Config())
return err
},
}
Expand All @@ -221,7 +223,7 @@ func ConfigureServices(
var serverConf server.Config
LoadServerConfig := &svcs.AnonService{
InitF: func(context.Context) (err error) {
serverConf, err = getConfigFromServerConfig(serverConfig)
serverConf, err = getConfigFromServerConfig(cfg.ServerConfig, cfg.ProtocolListenerFactory)
return err
},
}
Expand All @@ -232,20 +234,20 @@ func ConfigureServices(
InitSqlEngineConfig := &svcs.AnonService{
InitF: func(context.Context) error {
config = &engine.SqlEngineConfig{
IsReadOnly: serverConfig.ReadOnly(),
PrivFilePath: serverConfig.PrivilegeFilePath(),
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
DoltCfgDirPath: serverConfig.CfgDir(),
ServerUser: serverConfig.User(),
ServerPass: serverConfig.Password(),
ServerHost: serverConfig.Host(),
Autocommit: serverConfig.AutoCommit(),
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
JwksConfig: serverConfig.JwksConfig(),
SystemVariables: serverConfig.SystemVars(),
IsReadOnly: cfg.ServerConfig.ReadOnly(),
PrivFilePath: cfg.ServerConfig.PrivilegeFilePath(),
BranchCtrlFilePath: cfg.ServerConfig.BranchControlFilePath(),
DoltCfgDirPath: cfg.ServerConfig.CfgDir(),
ServerUser: cfg.ServerConfig.User(),
ServerPass: cfg.ServerConfig.Password(),
ServerHost: cfg.ServerConfig.Host(),
Autocommit: cfg.ServerConfig.AutoCommit(),
DoltTransactionCommit: cfg.ServerConfig.DoltTransactionCommit(),
JwksConfig: cfg.ServerConfig.JwksConfig(),
SystemVariables: cfg.ServerConfig.SystemVars(),
ClusterController: clusterController,
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
SkipRootUserInitialization: skipRootUserInitialization,
SkipRootUserInitialization: cfg.SkipRootUserInit,
}
return nil
},
Expand All @@ -255,7 +257,7 @@ func ConfigureServices(
var esStatus eventscheduler.SchedulerStatus
InitEventSchedulerStatus := &svcs.AnonService{
InitF: func(context.Context) (err error) {
esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
esStatus, err = getEventSchedulerStatus(cfg.ServerConfig.EventSchedulerStatus())
if err != nil {
return err
}
Expand All @@ -267,8 +269,8 @@ func ConfigureServices(

InitAutoGCController := &svcs.AnonService{
InitF: func(context.Context) error {
if serverConfig.AutoGCBehavior() != nil &&
serverConfig.AutoGCBehavior().Enable() {
if cfg.ServerConfig.AutoGCBehavior() != nil &&
cfg.ServerConfig.AutoGCBehavior().Enable() {
config.AutoGCController = sqle.NewAutoGCController(lgr)
}
return nil
Expand Down Expand Up @@ -352,7 +354,7 @@ func ConfigureServices(
// in the configuration files for a sql-server, and not global for the whole host.
PersistNondeterministicSystemVarDefaults := &svcs.AnonService{
InitF: func(ctx context.Context) error {
err := dsess.PersistSystemVarDefaults(dEnv)
err := dsess.PersistSystemVarDefaults(cfg.DoltEnv)
if err != nil {
logrus.Errorf("unable to persist system variable defaults: %v", err)
}
Expand Down Expand Up @@ -404,7 +406,7 @@ func ConfigureServices(

if logBin == 1 {
logrus.Infof("Enabling binary logging for branch %s", logBinBranch)
binlogProducer, err := binlogreplication.NewBinlogProducer(dEnv.FS)
binlogProducer, err := binlogreplication.NewBinlogProducer(cfg.DoltEnv.FS)
if err != nil {
return err
}
Expand Down Expand Up @@ -441,7 +443,7 @@ func ConfigureServices(
InitF: func(ctx context.Context) error {
// If privileges.db has already been initialized, indicating that this is NOT the
// first time sql-server has been launched, then don't initialize the root superuser.
if permissionDbExists, err := doesPrivilegesDbExist(dEnv, serverConfig.PrivilegeFilePath()); err != nil {
if permissionDbExists, err := doesPrivilegesDbExist(cfg.DoltEnv, cfg.ServerConfig.PrivilegeFilePath()); err != nil {
return err
} else if permissionDbExists {
logrus.Debug("privileges.db already exists, not creating root superuser")
Expand Down Expand Up @@ -480,7 +482,7 @@ func ConfigureServices(
// for persisting the privileges database. The filesys API
// is in the Dolt layer, so when the file path is passed to
// GMS, it expects it to be a path on disk, and errors out.
if _, isInMemFs := dEnv.FS.(*filesys.InMemFS); isInMemFs {
if _, isInMemFs := cfg.DoltEnv.FS.(*filesys.InMemFS); isInMemFs {
return nil
} else {
sqlCtx, err := sqlEngine.NewDefaultContext(context.Background())
Expand All @@ -496,8 +498,8 @@ func ConfigureServices(
var metListener *metricsListener
InitMetricsListener := &svcs.AnonService{
InitF: func(context.Context) (err error) {
labels := serverConfig.MetricsLabels()
metListener, err = newMetricsListener(labels, version, clusterController)
labels := cfg.ServerConfig.MetricsLabels()
metListener, err = newMetricsListener(labels, cfg.Version, clusterController)
return err
},
StopF: func() error {
Expand Down Expand Up @@ -539,10 +541,10 @@ func ConfigureServices(

RunMetricsServer := &svcs.AnonService{
InitF: func(context.Context) (err error) {
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
if cfg.ServerConfig.MetricsHost() != "" && cfg.ServerConfig.MetricsPort() > 0 {
metSrv.state.Swap(svcs.ServiceState_Init)

addr := fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort())
addr := fmt.Sprintf("%s:%d", cfg.ServerConfig.MetricsHost(), cfg.ServerConfig.MetricsPort())
metSrv.lis, err = net.Listen("tcp", addr)
if err != nil {
return err
Expand Down Expand Up @@ -583,16 +585,16 @@ func ConfigureServices(
var remoteSrv RemoteSrvService
RunRemoteSrv := &svcs.AnonService{
InitF: func(ctx context.Context) error {
if serverConfig.RemotesapiPort() == nil {
if cfg.ServerConfig.RemotesapiPort() == nil {
return nil
}
remoteSrv.state.Swap(svcs.ServiceState_Init)

port := *serverConfig.RemotesapiPort()
port := *cfg.ServerConfig.RemotesapiPort()

apiReadOnly := false
if serverConfig.RemotesapiReadOnly() != nil {
apiReadOnly = *serverConfig.RemotesapiReadOnly()
if cfg.ServerConfig.RemotesapiReadOnly() != nil {
apiReadOnly = *cfg.ServerConfig.RemotesapiReadOnly()
}

listenaddr := fmt.Sprintf(":%d", port)
Expand All @@ -601,7 +603,7 @@ func ConfigureServices(
}
args := remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: apiReadOnly || serverConfig.ReadOnly(),
ReadOnly: apiReadOnly || cfg.ServerConfig.ReadOnly(),
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET,
Expand Down Expand Up @@ -666,7 +668,7 @@ func ConfigureServices(
}
args.FS = sqlEngine.FileSystem()

clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(cfg.ServerConfig.ClusterConfig())
if err != nil {
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
return err
Expand All @@ -675,7 +677,7 @@ func ConfigureServices(

clusterRemoteSrv.srv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
lgr.Errorf("error creating remotesapi server on port %d: %v", *cfg.ServerConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer())
Expand Down Expand Up @@ -715,13 +717,13 @@ func ConfigureServices(
var sqlServerClosed bool
InitSQLServer := &svcs.AnonService{
InitF: func(context.Context) (err error) {
v, ok := serverConfig.(servercfg.ValidatingServerConfig)
v, ok := cfg.ServerConfig.(servercfg.ValidatingServerConfig)
if ok && v.GoldenMysqlConnectionString() != "" {
mySQLServer, err = server.NewServerWithHandler(
serverConf,
sqlEngine.GetUnderlyingEngine(),
sql.NewContext,
newSessionBuilder(sqlEngine, serverConfig),
newSessionBuilder(sqlEngine, cfg.ServerConfig),
metListener,
func(h mysql.Handler) (mysql.Handler, error) {
return golden.NewValidatingHandler(h, v.GoldenMysqlConnectionString(), logrus.StandardLogger())
Expand All @@ -732,7 +734,7 @@ func ConfigureServices(
serverConf,
sqlEngine.GetUnderlyingEngine(),
sql.NewContext,
newSessionBuilder(sqlEngine, serverConfig),
newSessionBuilder(sqlEngine, cfg.ServerConfig),
metListener,
)
}
Expand Down Expand Up @@ -1065,7 +1067,7 @@ func newSessionBuilder(se *engine.SqlEngine, config servercfg.ServerConfig) serv
}

// getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server.
func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Config, error) {
func getConfigFromServerConfig(serverConfig servercfg.ServerConfig, plf server.ProtocolListenerFunc) (server.Config, error) {
serverConf, err := handleProtocolAndAddress(serverConfig)
if err != nil {
return server.Config{}, err
Expand Down Expand Up @@ -1095,6 +1097,7 @@ func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Conf
serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport()
serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen()
serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery()
serverConf.ProtocolListenerFactory = plf

return serverConf, nil
}
Expand Down
28 changes: 24 additions & 4 deletions go/cmd/dolt/commands/sqlserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ func TestServerGoodParams(t *testing.T) {
t.Run(servercfg.ConfigInfo(test), func(t *testing.T) {
sc := svcs.NewController()
go func(config servercfg.ServerConfig, sc *svcs.Controller) {
_, _ = Serve(context.Background(), "0.0.0", config, sc, env, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: config,
Controller: sc,
DoltEnv: env,
})
}(test, sc)
err := sc.WaitForStart()
require.NoError(t, err)
Expand Down Expand Up @@ -240,7 +245,12 @@ func TestServerSelect(t *testing.T) {
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: env,
})
}()
err = sc.WaitForStart()
require.NoError(t, err)
Expand Down Expand Up @@ -339,7 +349,12 @@ func TestServerSetDefaultBranch(t *testing.T) {
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: dEnv,
})
}()
err = sc.WaitForStart()
require.NoError(t, err)
Expand Down Expand Up @@ -503,7 +518,12 @@ func TestReadReplica(t *testing.T) {

os.Chdir(multiSetup.DbPaths[readReplicaDbName])
go func() {
err, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, multiSetup.GetEnv(readReplicaDbName), false)
err, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: multiSetup.GetEnv(readReplicaDbName),
})
require.NoError(t, err)
}()
require.NoError(t, sc.WaitForStart())
Expand Down
Loading
Loading