Skip to content

Commit

Permalink
service side replay cache thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmturner committed Oct 21, 2017
1 parent 0f10b62 commit 21f62e2
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 27 deletions.
57 changes: 42 additions & 15 deletions service/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ them following an event that caused the server to lose track of
recently seen authenticators.*/

// Cache for tickets received from clients keyed by fully qualified client name. Used to track replay of tickets.
type Cache map[string]clientEntries
type Cache struct {
Entries map[string]clientEntries
mux sync.RWMutex
}

// clientEntries holds entries of client details sent to the service.
type clientEntries struct {
Expand All @@ -46,6 +49,24 @@ type replayCacheEntry struct {
CTime time.Time // This combines the ticket's CTime and Cusec
}

func (c *Cache) getClientEntries(cname types.PrincipalName) (clientEntries, bool) {
c.mux.RLock()
defer c.mux.RUnlock()
ce, ok := c.Entries[cname.GetPrincipalNameString()]
return ce, ok
}

func (c *Cache) getClientEntry(cname types.PrincipalName, t time.Time) (replayCacheEntry, bool) {
if ce, ok := c.getClientEntries(cname); ok {
c.mux.RLock()
defer c.mux.RUnlock()
if e, ok := ce.ReplayMap[t]; ok {
return e, true
}
}
return replayCacheEntry{}, false
}

// Instance of the ServiceCache. This needs to be a singleton.
var replayCache Cache
var once sync.Once
Expand All @@ -54,7 +75,9 @@ var once sync.Once
func GetReplayCache(d time.Duration) *Cache {
// Create a singleton of the ReplayCache and start a background thread to regularly clean out old entries
once.Do(func() {
replayCache = make(Cache)
replayCache = Cache{
Entries: make(map[string]clientEntries),
}
go func() {
for {
// TODO consider using a context here.
Expand All @@ -69,7 +92,9 @@ func GetReplayCache(d time.Duration) *Cache {
// AddEntry adds an entry to the Cache.
func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {
ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
if ce, ok := (*c)[a.CName.GetPrincipalNameString()]; ok {
if ce, ok := c.getClientEntries(a.CName); ok {
c.mux.Lock()
defer c.mux.Unlock()
ce.ReplayMap[ct] = replayCacheEntry{
PresentedTime: time.Now().UTC(),
SName: sname,
Expand All @@ -78,7 +103,9 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {
ce.SeqNumber = a.SeqNumber
ce.SubKey = a.SubKey
} else {
(*c)[a.CName.GetPrincipalNameString()] = clientEntries{
c.mux.Lock()
defer c.mux.Unlock()
c.Entries[a.CName.GetPrincipalNameString()] = clientEntries{
ReplayMap: map[time.Time]replayCacheEntry{
ct: {
PresentedTime: time.Now().UTC(),
Expand All @@ -94,26 +121,26 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {

// ClearOldEntries clears entries from the Cache that are older than the duration provided.
func (c *Cache) ClearOldEntries(d time.Duration) {
for ck := range *c {
for ct, e := range (*c)[ck].ReplayMap {
c.mux.Lock()
defer c.mux.Unlock()
for ke, ce := range c.Entries {
for k, e := range ce.ReplayMap {
if time.Now().UTC().Sub(e.PresentedTime) > d {
delete((*c)[ck].ReplayMap, ct)
delete(ce.ReplayMap, k)
}
}
if len((*c)[ck].ReplayMap) == 0 {
delete((*c), ck)
if len(ce.ReplayMap) == 0 {
delete(c.Entries, ke)
}
}
}

// IsReplay tests if the Authenticator provided is a replay within the duration defined. If this is not a replay add the entry to the cache for tracking.
func (c *Cache) IsReplay(sname types.PrincipalName, a types.Authenticator) bool {
if ck, ok := (*c)[a.CName.GetPrincipalNameString()]; ok {
ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
if e, ok := ck.ReplayMap[ct]; ok {
if e.SName.Equal(sname) {
return true
}
ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
if e, ok := c.getClientEntry(a.CName, ct); ok {
if e.SName.Equal(sname) {
return true
}
}
c.AddEntry(sname, a)
Expand Down
91 changes: 79 additions & 12 deletions service/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,12 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) {
}
assert.Equal(t, http.StatusOK, httpResp.StatusCode, "Status code in response to client SPNEGO request not as expected")

// A number of concurrent requests with the same ticket should be rejected due to replay
var wg sync.WaitGroup
noReq := 10
wg.Add(noReq)
for i := 0; i < noReq; i++ {
go httpGetReplay(t, r1, &wg)
// Use ticket again should be rejected
httpResp, err = http.DefaultClient.Do(r1)
if err != nil {
t.Fatalf("Request error: %v\n", err)
}
wg.Wait()
assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")

// Form a 2nd ticket
st = time.Now().UTC()
Expand Down Expand Up @@ -164,13 +162,82 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")
}

func httpGetReplay(t *testing.T, r *http.Request, wg *sync.WaitGroup) {
defer wg.Done()
httpResp, err := http.DefaultClient.Do(r)
func TestService_SPNEGOKRB_ReplayCache_Concurrency(t *testing.T) {
s := httpServer()
defer s.Close()

cl := getClient()
sname := types.PrincipalName{
NameType: nametype.KRB_NT_PRINCIPAL,
NameString: []string{"HTTP", "host.test.gokrb5"},
}
b, _ := hex.DecodeString(testdata.HTTP_KEYTAB)
kt, _ := keytab.Parse(b)
st := time.Now().UTC()
tkt, sessionKey, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm,
sname, "TEST.GOKRB5",
types.NewKrbFlags(),
kt,
18,
1,
st,
st,
st.Add(time.Duration(24)*time.Hour),
st.Add(time.Duration(48)*time.Hour),
)
if err != nil {
t.Fatalf("Request error: %v\n", err)
t.Fatalf("Error getting test ticket: %v", err)
}
assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")

r1, _ := http.NewRequest("GET", s.URL, nil)
err = client.SetSPNEGOHeader(*cl.Credentials, tkt, sessionKey, r1)
if err != nil {
t.Fatalf("Error setting client SPNEGO header: %v", err)
}

// Form a 2nd ticket
st = time.Now().UTC()
tkt2, sessionKey2, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm,
sname, "TEST.GOKRB5",
types.NewKrbFlags(),
kt,
18,
1,
st,
st,
st.Add(time.Duration(24)*time.Hour),
st.Add(time.Duration(48)*time.Hour),
)
if err != nil {
t.Fatalf("Error getting test ticket: %v", err)
}
r2, _ := http.NewRequest("GET", s.URL, nil)
err = client.SetSPNEGOHeader(*cl.Credentials, tkt2, sessionKey2, r2)
if err != nil {
t.Fatalf("Error setting client SPNEGO header: %v", err)
}

// Concurrent 1st requests should be OK
var wg sync.WaitGroup
wg.Add(2)
go httpGet(r1, &wg)
go httpGet(r2, &wg)
wg.Wait()

// A number of concurrent requests with the same ticket should be rejected due to replay
var wg2 sync.WaitGroup
noReq := 10
wg2.Add(noReq * 2)
for i := 0; i < noReq; i++ {
go httpGet(r1, &wg2)
go httpGet(r2, &wg2)
}
wg2.Wait()
}

func httpGet(r *http.Request, wg *sync.WaitGroup) {
defer wg.Done()
http.DefaultClient.Do(r)
}

func httpServer() *httptest.Server {
Expand Down

0 comments on commit 21f62e2

Please sign in to comment.