Skip to content

Commit

Permalink
Add idempotency checks for OTK uploads (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay authored Jan 9, 2024
1 parent 61e494b commit eeb88a5
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tests/csapi/upload_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,103 @@ func TestUploadKey(t *testing.T) {
})
})
}

// Tests idempotency of the /keys/upload endpoint.
// Tests that if you upload 4 OTKs then upload the same 4, no error is returned.
func TestUploadKeyIdempotency(t *testing.T) {
deployment := complement.Deploy(t, 1)
defer deployment.Destroy(t)
alice := deployment.Register(t, "hs1", helpers.RegistrationOpts{})
deviceKeys, oneTimeKeys := alice.MustGenerateOneTimeKeys(t, 4)
requests := []client.RequestOpt{
client.WithJSONBody(t, map[string]interface{}{
"device_keys": deviceKeys,
"one_time_keys": oneTimeKeys,
}),
client.WithJSONBody(t, map[string]interface{}{
"one_time_keys": oneTimeKeys,
}),
client.WithJSONBody(t, map[string]interface{}{
"one_time_keys": oneTimeKeys,
}),
}
for _, reqBody := range requests {
resp := alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"}, reqBody)
must.MatchResponse(t, resp, match.HTTPResponse{
StatusCode: http.StatusOK,
JSON: []match.JSON{
match.JSONMapEach("one_time_key_counts", func(k, v gjson.Result) error {
keyCount := 0
for key := range oneTimeKeys {
// check that the returned algorithms -> key count matches those we uploaded
if strings.HasPrefix(key, k.Str) {
keyCount++
}
}
if int(v.Float()) != keyCount {
return fmt.Errorf("expected %d one time keys, got %d", keyCount, int(v.Float()))
}
return nil
}),
},
})
}
}

// Tests idempotency of the /keys/upload endpoint.
// Tests that if you upload OTKs A,B,C then upload OTKs B,C,D, no error is returned and the OTK count says 4 (A,B,C,D).
func TestUploadKeyIdempotencyOverlap(t *testing.T) {
deployment := complement.Deploy(t, 1)
defer deployment.Destroy(t)
alice := deployment.Register(t, "hs1", helpers.RegistrationOpts{})
deviceKeys, oneTimeKeys := alice.MustGenerateOneTimeKeys(t, 4)
i := 0
keysABC := map[string]interface{}{}
keysBCD := map[string]interface{}{}
for keyID, otk := range oneTimeKeys {
i++
if i == 1 {
keysABC[keyID] = otk
continue
}
if i == 4 {
keysBCD[keyID] = otk
continue
}
keysABC[keyID] = otk
keysBCD[keyID] = otk
}
t.Logf("OTKs ABC %v", keysABC)
t.Logf("OTKs BCD %v", keysBCD)
requests := []client.RequestOpt{
client.WithJSONBody(t, map[string]interface{}{
"device_keys": deviceKeys,
}),
client.WithJSONBody(t, map[string]interface{}{
"one_time_keys": keysABC,
}),
client.WithJSONBody(t, map[string]interface{}{
"one_time_keys": keysBCD,
}),
}
for i, reqBody := range requests {
expectedOTKCount := 0
if i == 1 {
expectedOTKCount = 3
} else if i == 2 {
expectedOTKCount = 4
}
resp := alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"}, reqBody)
must.MatchResponse(t, resp, match.HTTPResponse{
StatusCode: http.StatusOK,
JSON: []match.JSON{
match.JSONMapEach("one_time_key_counts", func(k, v gjson.Result) error {
if int(v.Float()) != expectedOTKCount {
return fmt.Errorf("expected %d one time keys, got %d", expectedOTKCount, int(v.Float()))
}
return nil
}),
},
})
}
}

0 comments on commit eeb88a5

Please sign in to comment.