diff --git a/lib/services/local/assistant.go b/lib/services/local/assistant.go index ce2133152fcee..542cbe71107ae 100644 --- a/lib/services/local/assistant.go +++ b/lib/services/local/assistant.go @@ -219,6 +219,15 @@ func (s *AssistService) CreateAssistantMessage(ctx context.Context, req *assist. return trace.BadParameter("missing conversation ID") } + // Check if the conversation exists. + conversationKey := backend.Key(assistantConversationPrefix, req.Username, req.ConversationId) + if _, err := s.Get(ctx, conversationKey); err != nil { + if trace.IsNotFound(err) { + return trace.NotFound("conversation %q not found", req.ConversationId) + } + return trace.Wrap(err) + } + msg := req.GetMessage() value, err := json.Marshal(msg) if err != nil { diff --git a/lib/services/local/assistant_test.go b/lib/services/local/assistant_test.go index 5a5853bd8ec93..90fc4642d11ea 100644 --- a/lib/services/local/assistant_test.go +++ b/lib/services/local/assistant_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" @@ -140,4 +141,18 @@ func TestAssistantCRUD(t *testing.T) { require.Equal(t, conversationID, conversations.Conversations[0].Id) require.Equal(t, conversationResp.Id, conversations.Conversations[1].Id) }) + + t.Run("refuse to add messages if conversion does not exist", func(t *testing.T) { + msg := &assist.CreateAssistantMessageRequest{ + Username: username, + ConversationId: uuid.New().String(), + Message: &assist.AssistantMessage{ + CreatedTime: timestamppb.New(time.Now()), + Payload: "foo", + Type: "USER_MSG", + }, + } + err := identity.CreateAssistantMessage(ctx, msg) + require.Error(t, err) + }) } diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 8f15b6a82bf51..87969a39f8621 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -17,6 +17,7 @@ package web import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -26,7 +27,6 @@ import ( "net/url" "testing" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -168,7 +168,13 @@ func Test_runAssistant(t *testing.T) { tc.setup(t, s) } - ws, err := s.makeAssistant(t, s.authPack(t, "foo")) + ctx := context.Background() + authPack := s.authPack(t, "foo") + // Create the conversation + conversationID := s.makeAssistConversation(t, ctx, authPack) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, conversationID) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -186,10 +192,26 @@ func Test_runAssistant(t *testing.T) { tc.act(t, ws) }) } +} + +// makeAssistConversation creates a new assist conversation and returns its ID +func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { + clt := authPack.clt + + resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) + require.NoError(t, err) + + convResp := struct { + ConversationID string `json:"id"` + }{} + err = json.Unmarshal(resp.Bytes(), &convResp) + require.NoError(t, err) + return convResp.ConversationID } -func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { +// makeAssistant creates a new assistant websocket connection. +func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID string) (*websocket.Conn, error) { u := url.URL{ Host: s.url().Host, Scheme: client.WSS, @@ -197,7 +219,7 @@ func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, } q := u.Query() - q.Set("conversation_id", uuid.New().String()) + q.Set("conversation_id", conversationID) q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode()