Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions integrations/access/mattermost/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@

package mattermost

import (
"context"
"sync/atomic"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
)

type MattermostPostSlice []Post
type MattermostDataPostSet map[MattermostDataPost]struct{}
type MattermostDataPostSet map[common.MessageData]struct{}

func (slice MattermostPostSlice) Len() int {
return len(slice)
Expand All @@ -34,11 +42,28 @@ func (slice MattermostPostSlice) Swap(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
}

func (set MattermostDataPostSet) Add(msg MattermostDataPost) {
func (set MattermostDataPostSet) Add(msg common.MessageData) {
set[msg] = struct{}{}
}

func (set MattermostDataPostSet) Contains(msg MattermostDataPost) bool {
func (set MattermostDataPostSet) Contains(msg common.MessageData) bool {
_, ok := set[msg]
return ok
}

type fakeStatusSink struct {
status atomic.Pointer[types.PluginStatus]
}

func (s *fakeStatusSink) Emit(_ context.Context, status types.PluginStatus) error {
s.status.Store(&status)
return nil
}

func (s *fakeStatusSink) Get() types.PluginStatus {
status := s.status.Load()
if status == nil {
panic("expected status to be set, but it has not been")
}
return *status
}
51 changes: 34 additions & 17 deletions integrations/access/mattermost/mattermost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/testing/integration"
Expand All @@ -57,6 +58,7 @@ type MattermostSuite struct {
}
raceNumber int
fakeMattermost *FakeMattermost
fakeStatusSink *fakeStatusSink
mmUser User

clients map[string]*integration.Client
Expand Down Expand Up @@ -197,10 +199,13 @@ func (s *MattermostSuite) SetupTest() {
Email: s.userNames.requestor,
})

s.fakeStatusSink = &fakeStatusSink{}

var conf Config
conf.Teleport = s.teleportConfig
conf.Mattermost.Token = "000000"
conf.Mattermost.URL = s.fakeMattermost.URL()
conf.StatusSink = s.fakeStatusSink

s.appConfig = &conf
s.SetContextTimeout(5 * time.Second)
Expand Down Expand Up @@ -256,14 +261,16 @@ func (s *MattermostSuite) createAccessRequest(reviewers []User) types.AccessRequ
return req
}

func (s *MattermostSuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData {
func (s *MattermostSuite) checkPluginData(reqID string, cond func(common.GenericPluginData) bool) common.GenericPluginData {
t := s.T()
t.Helper()

for {
rawData, err := s.ruler().PollAccessRequestPluginData(s.Context(), "mattermost", reqID)
require.NoError(t, err)
if data := DecodePluginData(rawData); cond(data) {
data, err := common.DecodePluginData(rawData)
require.NoError(t, err)
if cond(data) {
return data
}
}
Expand All @@ -280,23 +287,23 @@ func (s *MattermostSuite) TestMattermostMessagePosting() {
s.startApp()
request := s.createAccessRequest([]User{reviewer2, reviewer1})

pluginData := s.checkPluginData(request.GetName(), func(data PluginData) bool {
return len(data.MattermostData) > 0
pluginData := s.checkPluginData(request.GetName(), func(data common.GenericPluginData) bool {
return len(data.SentMessages) > 0
})
assert.Len(t, pluginData.MattermostData, 2)
assert.Len(t, pluginData.SentMessages, 2)

var posts []Post
postSet := make(MattermostDataPostSet)
for i := 0; i < 2; i++ {
post, err := s.fakeMattermost.CheckNewPost(s.Context())
require.NoError(t, err, "no new messages posted")
postSet.Add(MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID})
postSet.Add(common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID})
posts = append(posts, post)
}

assert.Len(t, postSet, 2)
assert.Contains(t, postSet, pluginData.MattermostData[0])
assert.Contains(t, postSet, pluginData.MattermostData[1])
assert.Contains(t, postSet, pluginData.SentMessages[0])
assert.Contains(t, postSet, pluginData.SentMessages[1])

sort.Sort(MattermostPostSlice(posts))

Expand All @@ -321,6 +328,7 @@ func (s *MattermostSuite) TestMattermostMessagePosting() {
statusLine, err := parsePostField(post, "Status")
require.NoError(t, err)
assert.Equal(t, "⏳ PENDING", statusLine)
assert.Equal(t, types.PluginStatusCode_RUNNING, s.fakeStatusSink.Get().GetCode())
}

func (s *MattermostSuite) TestApproval() {
Expand Down Expand Up @@ -401,8 +409,8 @@ func (s *MattermostSuite) TestReviewComments() {
s.startApp()

req := s.createAccessRequest([]User{reviewer})
s.checkPluginData(req.GetName(), func(data PluginData) bool {
return len(data.MattermostData) > 0
s.checkPluginData(req.GetName(), func(data common.GenericPluginData) bool {
return len(data.SentMessages) > 0
})

post, err := s.fakeMattermost.CheckNewPost(s.Context())
Expand Down Expand Up @@ -487,6 +495,10 @@ func (s *MattermostSuite) TestApprovalByReview() {
assert.Equal(t, post.ID, comment.RootID)
assert.Contains(t, comment.Message, s.userNames.reviewer2+" reviewed the request", "comment must contain a review author")

// When posting a review, the bot also updates the message to add the amount of reviewers.
// This update is soon superseded by the "access allowed" update
_, _ = s.fakeMattermost.CheckPostUpdate(s.Context())

postUpdate, err := s.fakeMattermost.CheckPostUpdate(s.Context())
require.NoError(t, err, "no messages updated")
assert.Equal(t, post.ID, postUpdate.ID)
Expand Down Expand Up @@ -542,6 +554,10 @@ func (s *MattermostSuite) TestDenialByReview() {
})
require.NoError(t, err)

// When posting a review, the bot also updates the message to add the amount of reviewers.
// This update is soon superseded by the "access allowed" update
_, _ = s.fakeMattermost.CheckPostUpdate(s.Context())

comment, err = s.fakeMattermost.CheckNewPost(s.Context())
require.NoError(t, err)
assert.Equal(t, post.ChannelID, comment.ChannelID)
Expand Down Expand Up @@ -577,8 +593,8 @@ func (s *MattermostSuite) TestExpiration() {
directChannelID := s.fakeMattermost.GetDirectChannelFor(s.fakeMattermost.GetBotUser(), reviewer).ID
assert.Equal(t, directChannelID, post.ChannelID)

s.checkPluginData(request.GetName(), func(data PluginData) bool {
return len(data.MattermostData) > 0
s.checkPluginData(request.GetName(), func(data common.GenericPluginData) bool {
return len(data.SentMessages) > 0
})

err = s.ruler().DeleteAccessRequest(s.Context(), request.GetName()) // simulate expiration
Expand Down Expand Up @@ -657,7 +673,7 @@ func (s *MattermostSuite) TestRace() {
if post.RootID == "" {
// Handle "root" notifications.

postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID}
postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID}
if _, loaded := postIDs.LoadOrStore(postKey, struct{}{}); loaded {
return setRaceErr(trace.Errorf("post %v already stored", postKey))
}
Expand Down Expand Up @@ -695,7 +711,7 @@ func (s *MattermostSuite) TestRace() {
} else {
// Handle review comments.

postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.RootID}
postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.RootID}
var newCounter int32
val, _ := reviewCommentCounters.LoadOrStore(postKey, &newCounter)
counterPtr := val.(*int32)
Expand All @@ -707,14 +723,14 @@ func (s *MattermostSuite) TestRace() {
}

// Multiplier TWO means that we handle updates for each of the two messages posted to reviewers.
for i := 0; i < 2*s.raceNumber; i++ {
for i := 0; i < 2*2*s.raceNumber; i++ {
process.SpawnCritical(func(ctx context.Context) error {
post, err := s.fakeMattermost.CheckPostUpdate(ctx)
if err != nil {
return setRaceErr(trace.Wrap(err))
}

postKey := MattermostDataPost{ChannelID: post.ChannelID, PostID: post.ID}
postKey := common.MessageData{ChannelID: post.ChannelID, MessageID: post.ID}
var newCounter int32
val, _ := postUpdateCounters.LoadOrStore(postKey, &newCounter)
counterPtr := val.(*int32)
Expand All @@ -740,7 +756,8 @@ func (s *MattermostSuite) TestRace() {
val, loaded = postUpdateCounters.LoadAndDelete(key)
next = next && assert.True(t, loaded)
counterPtr = val.(*int32)
next = next && assert.Equal(t, int32(1), *counterPtr)
// Each message should be updated 2 times
next = next && assert.Equal(t, int32(2), *counterPtr)

return next
})
Expand Down
102 changes: 0 additions & 102 deletions integrations/access/mattermost/plugindata.go

This file was deleted.

Loading