Skip to content

Commit

Permalink
Merge pull request #437 from matrix-org/kegan/device-data-table
Browse files Browse the repository at this point in the history
Refactor device data
  • Loading branch information
kegsay authored May 20, 2024
2 parents 693587e + 3fc49bd commit 1551ccd
Show file tree
Hide file tree
Showing 14 changed files with 832 additions and 347 deletions.
102 changes: 19 additions & 83 deletions internal/device_data.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package internal

import (
"sync"
)

const (
bitOTKCount int = iota
bitFallbackKeyTypes
Expand All @@ -18,105 +14,45 @@ func isBitSet(n int, bit int) bool {
return val > 0
}

// DeviceData contains useful data for this user's device. This list can be expanded without prompting
// schema changes. These values are upserted into the database and persisted forever.
// DeviceData contains useful data for this user's device.
type DeviceData struct {
DeviceListChanges
DeviceKeyData
UserID string
DeviceID string
}

// This is calculated from device_lists table
type DeviceListChanges struct {
DeviceListChanged []string
DeviceListLeft []string
}

// This gets serialised as CBOR in device_data table
type DeviceKeyData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`

DeviceLists DeviceLists `json:"dl"`

// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`

UserID string
DeviceID string
}

func (dd *DeviceData) SetOTKCountChanged() {
func (dd *DeviceKeyData) SetOTKCountChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
}

func (dd *DeviceData) SetFallbackKeysChanged() {
func (dd *DeviceKeyData) SetFallbackKeysChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
}

func (dd *DeviceData) OTKCountChanged() bool {
func (dd *DeviceKeyData) OTKCountChanged() bool {
return isBitSet(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) FallbackKeysChanged() bool {
func (dd *DeviceKeyData) FallbackKeysChanged() bool {
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
}

type UserDeviceKey struct {
UserID string
DeviceID string
}

type DeviceDataMap struct {
deviceDataMu *sync.Mutex
deviceDataMap map[UserDeviceKey]*DeviceData
Pos int64
}

func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap {
ddm := &DeviceDataMap{
deviceDataMu: &sync.Mutex{},
deviceDataMap: make(map[UserDeviceKey]*DeviceData),
Pos: startPos,
}
for i, dd := range devices {
ddm.deviceDataMap[UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}] = &devices[i]
}
return ddm
}

func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData {
key := UserDeviceKey{
UserID: userID,
DeviceID: deviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
dd, ok := d.deviceDataMap[key]
if !ok {
return nil
}
return dd
}

func (d *DeviceDataMap) Update(dd DeviceData) DeviceData {
key := UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
existing, ok := d.deviceDataMap[key]
if !ok {
existing = &DeviceData{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
}
if dd.OTKCounts != nil {
existing.OTKCounts = dd.OTKCounts
}
if dd.FallbackKeyTypes != nil {
existing.FallbackKeyTypes = dd.FallbackKeyTypes
}
existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists)

d.deviceDataMap[key] = existing

return *existing
}
88 changes: 44 additions & 44 deletions state/device_data_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ type DeviceDataRow struct {
ID int64 `db:"id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
// This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't
// need to perform searches on this data.
Data []byte `db:"data"`
KeyData []byte `db:"data"`
}

type DeviceDataTable struct {
db *sqlx.DB
db *sqlx.DB
deviceListTable *DeviceListTable
}

func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
Expand All @@ -37,14 +38,16 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
`)
return &DeviceDataTable{
db: db,
db: db,
deviceListTable: NewDeviceListTable(db),
}
}

// Atomically select the device data for this user|device and then swap DeviceLists around if set.
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// grab otk counts and fallback key types
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
Expand All @@ -54,32 +57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
}
return err
}
result = &internal.DeviceData{}
var keyData *internal.DeviceKeyData
// unmarshal to swap
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
decMode, err := opts.DecMode()
result.UserID = userID
result.DeviceID = deviceID
if keyData != nil {
result.DeviceKeyData = *keyData
}

deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap)
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &result); err != nil {
return err
for targetUserID, targetState := range deviceListChanges {
switch targetState {
case internal.DeviceListChanged:
result.DeviceListChanged = append(result.DeviceListChanged, targetUserID)
case internal.DeviceListLeft:
result.DeviceListLeft = append(result.DeviceListLeft, targetUserID)
}
}
result.UserID = userID
result.DeviceID = deviceID
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New

// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
writeBack.DeviceLists.New = make(map[string]int)
writeBack := *keyData
writeBack.ChangedBits = 0

if reflect.DeepEqual(result, &writeBack) {
if reflect.DeepEqual(keyData, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
Expand All @@ -97,52 +106,43 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
return
}

func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
return err
}

// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// Update device lists
if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil {
return err
}
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
// unmarshal and combine
var tempDD internal.DeviceData
if len(row.Data) > 0 {
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &tempDD); err != nil {
var keyData internal.DeviceKeyData
if len(row.KeyData) > 0 {
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
}
if dd.FallbackKeyTypes != nil {
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
tempDD.SetFallbackKeysChanged()
if keys.FallbackKeyTypes != nil {
keyData.FallbackKeyTypes = keys.FallbackKeyTypes
keyData.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
tempDD.OTKCounts = dd.OTKCounts
tempDD.SetOTKCountChanged()
if keys.OTKCounts != nil {
keyData.OTKCounts = keys.OTKCounts
keyData.SetOTKCountChanged()
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)

data, err := cbor.Marshal(tempDD)
data, err := cbor.Marshal(keyData)
if err != nil {
return err
}
_, err = txn.Exec(
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
dd.UserID, dd.DeviceID, data,
userID, deviceID, data,
)
return err
})
Expand Down
Loading

0 comments on commit 1551ccd

Please sign in to comment.