Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions integration/appaccess/pack.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
"github.com/gravitational/teleport/lib/web/app"
websession "github.com/gravitational/teleport/lib/web/session"
)

// Pack contains identity as well as initialized Teleport clusters and instances.
Expand Down Expand Up @@ -243,7 +244,7 @@ func (p *Pack) initWebSession(t *testing.T) {
// Extract session cookie and bearer token.
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
require.Equal(t, cookie.Name, web.CookieName)
require.Equal(t, cookie.Name, websession.CookieName)

p.webCookie = cookie.Value
p.webToken = csResp.Token
Expand Down Expand Up @@ -347,7 +348,7 @@ func (p *Pack) makeWebapiRequest(method, endpoint string, payload []byte) (int,
}

req.AddCookie(&http.Cookie{
Name: web.CookieName,
Name: websession.CookieName,
Value: p.webCookie,
})
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", p.webToken))
Expand Down
5 changes: 3 additions & 2 deletions integration/helpers/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
websession "github.com/gravitational/teleport/lib/web/session"
)

const (
Expand Down Expand Up @@ -1413,8 +1414,8 @@ func (i *TeleInstance) NewWebClient(cfg ClientConfig) (*WebClient, error) {
return nil, trace.BadParameter("unexpected number of cookies returned; got %d, want %d", len(cookies), 1)
}
cookie := cookies[0]
if cookie.Name != web.CookieName {
return nil, trace.BadParameter("unexpected session cookies returned; got %s, want %s", cookie.Name, web.CookieName)
if cookie.Name != websession.CookieName {
return nil, trace.BadParameter("unexpected session cookies returned; got %s, want %s", cookie.Name, websession.CookieName)
}

tc, err := i.NewUnauthenticatedClient(cfg)
Expand Down
5 changes: 3 additions & 2 deletions integration/helpers/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
websession "github.com/gravitational/teleport/lib/web/session"
"github.com/gravitational/teleport/lib/web/ui"
)

Expand Down Expand Up @@ -91,7 +92,7 @@ func LoginWebClient(t *testing.T, host, username, password string) *WebClientPac
// Extract session cookie and bearer token.
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
require.Equal(t, cookie.Name, web.CookieName)
require.Equal(t, cookie.Name, websession.CookieName)

webClient := &WebClientPack{
clt: client,
Expand Down Expand Up @@ -127,7 +128,7 @@ func (w *WebClientPack) DoRequest(t *testing.T, method, endpoint string, payload
require.NoError(t, err)

req.AddCookie(&http.Cookie{
Name: web.CookieName,
Name: websession.CookieName,
Value: w.webCookie,
})
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", w.bearerToken))
Expand Down
236 changes: 236 additions & 0 deletions lib/benchmark/web.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package benchmark

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"

apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
)

// WebSSHBenchmark is a benchmark suite that connects to the configured
// target hosts via the web api and executes the provided command.
type WebSSHBenchmark struct {
// Command to execute on the host.
Command []string
// Random whether to connect to a random host or not
Random bool
// Duration of the test used to determine if renewing web sessions
// is necessary.
Duration time.Duration
}

// BenchBuilder returns a WorkloadFunc for the given benchmark suite.
func (s WebSSHBenchmark) BenchBuilder(ctx context.Context, tc *client.TeleportClient) (WorkloadFunc, error) {
clt, sess, err := tc.LoginWeb(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

webSess := &webSession{
webSession: sess,
clt: clt,
}

// The web session will expire before the duration of the test
// so launch the renewal loop.
if !time.Now().Add(s.Duration).Before(webSess.expires()) {
go webSess.renew(ctx)
}

// Add "exit" to ensure that the session terminates after running the command.
command := strings.Join(append(s.Command, "\r\nexit\r\n"), " ")

if s.Random {
if tc.Host != "all" {
return nil, trace.BadParameter("random ssh bench commands must use the format <user>@all <command>")
}

servers, err := s.getServers(ctx, tc)
if err != nil {
return nil, trace.Wrap(err)
}

return func(ctx context.Context) error {
return trace.Wrap(s.runCommand(ctx, tc, webSess, chooseRandomHost(servers), command))
}, nil
}

return func(ctx context.Context) error {
return trace.Wrap(s.runCommand(ctx, tc, webSess, tc.Host, command))
}, nil
}

// runCommand starts a non-interactive SSH session and executes the provided
// command before terminating the session.
func (s WebSSHBenchmark) runCommand(ctx context.Context, tc *client.TeleportClient, webSess *webSession, host, command string) error {
stream, err := s.connectToHost(ctx, tc, webSess, host)
if err != nil {
return trace.Wrap(err)
}
defer stream.Close()

if _, err := io.WriteString(stream, command); err != nil {
return trace.Wrap(err)
}

if _, err := io.Copy(tc.Stdout, stream); err != nil && !errors.Is(err, io.EOF) {
return trace.Wrap(err)
}

return nil
}

// getServers returns all [types.Server] that the authenticated user has
// access to.
func (s WebSSHBenchmark) getServers(ctx context.Context, tc *client.TeleportClient) ([]types.Server, error) {
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
defer clt.Close()

resources, err := apiclient.GetAllResources[types.Server](ctx, clt.AuthClient, tc.ResourceFilter(types.KindNode))
if err != nil {
return nil, trace.Wrap(err)
}

if len(resources) == 0 {
return nil, trace.BadParameter("no target hosts available")
}

return resources, nil
}

// connectToHost opens an SSH session to the target host via the Proxy web api.
func (s WebSSHBenchmark) connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *webSession, host string) (*web.TerminalStream, error) {
req := web.TerminalRequest{
Server: host,
Login: tc.HostLogin,
Term: session.TerminalParams{
W: 100,
H: 100,
},
}

data, err := json.Marshal(req)
if err != nil {
return nil, trace.Wrap(err)
}

u := url.URL{
Host: tc.WebProxyAddr,
Scheme: client.WSS,
Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", tc.SiteName),
RawQuery: url.Values{
"params": []string{string(data)},
roundtrip.AccessTokenQueryParam: []string{webSession.getToken()},
}.Encode(),
}

dialer := websocket.Dialer{
TLSClientConfig: &tls.Config{InsecureSkipVerify: tc.InsecureSkipVerify},
Jar: webSession.getCookieJar(),
}

ws, resp, err := dialer.DialContext(ctx, u.String(), http.Header{
"Origin": []string{"http://localhost"},
})
if err != nil {
return nil, trace.Wrap(err)
}
defer resp.Body.Close()

ty, _, err := ws.ReadMessage()
if err != nil {
return nil, trace.Wrap(err)
}

if ty != websocket.BinaryMessage {
return nil, trace.BadParameter("unexpected websocket message received %d", ty)
}

stream := web.NewTerminalStream(ctx, ws, utils.NewLogger())
return stream, trace.Wrap(err)
}

type webSession struct {
mu sync.Mutex
webSession types.WebSession
clt *client.WebClient
}

func (s *webSession) renew(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Until(s.expires().Add(-3 * time.Minute))):
resp, err := s.clt.PostJSON(ctx, s.clt.Endpoint("webapi", "sessions", "renew"), nil)
if err != nil {
continue
}

session, err := client.GetSessionFromResponse(resp)
if err != nil {
continue
}

s.mu.Lock()
s.webSession = session
s.mu.Unlock()
}
}
}

func (s *webSession) expires() time.Time {
s.mu.Lock()
defer s.mu.Unlock()

return s.webSession.GetBearerTokenExpiryTime()
}

func (s *webSession) getCookieJar() http.CookieJar {
s.mu.Lock()
defer s.mu.Unlock()

return s.clt.HTTPClient().Jar
}

func (s *webSession) getToken() string {
s.mu.Lock()
defer s.mu.Unlock()

return s.webSession.GetBearerToken()
}
Loading