diff --git a/service/cache.go b/service/cache.go index 22a5db47..b968c001 100644 --- a/service/cache.go +++ b/service/cache.go @@ -30,7 +30,10 @@ them following an event that caused the server to lose track of recently seen authenticators.*/ // Cache for tickets received from clients keyed by fully qualified client name. Used to track replay of tickets. -type Cache map[string]clientEntries +type Cache struct { + Entries map[string]clientEntries + mux sync.RWMutex +} // clientEntries holds entries of client details sent to the service. type clientEntries struct { @@ -46,6 +49,24 @@ type replayCacheEntry struct { CTime time.Time // This combines the ticket's CTime and Cusec } +func (c *Cache) getClientEntries(cname types.PrincipalName) (clientEntries, bool) { + c.mux.RLock() + defer c.mux.RUnlock() + ce, ok := c.Entries[cname.GetPrincipalNameString()] + return ce, ok +} + +func (c *Cache) getClientEntry(cname types.PrincipalName, t time.Time) (replayCacheEntry, bool) { + if ce, ok := c.getClientEntries(cname); ok { + c.mux.RLock() + defer c.mux.RUnlock() + if e, ok := ce.ReplayMap[t]; ok { + return e, true + } + } + return replayCacheEntry{}, false +} + // Instance of the ServiceCache. This needs to be a singleton. var replayCache Cache var once sync.Once @@ -54,7 +75,9 @@ var once sync.Once func GetReplayCache(d time.Duration) *Cache { // Create a singleton of the ReplayCache and start a background thread to regularly clean out old entries once.Do(func() { - replayCache = make(Cache) + replayCache = Cache{ + Entries: make(map[string]clientEntries), + } go func() { for { // TODO consider using a context here. @@ -69,7 +92,9 @@ func GetReplayCache(d time.Duration) *Cache { // AddEntry adds an entry to the Cache. func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) { ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond) - if ce, ok := (*c)[a.CName.GetPrincipalNameString()]; ok { + if ce, ok := c.getClientEntries(a.CName); ok { + c.mux.Lock() + defer c.mux.Unlock() ce.ReplayMap[ct] = replayCacheEntry{ PresentedTime: time.Now().UTC(), SName: sname, @@ -78,7 +103,9 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) { ce.SeqNumber = a.SeqNumber ce.SubKey = a.SubKey } else { - (*c)[a.CName.GetPrincipalNameString()] = clientEntries{ + c.mux.Lock() + defer c.mux.Unlock() + c.Entries[a.CName.GetPrincipalNameString()] = clientEntries{ ReplayMap: map[time.Time]replayCacheEntry{ ct: { PresentedTime: time.Now().UTC(), @@ -94,26 +121,26 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) { // ClearOldEntries clears entries from the Cache that are older than the duration provided. func (c *Cache) ClearOldEntries(d time.Duration) { - for ck := range *c { - for ct, e := range (*c)[ck].ReplayMap { + c.mux.Lock() + defer c.mux.Unlock() + for ke, ce := range c.Entries { + for k, e := range ce.ReplayMap { if time.Now().UTC().Sub(e.PresentedTime) > d { - delete((*c)[ck].ReplayMap, ct) + delete(ce.ReplayMap, k) } } - if len((*c)[ck].ReplayMap) == 0 { - delete((*c), ck) + if len(ce.ReplayMap) == 0 { + delete(c.Entries, ke) } } } // IsReplay tests if the Authenticator provided is a replay within the duration defined. If this is not a replay add the entry to the cache for tracking. func (c *Cache) IsReplay(sname types.PrincipalName, a types.Authenticator) bool { - if ck, ok := (*c)[a.CName.GetPrincipalNameString()]; ok { - ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond) - if e, ok := ck.ReplayMap[ct]; ok { - if e.SName.Equal(sname) { - return true - } + ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond) + if e, ok := c.getClientEntry(a.CName, ct); ok { + if e.SName.Equal(sname) { + return true } } c.AddEntry(sname, a) diff --git a/service/http_test.go b/service/http_test.go index 292e7c99..4427aa01 100644 --- a/service/http_test.go +++ b/service/http_test.go @@ -111,14 +111,12 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) { } assert.Equal(t, http.StatusOK, httpResp.StatusCode, "Status code in response to client SPNEGO request not as expected") - // A number of concurrent requests with the same ticket should be rejected due to replay - var wg sync.WaitGroup - noReq := 10 - wg.Add(noReq) - for i := 0; i < noReq; i++ { - go httpGetReplay(t, r1, &wg) + // Use ticket again should be rejected + httpResp, err = http.DefaultClient.Do(r1) + if err != nil { + t.Fatalf("Request error: %v\n", err) } - wg.Wait() + assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.") // Form a 2nd ticket st = time.Now().UTC() @@ -164,13 +162,82 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.") } -func httpGetReplay(t *testing.T, r *http.Request, wg *sync.WaitGroup) { - defer wg.Done() - httpResp, err := http.DefaultClient.Do(r) +func TestService_SPNEGOKRB_ReplayCache_Concurrency(t *testing.T) { + s := httpServer() + defer s.Close() + + cl := getClient() + sname := types.PrincipalName{ + NameType: nametype.KRB_NT_PRINCIPAL, + NameString: []string{"HTTP", "host.test.gokrb5"}, + } + b, _ := hex.DecodeString(testdata.HTTP_KEYTAB) + kt, _ := keytab.Parse(b) + st := time.Now().UTC() + tkt, sessionKey, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm, + sname, "TEST.GOKRB5", + types.NewKrbFlags(), + kt, + 18, + 1, + st, + st, + st.Add(time.Duration(24)*time.Hour), + st.Add(time.Duration(48)*time.Hour), + ) if err != nil { - t.Fatalf("Request error: %v\n", err) + t.Fatalf("Error getting test ticket: %v", err) } - assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.") + + r1, _ := http.NewRequest("GET", s.URL, nil) + err = client.SetSPNEGOHeader(*cl.Credentials, tkt, sessionKey, r1) + if err != nil { + t.Fatalf("Error setting client SPNEGO header: %v", err) + } + + // Form a 2nd ticket + st = time.Now().UTC() + tkt2, sessionKey2, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm, + sname, "TEST.GOKRB5", + types.NewKrbFlags(), + kt, + 18, + 1, + st, + st, + st.Add(time.Duration(24)*time.Hour), + st.Add(time.Duration(48)*time.Hour), + ) + if err != nil { + t.Fatalf("Error getting test ticket: %v", err) + } + r2, _ := http.NewRequest("GET", s.URL, nil) + err = client.SetSPNEGOHeader(*cl.Credentials, tkt2, sessionKey2, r2) + if err != nil { + t.Fatalf("Error setting client SPNEGO header: %v", err) + } + + // Concurrent 1st requests should be OK + var wg sync.WaitGroup + wg.Add(2) + go httpGet(r1, &wg) + go httpGet(r2, &wg) + wg.Wait() + + // A number of concurrent requests with the same ticket should be rejected due to replay + var wg2 sync.WaitGroup + noReq := 10 + wg2.Add(noReq * 2) + for i := 0; i < noReq; i++ { + go httpGet(r1, &wg2) + go httpGet(r2, &wg2) + } + wg2.Wait() +} + +func httpGet(r *http.Request, wg *sync.WaitGroup) { + defer wg.Done() + http.DefaultClient.Do(r) } func httpServer() *httptest.Server {