Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lib/srv/db/sqlserver/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *connector) keytabClient(session *common.Session) (*client.Client, error
}

// kinitClient returns a kerberos client using a kinit ccache
func (c *connector) kinitClient(ctx context.Context, session *common.Session, auth windows.AuthInterface, dataDir string) (*client.Client, error) {
func (c *connector) kinitClient(ctx context.Context, session *common.Session, auth windows.AuthInterface) (*client.Client, error) {
ldapPem, _ := pem.Decode([]byte(session.Database.GetAD().LDAPCert))

if ldapPem == nil {
Expand Down Expand Up @@ -98,7 +98,6 @@ func (c *connector) kinitClient(ctx context.Context, session *common.Session, au
Realm: realmName,
KDCHost: session.Database.GetAD().KDCHostName,
AdminServer: session.Database.GetAD().Domain,
DataDir: dataDir,
LDAPCA: cert,
LDAPCAPEM: session.Database.GetAD().LDAPCert,
Command: c.kinitCommandGenerator,
Expand Down
4 changes: 1 addition & 3 deletions lib/srv/db/sqlserver/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ type connector struct {
DBAuth common.Auth
// AuthClient is the teleport client
AuthClient windows.AuthInterface
// DataDir is the Teleport data directory
DataDir string

kinitCommandGenerator kinit.CommandGenerator
}
Expand All @@ -75,7 +73,7 @@ func (c *connector) getKerberosClient(ctx context.Context, sessionCtx *common.Se
}
return kt, nil
case sessionCtx.Database.GetAD().KDCHostName != "" && sessionCtx.Database.GetAD().LDAPCert != "":
kt, err := c.kinitClient(ctx, sessionCtx, c.AuthClient, c.DataDir)
kt, err := c.kinitClient(ctx, sessionCtx, c.AuthClient)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
24 changes: 22 additions & 2 deletions lib/srv/db/sqlserver/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,29 @@ type staticCache struct {
pass bool
}

func getCachePath(t *testing.T, args ...string) string {
if len(args) != 8 {
t.Fatalf("Unexpected args (%v): %v", len(args), args)
}
// example arguments:
// [-X X509_anchors=FILE:/tmp/kinit3779395068/userca.pem -X X509_user_identity=FILE:/tmp/kinit3779395068/cert.pem,/tmp/kinit3779395068/key.pem -c /tmp/kinit3779395068/krb5.cache -- alice]
if args[0] != "-X" {
t.Fatalf("Unexpected args (%v): %v", args[0], args)
}
if args[2] != "-X" {
t.Fatalf("Unexpected args (%v): %v", args[2], args)
}
if args[4] != "-c" {
t.Fatalf("Unexpected args (%v): %v", args[4], args)
}
if args[6] != "--" {
t.Fatalf("Unexpected args (%v): %v", args[6], args)
}
return args[5]
}

func (s *staticCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd {
cachePath := args[len(args)-1]
cachePath := getCachePath(s.t, args...)
require.NotEmpty(s.t, cachePath)
err := os.WriteFile(cachePath, cacheData, 0664)
require.NoError(s.t, err)
Expand Down Expand Up @@ -315,7 +336,6 @@ func TestConnectorKInitClient(t *testing.T) {

databaseUser := "alice"
databaseName := database.GetName()
connector.DataDir = dir

connectorCtx, cancel := context.WithCancel(ctx)
// we want to pass the canceled context
Expand Down
1 change: 0 additions & 1 deletion lib/srv/db/sqlserver/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func NewEngine(ec common.EngineConfig) common.Engine {
Connector: &connector{
DBAuth: ec.Auth,
AuthClient: ec.AuthClient,
DataDir: ec.DataDir,
},
}
}
Expand Down
45 changes: 16 additions & 29 deletions lib/srv/db/sqlserver/kinit/kinit.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"os"
"os/exec"
"path/filepath"
"strings"
"text/template"
"time"

Expand Down Expand Up @@ -94,8 +95,6 @@ type CommandConfig struct {
KDCHost string
// AdminServer is the administration server hostname (usually AD server)
AdminServer string
// DataDir is the Teleport Data Directory
DataDir string
// LDAPCA is the Windows LDAP Certificate for client signing
LDAPCA *x509.Certificate
// LDAPCAPEM contains the same certificate as LDAPCA but in PEM format. It
Expand All @@ -114,13 +113,9 @@ func NewCommandLineInitializer(config CommandConfig) *CommandLineInitializer {
cmd := &CommandLineInitializer{
auth: config.AuthClient,
userName: config.User,
cacheName: fmt.Sprintf("%s@%s", config.User, config.Realm),
RealmName: config.Realm,
KDCHostName: config.KDCHost,
AdminServerName: config.AdminServer,
dataDir: config.DataDir,
certPath: fmt.Sprintf("%s.pem", config.User),
keyPath: fmt.Sprintf("%s-key.pem", config.User),
binary: kinitBinary,
command: config.Command,
certGetter: config.CertGetter,
Expand Down Expand Up @@ -160,12 +155,7 @@ type CommandLineInitializer struct {
// AdminServerName is the admin server Name (usually AD host)
AdminServerName string

dataDir string
userName string
cacheName string

certPath string
keyPath string
userName string
binary string

command CommandGenerator
Expand Down Expand Up @@ -242,41 +232,33 @@ func (k *CommandLineInitializer) UseOrCreateCredentials(ctx context.Context) (*c
}
}()

certPath := filepath.Join(tmp, fmt.Sprintf("%s.pem", k.userName))
keyPath := filepath.Join(tmp, fmt.Sprintf("%s-key.pem", k.userName))
certPath := filepath.Join(tmp, "cert.pem")
keyPath := filepath.Join(tmp, "key.pem")
userCAPath := filepath.Join(tmp, "userca.pem")

cacheDir := filepath.Join(k.dataDir, "krb5_cache")

err = os.MkdirAll(cacheDir, os.ModePerm)
if err != nil {
return nil, nil, trace.Wrap(err)
}

cachePath := filepath.Join(cacheDir, k.cacheName)
cachePath := filepath.Join(tmp, "krb5.cache")

wca, err := k.certGetter.GetCertificateBytes(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
}

// store files in temp dir
err = os.WriteFile(certPath, wca.certPEM, 0644)
err = os.WriteFile(certPath, wca.certPEM, 0600)
if err != nil {
return nil, nil, trace.Wrap(err)
}

err = os.WriteFile(keyPath, wca.keyPEM, 0644)
err = os.WriteFile(keyPath, wca.keyPEM, 0600)
if err != nil {
return nil, nil, trace.Wrap(err)
}

err = os.WriteFile(userCAPath, k.buildAnchorsFileContents(wca.caCert), 0644)
err = os.WriteFile(userCAPath, k.buildAnchorsFileContents(wca.caCert), 0600)
if err != nil {
return nil, nil, trace.Wrap(err)
}

krbConfPath := filepath.Join(tmp, fmt.Sprintf("krb_%s", k.userName))
krbConfPath := filepath.Join(tmp, "krb5.conf")
err = k.WriteKRB5Config(krbConfPath)
if err != nil {
return nil, nil, trace.Wrap(err)
Expand All @@ -290,14 +272,19 @@ func (k *CommandLineInitializer) UseOrCreateCredentials(ctx context.Context) (*c
cmd := k.command.CommandContext(ctx,
k.binary,
"-X", fmt.Sprintf("X509_anchors=FILE:%s", userCAPath),
"-X", fmt.Sprintf("X509_user_identity=FILE:%s,%s", certPath, keyPath), k.userName,
"-c", cachePath)
"-X", fmt.Sprintf("X509_user_identity=FILE:%s,%s", certPath, keyPath),
"-c", cachePath,
"--", k.userName,
)

if cmd.Err != nil {
return nil, nil, trace.Wrap(cmd.Err)
}

cmd.Env = append(cmd.Env, []string{fmt.Sprintf("%s=%s", krb5ConfigEnv, krbConfPath)}...)

k.log.Debugf("Running command: %v %v %s", strings.Join(cmd.Env, " "), cmd.Path, strings.Join(cmd.Args, " "))

kinitOutput, err := cmd.CombinedOutput()
if err != nil {
k.log.Errorf("Failed to authenticate with KDC: %s", kinitOutput)
Expand Down
32 changes: 23 additions & 9 deletions lib/srv/db/sqlserver/kinit/kinit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,29 @@ type badCache struct {
t *testing.T
}

func getCachePath(t *testing.T, args ...string) string {
if len(args) != 8 {
t.Fatalf("Unexpected args (%v): %v", len(args), args)
}
// example arguments:
// [-X X509_anchors=FILE:/tmp/kinit3779395068/userca.pem -X X509_user_identity=FILE:/tmp/kinit3779395068/cert.pem,/tmp/kinit3779395068/key.pem -c /tmp/kinit3779395068/krb5.cache -- alice]
if args[0] != "-X" {
t.Fatalf("Unexpected args (%v): %v", args[0], args)
}
if args[2] != "-X" {
t.Fatalf("Unexpected args (%v): %v", args[2], args)
}
if args[4] != "-c" {
t.Fatalf("Unexpected args (%v): %v", args[4], args)
}
if args[6] != "--" {
t.Fatalf("Unexpected args (%v): %v", args[6], args)
}
return args[5]
}

func (b *badCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd {
cachePath := args[len(args)-1]
cachePath := getCachePath(b.t, args...)
require.NotEmpty(b.t, cachePath)
err := os.WriteFile(cachePath, badCacheData, 0664)
require.NoError(b.t, err)
Expand All @@ -53,7 +74,7 @@ func (b *badCache) CommandContext(ctx context.Context, name string, args ...stri
}

func (s *staticCache) CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd {
cachePath := args[len(args)-1]
cachePath := getCachePath(s.t, args...)
require.NotEmpty(s.t, cachePath)
err := os.WriteFile(cachePath, cacheData, 0664)
require.NoError(s.t, err)
Expand Down Expand Up @@ -90,11 +111,6 @@ type testCase struct {
func step(t *testing.T, name string, cg CommandGenerator, c *testCertGetter, expectErr require.ErrorAssertionFunc, expectNil require.ValueAssertionFunc) *testCase {
t.Helper()

dir := t.TempDir()
var err error
dir, err = os.MkdirTemp(dir, "krb5_cache")
require.NoError(t, err)

return &testCase{
name: name,
initializer: New(NewCommandLineInitializer(
Expand All @@ -103,7 +119,6 @@ func step(t *testing.T, name string, cg CommandGenerator, c *testCertGetter, exp
Realm: "example.com",
KDCHost: "host.example.com",
AdminServer: "host.example.com",
DataDir: dir,
Command: cg,
CertGetter: c,
})),
Expand Down Expand Up @@ -183,7 +198,6 @@ func TestKRBConfString(t *testing.T) {
AdminServer: "host.example.com",
Command: &staticCache{t: t, pass: true},
CertGetter: &testCertGetter{pass: true},
DataDir: t.TempDir(),
})

tmp := t.TempDir()
Expand Down
Loading