Skip to content
This repository has been archived by the owner on Nov 14, 2024. It is now read-only.

Query rooms with ACLs instead of all rooms #3338

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions roomserver/acls/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import (
const MRoomServerACL = "m.room.server_acl"

type ServerACLDatabase interface {
// GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error)
// RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error)

// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
Expand All @@ -57,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
}

// Look up all of the rooms that the current state server knows about.
rooms, err := db.GetKnownRooms(ctx)
rooms, err := db.RoomsWithACLs(ctx)
if err != nil {
logrus.WithError(err).Fatalf("Failed to get known rooms")
}
Expand Down
2 changes: 1 addition & 1 deletion roomserver/acls/acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ var (

type dummyACLDB struct{}

func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) {
return []string{"1", "2"}, nil
}

Expand Down
3 changes: 3 additions & 0 deletions roomserver/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ type RoomserverInternalAPI interface {
req *QueryAuthChainRequest,
res *QueryAuthChainResponse,
) error

// RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error)
}

type UserRoomPrivateKeyCreator interface {
Expand Down
5 changes: 5 additions & 0 deletions roomserver/internal/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,8 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID,

return nil, nil
}

// RoomsWithACLs returns all room IDs for rooms with ACLs
func (r *Queryer) RoomsWithACLs(ctx context.Context) ([]string, error) {
return r.DB.RoomsWithACLs(ctx)
}
35 changes: 35 additions & 0 deletions roomserver/roomserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,3 +1284,38 @@ func TestRoomConsumerRecreation(t *testing.T) {
wantAckWait := input.MaximumMissingProcessingTime + (time.Second * 10)
assert.Equal(t, wantAckWait, info.Config.AckWait)
}

func TestRoomsWithACLs(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
noACLRoom := test.NewRoom(t, alice)
aclRoom := test.NewRoom(t, alice)

aclRoom.CreateAndInsert(t, alice, "m.room.server_acl", map[string]any{
"deny": []string{"evilhost.test"},
"allow": []string{"*"},
}, test.WithStateKey(""))

test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
defer closeDB()

cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
natsInstance := &jetstream.NATSInstance{}
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
// start JetStream listeners
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)

for _, room := range []*test.Room{noACLRoom, aclRoom} {
// Create the rooms
err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false)
assert.NoError(t, err)
}

// Validate that we only have one ACLd room.
roomsWithACLs, err := rsAPI.RoomsWithACLs(ctx)
assert.NoError(t, err)
assert.Equal(t, []string{aclRoom.ID}, roomsWithACLs)
})
}
5 changes: 3 additions & 2 deletions roomserver/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error

Expand All @@ -193,6 +191,9 @@ type Database interface {
MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)

// RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error)
}

type UserRoomKeys interface {
Expand Down
26 changes: 26 additions & 0 deletions roomserver/storage/postgres/events_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ const selectRoomNIDsForEventNIDsSQL = "" +
const selectEventRejectedSQL = "" +
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"

const selectRoomsWithACLsSQL = `select distinct room_nid from roomserver_events where event_type_nid = $1`
S7evinK marked this conversation as resolved.
Show resolved Hide resolved

type eventStatements struct {
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
Expand All @@ -166,6 +168,7 @@ type eventStatements struct {
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt
selectEventRejectedStmt *sql.Stmt
selectRoomsWithACLsStmt *sql.Stmt
}

func CreateEventsTable(db *sql.DB) error {
Expand Down Expand Up @@ -206,6 +209,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
{&s.selectRoomsWithACLsStmt, selectRoomsWithACLsSQL},
}.Prepare(db)
}

Expand Down Expand Up @@ -582,3 +586,25 @@ func (s *eventStatements) SelectEventRejected(
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
return
}

func (s *eventStatements) SelectRoomsWithEventTypeNID(
ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithACLsStmt)
rows, err := stmt.QueryContext(ctx, eventTypeNID)
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed")
if err != nil {
return nil, err
}

var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() {
if err := rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
}

return roomNIDs, rows.Err()
}
22 changes: 0 additions & 22 deletions roomserver/storage/postgres/rooms_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" +
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"

const selectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE array_length(latest_event_nids, 1) > 0"

const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid = ANY($1)"

