diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index ff10ec4ac28..57984f3871f 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -230,6 +230,17 @@ func ConfigureServices( } controller.Register(LoadServerConfig) + // Ensure @@port reflects the configured server port instead of the MySQL default (3306) + SetPortSystemVariable := &svcs.AnonService{ + InitF: func(context.Context) error { + if err := sql.SystemVariables.AssignValues(map[string]interface{}{"port": cfg.ServerConfig.Port()}); err != nil { + logrus.Warnf("unable to set @@port system variable: %v", err) + } + return nil + }, + } + controller.Register(SetPortSystemVariable) + // Create SQL Engine with users var config *engine.SqlEngineConfig InitSqlEngineConfig := &svcs.AnonService{ diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index 6aaeb0bc014..8b40b71f2e2 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -16,6 +16,7 @@ package sqlserver import ( "fmt" + "net" "net/http" "os" "path/filepath" @@ -514,6 +515,66 @@ listener: assert.Equal(t, 1, readOnlyValue[0].Value) } +func TestPortSystemVariable(t *testing.T) { + ctx := context.Background() + + dEnv, err := sqle.CreateEnvWithSeedData() + require.NoError(t, err) + defer func() { + assert.NoError(t, dEnv.DoltDB(ctx).Close()) + }() + + // Pick an ephemeral free port for this test + listenPort, err := findEmptyPort() + require.NoError(t, err) + serverConfig := DefaultCommandLineServerConfig().WithPort(listenPort) + sc := svcs.NewController() + defer sc.Stop() + go func() { + _, _ = Serve(context.Background(), &Config{ + Version: "0.0.0", + ServerConfig: serverConfig, + Controller: sc, + DoltEnv: dEnv, + }) + }() + err = sc.WaitForStart() + require.NoError(t, err) + + conn, err := dbr.Open("mysql", servercfg.ConnectionString(serverConfig, "dolt"), nil) + require.NoError(t, err) + defer conn.Close() + sess := conn.NewSession(nil) + + // Verify @@port + var portVal []struct { + Value int `db:"@@port"` + } + _, err = sess.SelectBySql("SELECT @@port").LoadContext(ctx, &portVal) + require.NoError(t, err) + require.Len(t, portVal, 1) + assert.Equal(t, listenPort, portVal[0].Value) + + // Verify @@global.port + var globalPortVal []struct { + Value int `db:"@@global.port"` + } + _, err = sess.SelectBySql("SELECT @@global.port").LoadContext(ctx, &globalPortVal) + require.NoError(t, err) + require.Len(t, globalPortVal, 1) + assert.Equal(t, listenPort, globalPortVal[0].Value) +} + +// findEmptyPort finds an available TCP port by asking the OS for an ephemeral port +func findEmptyPort() (int, error) { + l, err := net.Listen("tcp", ":0") + if err != nil { + return -1, err + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +} + func TestReadOnlyEnforcement(t *testing.T) { ctx := context.Background() diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index c83805868c8..a3cb45afe44 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -340,6 +340,23 @@ EOF [[ "$output" =~ "read only mode" ]] || false } +@test "sql-server: @@port reflects configured listener port" { + skiponwindows "Missing dependencies" + + cd repo1 + + start_sql_server + + # Verify that @@port matches the dynamically selected $PORT + run dolt sql -q "SELECT @@port;" + [ "$status" -eq 0 ] + [[ "$output" =~ " $PORT " ]] || false + + run dolt sql -q "SELECT @@global.port;" + [ "$status" -eq 0 ] + [[ "$output" =~ " $PORT " ]] || false +} + @test "sql-server: inspect sql-server using CLI" { skiponwindows "Missing dependencies"