diff --git a/lib/srv/db/sqlserver/auth.go b/lib/srv/db/sqlserver/auth.go index 661a4a77cdabe..8f1e9ed4953f1 100644 --- a/lib/srv/db/sqlserver/auth.go +++ b/lib/srv/db/sqlserver/auth.go @@ -69,7 +69,7 @@ func (c *connector) keytabClient(session *common.Session) (*client.Client, error } // kinitClient returns a kerberos client using a kinit ccache -func (c *connector) kinitClient(ctx context.Context, session *common.Session, auth windows.AuthInterface, dataDir string) (*client.Client, error) { +func (c *connector) kinitClient(ctx context.Context, session *common.Session, auth windows.AuthInterface) (*client.Client, error) { ldapPem, _ := pem.Decode([]byte(session.Database.GetAD().LDAPCert)) if ldapPem == nil { @@ -98,7 +98,6 @@ func (c *connector) kinitClient(ctx context.Context, session *common.Session, au Realm: realmName, KDCHost: session.Database.GetAD().KDCHostName, AdminServer: session.Database.GetAD().Domain, - DataDir: dataDir, LDAPCA: cert, LDAPCAPEM: session.Database.GetAD().LDAPCert, Command: c.kinitCommandGenerator, diff --git a/lib/srv/db/sqlserver/connect.go b/lib/srv/db/sqlserver/connect.go index e1ec06a31862f..38e0e5b9702f6 100644 --- a/lib/srv/db/sqlserver/connect.go +++ b/lib/srv/db/sqlserver/connect.go @@ -58,8 +58,6 @@ type connector struct { DBAuth common.Auth // AuthClient is the teleport client AuthClient windows.AuthInterface - // DataDir is the Teleport data directory - DataDir string kinitCommandGenerator kinit.CommandGenerator } @@ -75,7 +73,7 @@ func (c *connector) getKerberosClient(ctx context.Context, sessionCtx *common.Se } return kt, nil case sessionCtx.Database.GetAD().KDCHostName != "" && sessionCtx.Database.GetAD().LDAPCert != "": - kt, err := c.kinitClient(ctx, sessionCtx, c.AuthClient, c.DataDir) + kt, err := c.kinitClient(ctx, sessionCtx, c.AuthClient) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/sqlserver/connect_test.go b/lib/srv/db/sqlserver/connect_test.go index 5453c382f04b1..11ef58966d6fe 100644 --- a/lib/srv/db/sqlserver/connect_test.go +++ b/lib/srv/db/sqlserver/connect_test.go @@ -154,8 +154,29 @@ type staticCache struct { pass bool } +func getCachePath(t *testing.T, args ...string) string { + if len(args) != 8 { + t.Fatalf("Unexpected args (%v): %v", len(args), args) + } + // example arguments: + // [-X X509_anchors=FILE:/tmp/kinit3779395068/userca.pem -X X509_user_identity=FILE:/tmp/kinit3779395068/cert.pem,/tmp/kinit3779395068/key.pem -c /tmp/kinit3779395068/krb5.cache -- alice] + if args[0] != "-X" { + t.Fatalf("Unexpected args (%v): %v", args[0], args) + } + if args[2] != "-X" { + t.Fatalf("Unexpected args (%v): %v", args[2], args) + } + if args[4] != "-c" { + t.Fatalf("Unexpected args (%v): %v", args[4], args) + } + if args[6] != "--" { + t.Fatalf("Unexpected args (%v): %v", args[6], args) + } + return args[5] +} + func (s *staticCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd { - cachePath := args[len(args)-1] + cachePath := getCachePath(s.t, args...) require.NotEmpty(s.t, cachePath) err := os.WriteFile(cachePath, cacheData, 0664) require.NoError(s.t, err) @@ -315,7 +336,6 @@ func TestConnectorKInitClient(t *testing.T) { databaseUser := "alice" databaseName := database.GetName() - connector.DataDir = dir connectorCtx, cancel := context.WithCancel(ctx) // we want to pass the canceled context diff --git a/lib/srv/db/sqlserver/engine.go b/lib/srv/db/sqlserver/engine.go index 24edbbd9b2bef..598029c8afad4 100644 --- a/lib/srv/db/sqlserver/engine.go +++ b/lib/srv/db/sqlserver/engine.go @@ -41,7 +41,6 @@ func NewEngine(ec common.EngineConfig) common.Engine { Connector: &connector{ DBAuth: ec.Auth, AuthClient: ec.AuthClient, - DataDir: ec.DataDir, }, } } diff --git a/lib/srv/db/sqlserver/kinit/kinit.go b/lib/srv/db/sqlserver/kinit/kinit.go index 2aee2adecf3ee..98953e0178818 100644 --- a/lib/srv/db/sqlserver/kinit/kinit.go +++ b/lib/srv/db/sqlserver/kinit/kinit.go @@ -27,6 +27,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "text/template" "time" @@ -94,8 +95,6 @@ type CommandConfig struct { KDCHost string // AdminServer is the administration server hostname (usually AD server) AdminServer string - // DataDir is the Teleport Data Directory - DataDir string // LDAPCA is the Windows LDAP Certificate for client signing LDAPCA *x509.Certificate // LDAPCAPEM contains the same certificate as LDAPCA but in PEM format. It @@ -114,13 +113,9 @@ func NewCommandLineInitializer(config CommandConfig) *CommandLineInitializer { cmd := &CommandLineInitializer{ auth: config.AuthClient, userName: config.User, - cacheName: fmt.Sprintf("%s@%s", config.User, config.Realm), RealmName: config.Realm, KDCHostName: config.KDCHost, AdminServerName: config.AdminServer, - dataDir: config.DataDir, - certPath: fmt.Sprintf("%s.pem", config.User), - keyPath: fmt.Sprintf("%s-key.pem", config.User), binary: kinitBinary, command: config.Command, certGetter: config.CertGetter, @@ -160,12 +155,7 @@ type CommandLineInitializer struct { // AdminServerName is the admin server Name (usually AD host) AdminServerName string - dataDir string - userName string - cacheName string - - certPath string - keyPath string + userName string binary string command CommandGenerator @@ -242,18 +232,10 @@ func (k *CommandLineInitializer) UseOrCreateCredentials(ctx context.Context) (*c } }() - certPath := filepath.Join(tmp, fmt.Sprintf("%s.pem", k.userName)) - keyPath := filepath.Join(tmp, fmt.Sprintf("%s-key.pem", k.userName)) + certPath := filepath.Join(tmp, "cert.pem") + keyPath := filepath.Join(tmp, "key.pem") userCAPath := filepath.Join(tmp, "userca.pem") - - cacheDir := filepath.Join(k.dataDir, "krb5_cache") - - err = os.MkdirAll(cacheDir, os.ModePerm) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - cachePath := filepath.Join(cacheDir, k.cacheName) + cachePath := filepath.Join(tmp, "krb5.cache") wca, err := k.certGetter.GetCertificateBytes(ctx) if err != nil { @@ -261,22 +243,22 @@ func (k *CommandLineInitializer) UseOrCreateCredentials(ctx context.Context) (*c } // store files in temp dir - err = os.WriteFile(certPath, wca.certPEM, 0644) + err = os.WriteFile(certPath, wca.certPEM, 0600) if err != nil { return nil, nil, trace.Wrap(err) } - err = os.WriteFile(keyPath, wca.keyPEM, 0644) + err = os.WriteFile(keyPath, wca.keyPEM, 0600) if err != nil { return nil, nil, trace.Wrap(err) } - err = os.WriteFile(userCAPath, k.buildAnchorsFileContents(wca.caCert), 0644) + err = os.WriteFile(userCAPath, k.buildAnchorsFileContents(wca.caCert), 0600) if err != nil { return nil, nil, trace.Wrap(err) } - krbConfPath := filepath.Join(tmp, fmt.Sprintf("krb_%s", k.userName)) + krbConfPath := filepath.Join(tmp, "krb5.conf") err = k.WriteKRB5Config(krbConfPath) if err != nil { return nil, nil, trace.Wrap(err) @@ -290,14 +272,19 @@ func (k *CommandLineInitializer) UseOrCreateCredentials(ctx context.Context) (*c cmd := k.command.CommandContext(ctx, k.binary, "-X", fmt.Sprintf("X509_anchors=FILE:%s", userCAPath), - "-X", fmt.Sprintf("X509_user_identity=FILE:%s,%s", certPath, keyPath), k.userName, - "-c", cachePath) + "-X", fmt.Sprintf("X509_user_identity=FILE:%s,%s", certPath, keyPath), + "-c", cachePath, + "--", k.userName, + ) if cmd.Err != nil { return nil, nil, trace.Wrap(cmd.Err) } cmd.Env = append(cmd.Env, []string{fmt.Sprintf("%s=%s", krb5ConfigEnv, krbConfPath)}...) + + k.log.Debugf("Running command: %v %v %s", strings.Join(cmd.Env, " "), cmd.Path, strings.Join(cmd.Args, " ")) + kinitOutput, err := cmd.CombinedOutput() if err != nil { k.log.Errorf("Failed to authenticate with KDC: %s", kinitOutput) diff --git a/lib/srv/db/sqlserver/kinit/kinit_test.go b/lib/srv/db/sqlserver/kinit/kinit_test.go index 9c0930345d762..95370eeacf38f 100644 --- a/lib/srv/db/sqlserver/kinit/kinit_test.go +++ b/lib/srv/db/sqlserver/kinit/kinit_test.go @@ -43,8 +43,29 @@ type badCache struct { t *testing.T } +func getCachePath(t *testing.T, args ...string) string { + if len(args) != 8 { + t.Fatalf("Unexpected args (%v): %v", len(args), args) + } + // example arguments: + // [-X X509_anchors=FILE:/tmp/kinit3779395068/userca.pem -X X509_user_identity=FILE:/tmp/kinit3779395068/cert.pem,/tmp/kinit3779395068/key.pem -c /tmp/kinit3779395068/krb5.cache -- alice] + if args[0] != "-X" { + t.Fatalf("Unexpected args (%v): %v", args[0], args) + } + if args[2] != "-X" { + t.Fatalf("Unexpected args (%v): %v", args[2], args) + } + if args[4] != "-c" { + t.Fatalf("Unexpected args (%v): %v", args[4], args) + } + if args[6] != "--" { + t.Fatalf("Unexpected args (%v): %v", args[6], args) + } + return args[5] +} + func (b *badCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd { - cachePath := args[len(args)-1] + cachePath := getCachePath(b.t, args...) require.NotEmpty(b.t, cachePath) err := os.WriteFile(cachePath, badCacheData, 0664) require.NoError(b.t, err) @@ -53,7 +74,7 @@ func (b *badCache) CommandContext(ctx context.Context, name string, args ...stri } func (s *staticCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd { - cachePath := args[len(args)-1] + cachePath := getCachePath(s.t, args...) require.NotEmpty(s.t, cachePath) err := os.WriteFile(cachePath, cacheData, 0664) require.NoError(s.t, err) @@ -90,11 +111,6 @@ type testCase struct { func step(t *testing.T, name string, cg CommandGenerator, c *testCertGetter, expectErr require.ErrorAssertionFunc, expectNil require.ValueAssertionFunc) *testCase { t.Helper() - dir := t.TempDir() - var err error - dir, err = os.MkdirTemp(dir, "krb5_cache") - require.NoError(t, err) - return &testCase{ name: name, initializer: New(NewCommandLineInitializer( @@ -103,7 +119,6 @@ func step(t *testing.T, name string, cg CommandGenerator, c *testCertGetter, exp Realm: "example.com", KDCHost: "host.example.com", AdminServer: "host.example.com", - DataDir: dir, Command: cg, CertGetter: c, })), @@ -183,7 +198,6 @@ func TestKRBConfString(t *testing.T) { AdminServer: "host.example.com", Command: &staticCache{t: t, pass: true}, CertGetter: &testCertGetter{pass: true}, - DataDir: t.TempDir(), }) tmp := t.TempDir()