Expand All @@ -94,7 +91,6 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionsForRoomNIDsStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
bulkSelectRoomIDsStmt *sql.Stmt
bulkSelectRoomNIDsStmt *sql.Stmt
}
Expand All @@ -116,29 +112,11 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
}.Prepare(db)
}

func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}
func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
Expand Down
21 changes: 18 additions & 3 deletions roomserver/storage/shared/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -1625,9 +1625,24 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}

// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
func (d *Database) RoomsWithACLs(ctx context.Context) ([]string, error) {

eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, "m.room.server_acl")
if err != nil {
return nil, err
}

roomNIDs, err := d.EventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID)
if err != nil {
return nil, err
}

roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil {
return nil, err
}

return roomIDs, nil
}

// ForgetRoom sets a users room to forgotten
Expand Down
26 changes: 26 additions & 0 deletions roomserver/storage/sqlite3/events_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ const selectRoomNIDsForEventNIDsSQL = "" +
const selectEventRejectedSQL = "" +
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"

const selectRoomsWithACLsSQL = `select distinct room_nid from roomserver_events where event_type_nid = $1`

type eventStatements struct {
db *sql.DB
insertEventStmt *sql.Stmt
Expand All @@ -135,6 +137,7 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
selectEventRejectedStmt *sql.Stmt
selectRoomsWithACLsStmt *sql.Stmt
//bulkSelectEventNIDStmt *sql.Stmt
//bulkSelectUnsentEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
Expand Down Expand Up @@ -192,6 +195,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
{&s.selectRoomsWithACLsStmt, selectRoomsWithACLsSQL},
}.Prepare(db)
}

Expand Down Expand Up @@ -682,3 +686,25 @@ func (s *eventStatements) SelectEventRejected(
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
return
}

func (s *eventStatements) SelectRoomsWithEventTypeNID(
ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithACLsStmt)
rows, err := stmt.QueryContext(ctx, eventTypeNID)
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed")
if err != nil {
return nil, err
}

var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() {
if err := rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
}

return roomNIDs, rows.Err()
}
23 changes: 0 additions & 23 deletions roomserver/storage/sqlite3/rooms_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" +
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"

const selectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE latest_event_nids != '[]'"

const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"

Expand All @@ -87,7 +84,6 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
//selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
}

func CreateRoomsTable(db *sql.DB) error {
Expand All @@ -108,29 +104,10 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL},
}.Prepare(db)
}

func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
Expand Down
36 changes: 36 additions & 0 deletions roomserver/storage/tables/events_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tables_test

import (
"context"
"fmt"
"testing"

"github.com/matrix-org/dendrite/internal/sqlutil"
Expand Down Expand Up @@ -147,3 +148,38 @@ func Test_EventsTable(t *testing.T) {
assert.Equal(t, int64(len(room.Events())+1), maxDepth)
})
}

func TestRoomsWithACL(t *testing.T) {

test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
eventStateKeys, closeEventStateKeys := mustCreateEventTypesTable(t, dbType)
defer closeEventStateKeys()

eventsTable, closeEventsTable := mustCreateEventsTable(t, dbType)
defer closeEventsTable()

ctx := context.Background()

// insert the m.room.server_acl event type
eventTypeNID, err := eventStateKeys.InsertEventTypeNID(ctx, nil, "m.room.server_acl")
assert.Nil(t, err)

// Create ACL'd rooms
var wantRoomNIDs []types.RoomNID
for i := 0; i < 10; i++ {
_, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false)
assert.Nil(t, err)
wantRoomNIDs = append(wantRoomNIDs, types.RoomNID(i))
}

// Create non-ACL'd rooms (eventTypeNID+1)
for i := 10; i < 20; i++ {
_, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID+1, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false)
assert.Nil(t, err)
}

gotRoomNIDs, err := eventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID)
assert.Nil(t, err)
assert.Equal(t, wantRoomNIDs, gotRoomNIDs)
})
}
3 changes: 2 additions & 1 deletion roomserver/storage/tables/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type Events interface {
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error)

SelectRoomsWithEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID) ([]types.RoomNID, error)
}

type Rooms interface {
Expand All @@ -80,7 +82,6 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
}
Expand Down
Loading
Loading