Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 67 additions & 23 deletions cmd/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/k0sproject/k0s/pkg/config"
"github.com/k0sproject/k0s/pkg/constant"
"github.com/k0sproject/k0s/pkg/etcd"
"github.com/k0sproject/k0s/pkg/k0scontext"
kubeutil "github.com/k0sproject/k0s/pkg/kubernetes"

apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -51,15 +52,54 @@ Reads the runtime configuration from standard input.`,
Args: cobra.NoArgs,
PersistentPreRun: debugFlags.Run,
RunE: func(cmd *cobra.Command, _ []string) error {
var run func() error
ctx := cmd.Context()
log := k0scontext.ValueOrElse(ctx, func() logrus.FieldLogger {
return logrus.StandardLogger()
})

if runtimeConfig, err := loadRuntimeConfig(cmd.InOrStdin()); err != nil {
var server *http.Server

if runtimeConfig, err := loadRuntimeConfig(log, cmd.InOrStdin()); err != nil {
return err
} else if run, err = buildServer(runtimeConfig.Spec.K0sVars, runtimeConfig.Spec.NodeConfig); err != nil {
} else if server, err = buildServer(log, runtimeConfig.Spec.K0sVars, runtimeConfig.Spec.NodeConfig); err != nil {
return err
}

return run()
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", server.Addr)
if err != nil {
return err
}
defer server.Close()

log.Info("Listening on ", server.Addr, ", start serving")

doneServing := make(chan struct{})
go func() {
defer close(doneServing)
err = server.ServeTLS(listener, "", "")
}()

select {
case <-doneServing:
return fmt.Errorf("unexpected server error: %w", err)

case <-ctx.Done():
log.Info("Shutting down server: ", context.Cause(ctx))

ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
return fmt.Errorf("while shutting down server: %w", err)
}

<-doneServing
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("unexpected error after server shutdown: %w", err)
}

log.Info("Good bye")
return nil
}
},
}

Expand All @@ -76,8 +116,8 @@ Reads the runtime configuration from standard input.`,
return cmd
}

func loadRuntimeConfig(stdin io.Reader) (*config.RuntimeConfig, error) {
logrus.Info("Reading runtime configuration from standard input ...")
func loadRuntimeConfig(log logrus.FieldLogger, stdin io.Reader) (*config.RuntimeConfig, error) {
log.Info("Reading runtime configuration from standard input")
bytes, err := io.ReadAll(stdin)
if err != nil {
return nil, fmt.Errorf("failed to read from standard input: %w", err)
Expand All @@ -91,7 +131,7 @@ func loadRuntimeConfig(stdin io.Reader) (*config.RuntimeConfig, error) {
return runtimeConfig, nil
}

func buildServer(k0sVars *config.CfgVars, nodeConfig *v1beta1.ClusterConfig) (func() error, error) {
func buildServer(log logrus.FieldLogger, k0sVars *config.CfgVars, nodeConfig *v1beta1.ClusterConfig) (*http.Server, error) {
// Single kube client for whole lifetime of the API
client, err := kubeutil.NewClientFromFile(k0sVars.AdminKubeConfigPath)
if err != nil {
Expand All @@ -107,37 +147,41 @@ func buildServer(k0sVars *config.CfgVars, nodeConfig *v1beta1.ClusterConfig) (fu
// Only mount the etcd handler if we're running on internal etcd storage
// by default the mux will return 404 back which the caller should handle
mux.Handle(prefix+"/etcd/members", mw.AllowMethods(http.MethodPost)(
authMiddleware(etcdHandler(k0sVars.CertRootDir, k0sVars.EtcdCertDir), secrets, "controller-join")))
authMiddleware(etcdHandler(log, k0sVars.CertRootDir, k0sVars.EtcdCertDir), log, secrets, "controller-join")))
}

if storage.IsJoinable() {
mux.Handle(prefix+"/ca", mw.AllowMethods(http.MethodGet)(
authMiddleware(caHandler(k0sVars.CertRootDir), secrets, "controller-join")))
authMiddleware(caHandler(k0sVars.CertRootDir), log, secrets, "controller-join")))
}

ipAddr, bindAddressSpecified := nodeConfig.Spec.API.ExtraArgs["bind-address"]
if !bindAddressSpecified && nodeConfig.Spec.API.OnlyBindToAddress {
ipAddr = nodeConfig.Spec.API.Address
}

srv := &http.Server{
cert, err := tls.LoadX509KeyPair(
filepath.Join(k0sVars.CertRootDir, "k0s-api.crt"),
filepath.Join(k0sVars.CertRootDir, "k0s-api.key"),
)
if err != nil {
return nil, err
}

return &http.Server{
Handler: mux,
Addr: net.JoinHostPort(ipAddr, strconv.Itoa(nodeConfig.Spec.API.K0sAPIPort)),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
CipherSuites: constant.AllowedTLS12CipherSuiteIDs,
},
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}

cert := filepath.Join(k0sVars.CertRootDir, "k0s-api.crt")
key := filepath.Join(k0sVars.CertRootDir, "k0s-api.key")

return func() error { return srv.ListenAndServeTLS(cert, key) }, nil
}, nil
}

