diff --git a/integrations/access/mattermost/helper_test.go b/integrations/access/mattermost/helper_test.go index dc11169808cdf..a248dd3816373 100644 --- a/integrations/access/mattermost/helper_test.go +++ b/integrations/access/mattermost/helper_test.go @@ -16,8 +16,16 @@ package mattermost +import ( + "context" + "sync/atomic" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" +) + type MattermostPostSlice []Post -type MattermostDataPostSet map[MattermostDataPost]struct{} +type MattermostDataPostSet map[common.MessageData]struct{} func (slice MattermostPostSlice) Len() int { return len(slice) @@ -34,11 +42,28 @@ func (slice MattermostPostSlice) Swap(i, j int) { slice[i], slice[j] = slice[j], slice[i] } -func (set MattermostDataPostSet) Add(msg MattermostDataPost) { +func (set MattermostDataPostSet) Add(msg common.MessageData) { set[msg] = struct{}{} } -func (set MattermostDataPostSet) Contains(msg MattermostDataPost) bool { +func (set MattermostDataPostSet) Contains(msg common.MessageData) bool { _, ok := set[msg] return ok } + +type fakeStatusSink struct { + status atomic.Pointer[types.PluginStatus] +} + +func (s *fakeStatusSink) Emit(_ context.Context, status types.PluginStatus) error { + s.status.Store(&status) + return nil +} + +func (s *fakeStatusSink) Get() types.PluginStatus { + status := s.status.Load() + if status == nil { + panic("expected status to be set, but it has not been") + } + return *status +} diff --git a/integrations/access/mattermost/mattermost_test.go b/integrations/access/mattermost/mattermost_test.go index ca181ef2cf8de..a48c4b3ea5e5c 100644 --- a/integrations/access/mattermost/mattermost_test.go +++ b/integrations/access/mattermost/mattermost_test.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/testing/integration" @@ -57,6 +58,7 @@ type MattermostSuite struct { } raceNumber int fakeMattermost *FakeMattermost + fakeStatusSink *fakeStatusSink mmUser User clients map[string]*integration.Client @@ -197,10 +199,13 @@ func (s *MattermostSuite) SetupTest() { Email: s.userNames.requestor, }) + s.fakeStatusSink = &fakeStatusSink{} + var conf Config conf.Teleport = s.teleportConfig conf.Mattermost.Token = "000000" conf.Mattermost.URL = s.fakeMattermost.URL() + conf.StatusSink = s.fakeStatusSink s.appConfig = &conf s.SetContextTimeout(5 * time.Second) @@ -256,14 +261,16 @@ func (s *MattermostSuite) createAccessRequest(reviewers []User) types.AccessRequ return req } -func (s *MattermostSuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData { +func (s *MattermostSuite) checkPluginData(reqID string, cond func(common.GenericPluginData) bool) common.GenericPluginData { t := s.T() t.Helper() for { rawData, err := s.ruler().PollAccessRequestPluginData(s.Context(), "mattermost", reqID) require.NoError(t, err) - if data := DecodePluginData(rawData); cond(data) { + data, err := common.DecodePluginData(rawData) + require.NoError(t, err) + if cond(data) { return data } } @@ -280,23 +287,23 @@ func (s *MattermostSuite) TestMattermostMessagePosting() { s.startApp() request := s.createAccessRequest([]User{reviewer2, reviewer1}) - pluginData := s.checkPluginData(request.GetName(), func(data PluginData) bool { - return len(data.MattermostData) > 0 + pluginData := s.checkPluginData(request.GetName(), func(data common.GenericPluginData) bool { + return len(data.SentMessages) > 0 }) - assert.Len(t, pluginData.MattermostData, 2) + assert.Len(t, pluginData.SentMessages, 2) var posts []Post postSet := make(MattermostDataPostSet) for i := 0; i < 2; i++ { post, err := s.fakeMattermost.CheckNewPost(s.Context()) require.NoError(t, err, "no new messages posted") - postSet.Add(MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID}) + postSet.Add(common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID}) posts = append(posts, post) } assert.Len(t, postSet, 2) - assert.Contains(t, postSet, pluginData.MattermostData[0]) - assert.Contains(t, postSet, pluginData.MattermostData[1]) + assert.Contains(t, postSet, pluginData.SentMessages[0]) + assert.Contains(t, postSet, pluginData.SentMessages[1]) sort.Sort(MattermostPostSlice(posts)) @@ -321,6 +328,7 @@ func (s *MattermostSuite) TestMattermostMessagePosting() { statusLine, err := parsePostField(post, "Status") require.NoError(t, err) assert.Equal(t, "⏳ PENDING", statusLine) + assert.Equal(t, types.PluginStatusCode_RUNNING, s.fakeStatusSink.Get().GetCode()) } func (s *MattermostSuite) TestApproval() { @@ -401,8 +409,8 @@ func (s *MattermostSuite) TestReviewComments() { s.startApp() req := s.createAccessRequest([]User{reviewer}) - s.checkPluginData(req.GetName(), func(data PluginData) bool { - return len(data.MattermostData) > 0 + s.checkPluginData(req.GetName(), func(data common.GenericPluginData) bool { + return len(data.SentMessages) > 0 }) post, err := s.fakeMattermost.CheckNewPost(s.Context()) @@ -487,6 +495,10 @@ func (s *MattermostSuite) TestApprovalByReview() { assert.Equal(t, post.ID, comment.RootID) assert.Contains(t, comment.Message, s.userNames.reviewer2+" reviewed the request", "comment must contain a review author") + // When posting a review, the bot also updates the message to add the amount of reviewers. + // This update is soon superseded by the "access allowed" update + _, _ = s.fakeMattermost.CheckPostUpdate(s.Context()) + postUpdate, err := s.fakeMattermost.CheckPostUpdate(s.Context()) require.NoError(t, err, "no messages updated") assert.Equal(t, post.ID, postUpdate.ID) @@ -542,6 +554,10 @@ func (s *MattermostSuite) TestDenialByReview() { }) require.NoError(t, err) + // When posting a review, the bot also updates the message to add the amount of reviewers. + // This update is soon superseded by the "access allowed" update + _, _ = s.fakeMattermost.CheckPostUpdate(s.Context()) + comment, err = s.fakeMattermost.CheckNewPost(s.Context()) require.NoError(t, err) assert.Equal(t, post.ChannelID, comment.ChannelID) @@ -577,8 +593,8 @@ func (s *MattermostSuite) TestExpiration() { directChannelID := s.fakeMattermost.GetDirectChannelFor(s.fakeMattermost.GetBotUser(), reviewer).ID assert.Equal(t, directChannelID, post.ChannelID) - s.checkPluginData(request.GetName(), func(data PluginData) bool { - return len(data.MattermostData) > 0 + s.checkPluginData(request.GetName(), func(data common.GenericPluginData) bool { + return len(data.SentMessages) > 0 }) err = s.ruler().DeleteAccessRequest(s.Context(), request.GetName()) // simulate expiration @@ -657,7 +673,7 @@ func (s *MattermostSuite) TestRace() { if post.RootID == "" { // Handle "root" notifications. - postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID} + postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID} if _, loaded := postIDs.LoadOrStore(postKey, struct{}{}); loaded { return setRaceErr(trace.Errorf("post %v already stored", postKey)) } @@ -695,7 +711,7 @@ func (s *MattermostSuite) TestRace() { } else { // Handle review comments. - postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.RootID} + postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.RootID} var newCounter int32 val, _ := reviewCommentCounters.LoadOrStore(postKey, &newCounter) counterPtr := val.(*int32) @@ -707,14 +723,14 @@ func (s *MattermostSuite) TestRace() { } // Multiplier TWO means that we handle updates for each of the two messages posted to reviewers. - for i := 0; i < 2*s.raceNumber; i++ { + for i := 0; i < 2*2*s.raceNumber; i++ { process.SpawnCritical(func(ctx context.Context) error { post, err := s.fakeMattermost.CheckPostUpdate(ctx) if err != nil { return setRaceErr(trace.Wrap(err)) } - postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID} + postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID} var newCounter int32 val, _ := postUpdateCounters.LoadOrStore(postKey, &newCounter) counterPtr := val.(*int32) @@ -740,7 +756,8 @@ func (s *MattermostSuite) TestRace() { val, loaded = postUpdateCounters.LoadAndDelete(key) next = next && assert.True(t, loaded) counterPtr = val.(*int32) - next = next && assert.Equal(t, int32(1), *counterPtr) + // Each message should be updated 2 times + next = next && assert.Equal(t, int32(2), *counterPtr) return next }) diff --git a/integrations/access/mattermost/plugindata.go b/integrations/access/mattermost/plugindata.go deleted file mode 100644 index 2d592d551fcfd..0000000000000 --- a/integrations/access/mattermost/plugindata.go +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mattermost - -import ( - "fmt" - "strings" -) - -// PluginData is a data associated with access request that we store in Teleport using UpdatePluginData API. -type PluginData struct { - RequestData - MattermostData -} - -type Resolution struct { - Tag ResolutionTag - Reason string -} -type ResolutionTag string - -const Unresolved = ResolutionTag("") -const ResolvedApproved = ResolutionTag("APPROVED") -const ResolvedDenied = ResolutionTag("DENIED") -const ResolvedExpired = ResolutionTag("EXPIRED") - -type RequestData struct { - User string - Roles []string - RequestReason string - ReviewsCount int - Resolution Resolution -} - -type MattermostDataPost struct { - PostID string - ChannelID string -} - -type MattermostData = []MattermostDataPost - -// DecodePluginData deserializes a string map to PluginData struct. -func DecodePluginData(dataMap map[string]string) (data PluginData) { - data.User = dataMap["user"] - if str := dataMap["roles"]; str != "" { - data.Roles = strings.Split(str, ",") - } - data.RequestReason = dataMap["request_reason"] - if str := dataMap["reviews_count"]; str != "" { - fmt.Sscanf(str, "%d", &data.ReviewsCount) - } - data.Resolution.Tag = ResolutionTag(dataMap["resolution"]) - data.Resolution.Reason = dataMap["resolve_reason"] - if channelID, postID := dataMap["channel_id"], dataMap["postID"]; channelID != "" && postID != "" { - data.MattermostData = append(data.MattermostData, MattermostDataPost{ChannelID: channelID, PostID: postID}) - } - if str := dataMap["messages"]; str != "" { - for _, encodedMsg := range strings.Split(str, ",") { - if parts := strings.Split(encodedMsg, "/"); len(parts) == 2 { - data.MattermostData = append(data.MattermostData, MattermostDataPost{ChannelID: parts[0], PostID: parts[1]}) - } - } - } - return -} - -// EncodePluginData serializes a PluginData struct into a string map. -func EncodePluginData(data PluginData) map[string]string { - result := make(map[string]string) - - result["user"] = data.User - result["roles"] = strings.Join(data.Roles, ",") - result["request_reason"] = data.RequestReason - var reviewsCountStr string - if data.ReviewsCount > 0 { - reviewsCountStr = fmt.Sprintf("%d", data.ReviewsCount) - } - result["reviews_count"] = reviewsCountStr - result["resolution"] = string(data.Resolution.Tag) - result["resolve_reason"] = data.Resolution.Reason - var encodedMessages []string - for _, msg := range data.MattermostData { - encodedMessages = append(encodedMessages, fmt.Sprintf("%s/%s", msg.ChannelID, msg.PostID)) - } - result["messages"] = strings.Join(encodedMessages, ",") - - return result -} diff --git a/integrations/access/mattermost/plugindata_test.go b/integrations/access/mattermost/plugindata_test.go deleted file mode 100644 index 5d2661a0bdb1e..0000000000000 --- a/integrations/access/mattermost/plugindata_test.go +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mattermost - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -var samplePluginData = PluginData{ - RequestData: RequestData{ - User: "user-foo", - Roles: []string{"role-foo", "role-bar"}, - RequestReason: "foo reason", - ReviewsCount: 3, - Resolution: Resolution{Tag: ResolvedApproved, Reason: "foo ok"}, - }, - MattermostData: MattermostData{ - {ChannelID: "CHANNEL1", PostID: "POST01"}, - {ChannelID: "CHANNEL2", PostID: "POST02"}, - }, -} - -func TestEncodePluginData(t *testing.T) { - dataMap := EncodePluginData(samplePluginData) - assert.Len(t, dataMap, 7) - assert.Equal(t, "user-foo", dataMap["user"]) - assert.Equal(t, "role-foo,role-bar", dataMap["roles"]) - assert.Equal(t, "foo reason", dataMap["request_reason"]) - assert.Equal(t, "3", dataMap["reviews_count"]) - assert.Equal(t, "APPROVED", dataMap["resolution"]) - assert.Equal(t, "foo ok", dataMap["resolve_reason"]) - assert.Equal(t, "CHANNEL1/POST01,CHANNEL2/POST02", dataMap["messages"]) -} - -func TestDecodePluginData(t *testing.T) { - pluginData := DecodePluginData(map[string]string{ - "user": "user-foo", - "roles": "role-foo,role-bar", - "request_reason": "foo reason", - "reviews_count": "3", - "resolution": "APPROVED", - "resolve_reason": "foo ok", - "messages": "CHANNEL1/POST01,CHANNEL2/POST02", - }) - assert.Equal(t, samplePluginData, pluginData) -} - -func TestEncodeEmptyPluginData(t *testing.T) { - dataMap := EncodePluginData(PluginData{}) - assert.Len(t, dataMap, 7) - for key, value := range dataMap { - assert.Emptyf(t, value, "value at key %q must be empty", key) - } -} - -func TestDecodeEmptyPluginData(t *testing.T) { - assert.Empty(t, DecodePluginData(nil)) - assert.Empty(t, DecodePluginData(make(map[string]string))) -}