Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions lib/ai/testutils/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package testutils

import (
"encoding/json"
"net/http"
"strconv"
"testing"
"time"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

// GetTestHandlerFn returns a handler function that can be used to OpenAI API used by
// the chat API. It takes a list of responses that will be returned in order.
func GetTestHandlerFn(t *testing.T, responses []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost || !(r.URL.Path == "/chat/completions") {
http.Error(w, "Unexpected request", http.StatusBadRequest)
return
}

switch r.Header.Get("Accept") {
case "application/json; charset=utf-8", "application/json":
responses = messageResponse(w, r, t, responses)
case "text/event-stream":
responses = streamResponse(w, t, responses)
default:
http.Error(w, "Unexpected request", http.StatusBadRequest)
}
}
}

func streamResponse(w http.ResponseWriter, t *testing.T, responses []string) []string {
w.Header().Set("Content-Type", "text/event-stream")

if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") {
http.Error(w, "Unexpected request", http.StatusBadRequest)
return responses
}

resp := &openai.ChatCompletionStreamResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "completion",
Created: time.Now().Unix(),
Model: openai.GPT4,
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: responses[0],
Role: openai.ChatMessageRoleAssistant,
},
FinishReason: "",
},
},
}

respBytes, err := json.Marshal(resp)
assert.NoError(t, err, "Marshal error")

_, err = w.Write([]byte("data: "))
assert.NoError(t, err, "Write error")
_, err = w.Write(respBytes)
assert.NoError(t, err, "Write error")
_, err = w.Write([]byte("\n\nevent: done\ndata: [DONE]\n\n"))
assert.NoError(t, err, "Write error")

return responses[1:]
}

func messageResponse(w http.ResponseWriter, r *http.Request, t *testing.T, responses []string) []string {
w.Header().Set("Content-Type", "application/json")

req := &openai.ChatCompletionRequest{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
}

// Use assert as require doesn't work when called from a goroutine
if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") {
http.Error(w, "Unexpected request", http.StatusBadRequest)
return responses
}

dataBytes := responses[0]

resp := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
Model: req.Model,
Choices: []openai.ChatCompletionChoice{
{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: dataBytes,
Name: "",
},
},
},
Usage: openai.Usage{},
}

respBytes, err := json.Marshal(resp)
assert.NoError(t, err, "Marshal error")

_, err = w.Write(respBytes)
assert.NoError(t, err, "Write error")

return responses[1:]
}
101 changes: 95 additions & 6 deletions lib/auth/assist/assistv1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,78 @@ import (
"context"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/services"
)

// ServiceConfig holds configuration options for
// the assist gRPC service.
type ServiceConfig struct {
Backend services.Assistant
Backend services.Assistant
Authorizer authz.Authorizer
Logger *logrus.Entry
}

// Service implements the teleport.assist.v1.AssistService RPC service.
type Service struct {
assist.UnimplementedAssistServiceServer

backend services.Assistant
backend services.Assistant
authorizer authz.Authorizer
log *logrus.Entry
}

// NewService returns a new assist gRPC service.
func NewService(cfg *ServiceConfig) (*Service, error) {
switch {
case cfg.Backend == nil:
return nil, trace.BadParameter("backend is required")
case cfg.Authorizer == nil:
return nil, trace.BadParameter("authorizer is required")
case cfg.Logger == nil:
cfg.Logger = logrus.WithField(trace.Component, "assist.service")
}

return &Service{
backend: cfg.Backend,
backend: cfg.Backend,
authorizer: cfg.Authorizer,
log: cfg.Logger,
}, nil
}

// CreateAssistantConversation creates a new conversation entry in the backend.
func (a *Service) CreateAssistantConversation(ctx context.Context, req *assist.CreateAssistantConversationRequest) (*assist.CreateAssistantConversationResponse, error) {
authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

if userHasAccess(authCtx, req) {
return nil, trace.AccessDenied("user %q is not allowed to create conversation for user %q", authCtx.User.GetName(), req.Username)
}

resp, err := a.backend.CreateAssistantConversation(ctx, req)
return resp, trace.Wrap(err)
}

// UpdateAssistantConversationInfo updates the conversation info for a conversation.
func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) {
err := a.backend.UpdateAssistantConversationInfo(ctx, request)
func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, req *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) {
authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbUpdate)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

if userHasAccess(authCtx, req) {
return nil, trace.AccessDenied("user %q is not allowed to update conversation for user %q", authCtx.User.GetName(), req.Username)
}

err = a.backend.UpdateAssistantConversationInfo(ctx, req)
if err != nil {
return &emptypb.Empty{}, trace.Wrap(err)
}
Expand All @@ -71,22 +103,79 @@ func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request *

// GetAssistantConversations returns all conversations started by a user.
func (a *Service) GetAssistantConversations(ctx context.Context, req *assist.GetAssistantConversationsRequest) (*assist.GetAssistantConversationsResponse, error) {
authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbList)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

if userHasAccess(authCtx, req) {
return nil, trace.AccessDenied("user %q is not allowed to list conversations for user %q", authCtx.User.GetName(), req.GetUsername())
}

resp, err := a.backend.GetAssistantConversations(ctx, req)
return resp, trace.Wrap(err)
}

// GetAssistantMessages returns all messages with given conversation ID.
func (a *Service) GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) {
authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbRead)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

if userHasAccess(authCtx, req) {
return nil, trace.AccessDenied("user %q is not allowed to get messages for user %q", authCtx.User.GetName(), req.GetUsername())
}

resp, err := a.backend.GetAssistantMessages(ctx, req)
return resp, trace.Wrap(err)
}

// CreateAssistantMessage adds the message to the backend.
func (a *Service) CreateAssistantMessage(ctx context.Context, req *assist.CreateAssistantMessageRequest) (*emptypb.Empty, error) {
authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

if userHasAccess(authCtx, req) {
return nil, trace.AccessDenied("user %q is not allowed to create message for user %q", authCtx.User.GetName(), req.GetUsername())
}

return &emptypb.Empty{}, trace.Wrap(a.backend.CreateAssistantMessage(ctx, req))
}

// IsAssistEnabled returns true if the assist is enabled or not on the auth level.
func (a *Service) IsAssistEnabled(ctx context.Context, req *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) {
func (a *Service) IsAssistEnabled(ctx context.Context, _ *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) {
authCtx, err := a.authorizer.Authorize(ctx)
if err != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}

// Check if this endpoint is called by a user or Proxy.
if authz.IsLocalUser(*authCtx) {
checkErr := authCtx.Checker.CheckAccessToRule(
&services.Context{User: authCtx.User},
defaults.Namespace, types.KindAssistant, types.VerbRead,
false, /* silent */
)
if checkErr != nil {
return nil, authz.ConvertAuthorizerError(ctx, a.log, err)
}
} else {
// This endpoint is called from Proxy to check if the assist is enabled.
// Proxy credentials are used instead of the user credentials.
requestedByProxy := authz.HasBuiltinRole(*authCtx, string(types.RoleProxy))
if !requestedByProxy {
return nil, trace.AccessDenied("only proxy is allowed to call IsAssistEnabled endpoint")
}
}

// Check if assist can use the backend.
return a.backend.IsAssistEnabled(ctx)
}

// userHasAccess returns true if the user should have access to the resource.
func userHasAccess(authCtx *authz.Context, req interface{ GetUsername() string }) bool {
return !authz.IsCurrentUser(*authCtx, req.GetUsername()) && !authz.HasBuiltinRole(*authCtx, string(types.RoleAdmin))
}
Loading