func etcdHandler(certRootDir, etcdCertDir string) http.Handler {
func etcdHandler(log logrus.FieldLogger, certRootDir, etcdCertDir string) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
ctx := req.Context()
var etcdReq v1beta1.EtcdRequest
Expand All @@ -146,7 +190,7 @@ func etcdHandler(certRootDir, etcdCertDir string) http.Handler {
sendError(err, resp)
return
}
logrus.Infof("etcd API, adding new member: %s", etcdReq.PeerAddress)
log.Infof("etcd API, adding new member: %s", etcdReq.PeerAddress)
err = etcdReq.Validate()
if err != nil {
sendError(err, resp)
Expand Down Expand Up @@ -238,7 +282,7 @@ func caHandler(certRootDir string) http.Handler {
// We need to validate:
// - that we find a secret with the ID
// - that the token matches whats inside the secret
func isValidToken(ctx context.Context, secrets clientcorev1.SecretInterface, rawTokenString, usage string) bool {
func isValidToken(ctx context.Context, log logrus.FieldLogger, secrets clientcorev1.SecretInterface, rawTokenString, usage string) bool {
tokenString, err := bootstraptokenv1.NewBootstrapTokenString(rawTokenString)
if err != nil {
return false
Expand All @@ -248,14 +292,14 @@ func isValidToken(ctx context.Context, secrets clientcorev1.SecretInterface, raw
secret, err := secrets.Get(ctx, secretName, metav1.GetOptions{})
if err != nil {
if !apierrors.IsNotFound(err) {
logrus.WithError(err).Error("Failed to get bootstrap token with ID ", tokenString.ID)
log.WithError(err).Error("Failed to get bootstrap token with ID ", tokenString.ID)
}
return false
}

token, err := bootstraptokenv1.BootstrapTokenFromSecret(secret)
if err != nil {
logrus.WithError(err).Errorf("Bootstrap token with ID %s is malformed", tokenString.ID)
log.WithError(err).Errorf("Bootstrap token with ID %s is malformed", tokenString.ID)
return false
}

Expand All @@ -277,12 +321,12 @@ func isValidToken(ctx context.Context, secrets clientcorev1.SecretInterface, raw
}
}

func authMiddleware(next http.Handler, secrets clientcorev1.SecretInterface, usage string) http.Handler {
func authMiddleware(next http.Handler, log logrus.FieldLogger, secrets clientcorev1.SecretInterface, usage string) http.Handler {
unauthorizedErr := errors.New("go away")

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
if ok && isValidToken(r.Context(), secrets, token, usage) {
if ok && isValidToken(r.Context(), log, secrets, token, usage) {
next.ServeHTTP(w, r)
} else {
sendError(unauthorizedErr, w, http.StatusUnauthorized)
Expand Down
166 changes: 166 additions & 0 deletions cmd/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// SPDX-FileCopyrightText: 2025 k0s authors
// SPDX-License-Identifier: Apache-2.0

package api_test

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"testing"
"testing/iotest"
"time"

"github.com/cloudflare/cfssl/csr"
"github.com/cloudflare/cfssl/initca"
"github.com/k0sproject/k0s/cmd"
"github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/k0sproject/k0s/pkg/config"
"github.com/k0sproject/k0s/pkg/k0scontext"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/yaml"

"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)

func TestAPI(t *testing.T) {
t.Run("MissingRuntimeConfig", func(t *testing.T) {
underTest := cmd.NewRootCmd()
underTest.SetArgs([]string{"api"})
underTest.SetIn(iotest.ErrReader(io.EOF))
err := underTest.ExecuteContext(t.Context())
assert.ErrorContains(t, err, `failed to load runtime configuration: invalid runtime configuration: invalid api version: ""`)
})

dataDir := t.TempDir()
rtc := config.RuntimeConfig{
TypeMeta: v1.TypeMeta{APIVersion: v1beta1.ClusterConfigAPIVersion, Kind: config.RuntimeConfigKind},
Spec: &config.RuntimeConfigSpec{
NodeConfig: &v1beta1.ClusterConfig{Spec: &v1beta1.ClusterSpec{
API: &v1beta1.APISpec{
Address: "127.0.0.1",
OnlyBindToAddress: true,
},
Storage: &v1beta1.StorageSpec{},
}},
K0sVars: &config.CfgVars{
AdminKubeConfigPath: filepath.Join(dataDir, "kubeconfig"),
CertRootDir: dataDir,
},
},
}
// Find a free port. We cannot pass zero to the API since this will fallback to 9443.
if l, err := net.Listen("tcp", "127.0.0.1:0"); assert.NoError(t, err) {
// Extract the port number
addr := l.Addr().(*net.TCPAddr)
rtc.Spec.NodeConfig.Spec.API.K0sAPIPort = addr.Port
require.NoError(t, l.Close())
} else {
rtc.Spec.NodeConfig.Spec.API.K0sAPIPort = 9443
}

configData, err := yaml.Marshal(&rtc)
require.NoError(t, err)

t.Run("MissingKubeconfig", func(t *testing.T) {
underTest := cmd.NewRootCmd()
underTest.SetArgs([]string{"api"})
underTest.SetIn(bytes.NewReader(configData))
err := underTest.ExecuteContext(t.Context())
var pathErr *os.PathError
if assert.ErrorAs(t, err, &pathErr) {
assert.Equal(t, pathErr.Path, rtc.Spec.K0sVars.AdminKubeConfigPath)
assert.ErrorIs(t, pathErr.Err, os.ErrNotExist)
}
})

kubeconfig := clientcmdapi.Config{
Clusters: map[string]*clientcmdapi.Cluster{t.Name(): {Server: "blackhole.example.com"}},
Contexts: map[string]*clientcmdapi.Context{t.Name(): {Cluster: t.Name()}},
CurrentContext: t.Name(),
}
require.NoError(t, clientcmd.WriteToFile(kubeconfig, rtc.Spec.K0sVars.AdminKubeConfigPath))

t.Run("MissingCertificate", func(t *testing.T) {
underTest := cmd.NewRootCmd()
underTest.SetArgs([]string{"api"})
underTest.SetIn(bytes.NewReader(configData))
err := underTest.ExecuteContext(t.Context())
var pathErr *os.PathError
if assert.ErrorAs(t, err, &pathErr) {
assert.Equal(t, pathErr.Path, filepath.Join(rtc.Spec.K0sVars.CertRootDir, "k0s-api.crt"))
assert.ErrorIs(t, pathErr.Err, os.ErrNotExist)
}
})

certData, _, keyData, err := initca.New(&csr.CertificateRequest{
KeyRequest: csr.NewKeyRequest(),
CN: "blackhole.example.com",
})
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(rtc.Spec.K0sVars.CertRootDir, "k0s-api.crt"), certData, 0644))
require.NoError(t, os.WriteFile(filepath.Join(rtc.Spec.K0sVars.CertRootDir, "k0s-api.key"), keyData, 0600))

t.Run("StartsAndStops", func(t *testing.T) {
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(errors.New("test function exited"))

var logsConsumed uint
log, allLogs := test.NewNullLogger()
ctx = k0scontext.WithValue[logrus.FieldLogger](ctx, log)

underTest := cmd.NewRootCmd()
underTest.SetArgs([]string{"api"})
underTest.SetIn(bytes.NewReader(configData))

errCh := make(chan error, 1)
go func() { errCh <- underTest.ExecuteContext(ctx) }()

startup:
for {
select {
case err := <-errCh:
require.Failf(t, "API terminated unexpectedly", "%v", err)

case <-time.After(100 * time.Millisecond):
for _, entry := range allLogs.AllEntries()[logsConsumed:] {
t.Log(entry.Message)
logsConsumed++
if entry.Message == fmt.Sprintf(
"Listening on %s:%d, start serving",
rtc.Spec.NodeConfig.Spec.API.Address,
rtc.Spec.NodeConfig.Spec.API.K0sAPIPort,
) {
cancel(errors.New(t.Name() + " succeeded"))
break startup
}
}
}
}

assert.NoError(t, <-errCh, "API didn't terminate successfully")
var shutdownReasonFound bool
for _, entry := range allLogs.AllEntries()[logsConsumed:] {
t.Log(entry.Message)
if !shutdownReasonFound {
if reason, found := strings.CutPrefix(entry.Message, "Shutting down server: "); found {
shutdownReasonFound = true
assert.Equal(t, t.Name()+" succeeded", reason, "Unexpected shutdown reason")
}
}
}

assert.True(t, shutdownReasonFound, "No shutdown reason found in API logs")
})
}
Loading