Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions lib/kube/proxy/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,9 @@ func newSession(ctx authContext, forwarder *Forwarder, req *http.Request, params
id := uuid.New()
log := forwarder.log.WithField("session", id.String())
log.Debug("Creating session")
roles, err := getRolesByName(forwarder, ctx.Context.Identity.GetIdentity().Groups)
if err != nil {
return nil, trace.Wrap(err)
}

var policySets []*types.SessionTrackerPolicySet
roles := ctx.Checker.Roles()
for _, role := range roles {
policySet := role.GetSessionPolicySet()
policySets = append(policySets, &policySet)
Expand Down Expand Up @@ -1192,16 +1189,39 @@ func (s *session) trackSession(p *party, policySet []*types.SessionTrackerPolicy
HostUser: p.Ctx.User.GetName(),
HostPolicies: policySet,
Login: "root",
Created: time.Now(),
Created: s.forwarder.cfg.Clock.Now(),
Reason: s.reason,
Invited: s.invitedUsers,
}

s.log.Debug("Creating session tracker")
var err error
s.tracker, err = srv.NewSessionTracker(s.forwarder.ctx, trackerSpec, s.forwarder.cfg.AuthClient)
if err != nil {
sessionTrackerService := s.forwarder.cfg.AuthClient

ctx := s.req.Context()

tracker, err := srv.NewSessionTracker(ctx, trackerSpec, sessionTrackerService)
switch {
// there was an error creating the tracker for a moderated session - terminate the session
case err != nil && s.accessEvaluator.IsModerated():
s.log.WithError(err).Warn("Failed to create session tracker, unable to proceed for moderated session")
return trace.Wrap(err)
// there was an error creating the tracker for a non-moderated session - permit the session with a local tracker
case err != nil && !s.accessEvaluator.IsModerated():
s.log.Warn("Failed to create session tracker, proceeding with local session tracker for non-moderated session")

localTracker, err := srv.NewSessionTracker(ctx, trackerSpec, nil)
// this error means there are problems with the trackerSpec, we need to return it
if err != nil {
return trace.Wrap(err)
}

s.tracker = localTracker
// there was an error even though the tracker wasn't being propagated - return it
case err != nil:
return trace.Wrap(err)
// the tracker was created successfully
case err == nil:
s.tracker = tracker
}

go func() {
Expand Down
158 changes: 158 additions & 0 deletions lib/kube/proxy/sess_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
Copyright 2021 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package proxy

import (
"context"
"fmt"
"net/http"
"testing"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/authz"
)

func Test_session_trackSession(t *testing.T) {
t.Parallel()
moderatedPolicy := &types.SessionTrackerPolicySet{
Version: types.V3,
Name: "name",
RequireSessionJoin: []*types.SessionRequirePolicy{
{
Name: "Auditor oversight",
Filter: fmt.Sprintf("contains(user.spec.roles, %q)", "test"),
Kinds: []string{"k8s"},
Modes: []string{string(types.SessionModeratorMode)},
Count: 1,
},
},
}
nonModeratedPolicy := &types.SessionTrackerPolicySet{
Version: types.V3,
Name: "name",
}
type args struct {
authClient auth.ClientI
policies []*types.SessionTrackerPolicySet
}
tests := []struct {
name string
args args
assertErr require.ErrorAssertionFunc
}{
{
name: "ok with moderated session and healthy auth service",
args: args{
authClient: &mockSessionTrackerService{},
policies: []*types.SessionTrackerPolicySet{
moderatedPolicy,
},
},
assertErr: require.NoError,
},
{
name: "ok with non-moderated session session and healthy auth service",
args: args{
authClient: &mockSessionTrackerService{},
policies: []*types.SessionTrackerPolicySet{
nonModeratedPolicy,
},
},
assertErr: require.NoError,
},
{
name: "fail with moderated session and unhealthy auth service",
args: args{
authClient: &mockSessionTrackerService{
returnErr: true,
},
policies: []*types.SessionTrackerPolicySet{
moderatedPolicy,
},
},
assertErr: require.Error,
},
{
name: "ok with non-moderated session session and unhealthy auth service",
args: args{
authClient: &mockSessionTrackerService{
returnErr: true,
},
policies: []*types.SessionTrackerPolicySet{
nonModeratedPolicy,
},
},
assertErr: require.NoError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sess := &session{
log: logrus.New().WithField(trace.Component, "test"),
id: uuid.New(),
req: &http.Request{},
podName: "podName",
accessEvaluator: auth.NewSessionAccessEvaluator(tt.args.policies, types.KubernetesSessionKind, "username"),
ctx: authContext{
Context: authz.Context{
User: &types.UserV2{
Metadata: types.Metadata{
Name: "username",
},
},
},
teleportCluster: teleportClusterClient{
name: "name",
},
kubeClusterName: "kubeClusterName",
},
forwarder: &Forwarder{
cfg: ForwarderConfig{
Clock: clockwork.NewFakeClock(),
AuthClient: tt.args.authClient,
},
ctx: context.Background(),
},
}
p := &party{
Ctx: sess.ctx,
}
err := sess.trackSession(p, tt.args.policies)
tt.assertErr(t, err)
})
}
}

type mockSessionTrackerService struct {
auth.ClientI
returnErr bool
}

func (m *mockSessionTrackerService) CreateSessionTracker(ctx context.Context, tracker types.SessionTracker) (types.SessionTracker, error) {
if m.returnErr {
return nil, trace.ConnectionProblem(nil, "mock error")
}
return tracker, nil
}