Skip to content

Commit

Permalink
[CT-950] safety heap methods (dydxprotocol#1821)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayy04 authored Jul 3, 2024
1 parent 4035bae commit cf745f4
Show file tree
Hide file tree
Showing 2 changed files with 375 additions and 0 deletions.
177 changes: 177 additions & 0 deletions protocol/x/subaccounts/keeper/safety_heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package keeper

import (
"cosmossdk.io/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

// RemoveSubaccountFromSafetyHeap removes a subaccount from the safety heap
// given a peretual and side.
func (k Keeper) RemoveSubaccountFromSafetyHeap(
ctx sdk.Context,
subaccountId types.SubaccountId,
perpetualId uint32,
side types.SafetyHeapPositionSide,
) {
store := k.GetSafetyHeapStore(ctx, perpetualId, side)
index := k.MustGetSubaccountHeapIndex(store, subaccountId)
k.MustRemoveElementAtIndex(ctx, store, index)
}

// AddSubaccountToSafetyHeap adds a subaccount to the safety heap
// given a perpetual and side.
func (k Keeper) AddSubaccountToSafetyHeap(
ctx sdk.Context,
subaccountId types.SubaccountId,
perpetualId uint32,
side types.SafetyHeapPositionSide,
) {
store := k.GetSafetyHeapStore(ctx, perpetualId, side)
k.Insert(ctx, store, subaccountId)
}

// Heap methods

// Insert inserts a subaccount into the safety heap.
func (k Keeper) Insert(
ctx sdk.Context,
store prefix.Store,
subaccountId types.SubaccountId,
) {
// Add the subaccount to the end of the heap.
length := k.GetSafetyHeapLength(store)
k.SetSubaccountAtIndex(store, subaccountId, length)

// Increment the size of the heap.
k.SetSafetyHeapLength(store, length+1)

// Heapify up the element at the end of the heap
// to restore the heap property.
k.HeapifyUp(ctx, store, length)
}

// MustRemoveElementAtIndex removes the element at the given index
// from the safety heap.
func (k Keeper) MustRemoveElementAtIndex(
ctx sdk.Context,
store prefix.Store,
index uint32,
) {
length := k.GetSafetyHeapLength(store)
if index >= length {
panic(types.ErrSafetyHeapSubaccountNotFoundAtIndex)
}

// Swap the element with the last element.
k.Swap(store, index, length-1)

// Remove the last element.
k.DeleteSubaccountAtIndex(store, length-1)
k.SetSafetyHeapLength(store, length-1)

// Heapify down and up the element at the given index
// to restore the heap property.
if index < length-1 {
k.HeapifyDown(ctx, store, index)
k.HeapifyUp(ctx, store, index)
}
}

// HeapifyUp moves the element at the given index up the heap
// until the heap property is restored.
func (k Keeper) HeapifyUp(
ctx sdk.Context,
store prefix.Store,
index uint32,
) {
if index == 0 {
return
}

parentIndex := (index - 1) / 2
if k.Less(ctx, store, index, parentIndex) {
k.Swap(store, index, parentIndex)
k.HeapifyUp(ctx, store, parentIndex)
}
}

// HeapifyDown moves the element at the given index down the heap
// until the heap property is restored.
func (k Keeper) HeapifyDown(
ctx sdk.Context,
store prefix.Store,
index uint32,
) {
leftIndex, rightIndex := 2*index+1, 2*index+2

length := k.GetSafetyHeapLength(store)
if rightIndex < length && k.Less(ctx, store, rightIndex, leftIndex) {
// Compare the current node with the right child
// if right child exists and is less than the left child.
if k.Less(ctx, store, rightIndex, index) {
k.Swap(store, index, rightIndex)
k.HeapifyDown(ctx, store, rightIndex)
}
} else if leftIndex < length {
// Compare the current node with the left child
// if left child exists.
if k.Less(ctx, store, leftIndex, index) {
k.Swap(store, index, leftIndex)
k.HeapifyDown(ctx, store, leftIndex)
}
}
}

// Swap swaps the elements at the given indices.
func (k Keeper) Swap(
store prefix.Store,
index1 uint32,
index2 uint32,
) {
// No-op case
if index1 == index2 {
return
}

first := k.MustGetSubaccountAtIndex(store, index1)
second := k.MustGetSubaccountAtIndex(store, index2)
k.SetSubaccountAtIndex(store, first, index2)
k.SetSubaccountAtIndex(store, second, index1)
}

// Less returns true if the element at the first index is less than
// the element at the second index.
func (k Keeper) Less(
ctx sdk.Context,
store prefix.Store,
first uint32,
second uint32,
) bool {
firstSubaccountId := k.MustGetSubaccountAtIndex(store, first)
secondSubaccountId := k.MustGetSubaccountAtIndex(store, second)

firstRisk, err := k.GetNetCollateralAndMarginRequirements(
ctx,
types.Update{
SubaccountId: firstSubaccountId,
},
)
if err != nil {
panic(err)
}

secondRisk, err := k.GetNetCollateralAndMarginRequirements(
ctx,
types.Update{
SubaccountId: secondSubaccountId,
},
)
if err != nil {
panic(err)
}

// Compare the risks of the two subaccounts and sort
// them in descending order.
return firstRisk.Cmp(secondRisk) > 0
}
198 changes: 198 additions & 0 deletions protocol/x/subaccounts/keeper/safety_heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package keeper_test

import (
"math/big"
"math/rand"
"testing"

"cosmossdk.io/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/app/config"
"github.com/dydxprotocol/v4-chain/protocol/testutil/constants"
keepertest "github.com/dydxprotocol/v4-chain/protocol/testutil/keeper"
testutil "github.com/dydxprotocol/v4-chain/protocol/testutil/util"
"github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/keeper"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
"github.com/stretchr/testify/require"
"gopkg.in/typ.v4/slices"
)

func TestSafetyHeapInsertRemoveMin(t *testing.T) {
perpetualId := uint32(0)
side := satypes.Long
totalSubaccounts := 1000

// Create 1000 subaccounts with balances ranging from -500 to 500.
// The subaccounts should be sorted by balance.
allSubaccounts := make([]satypes.Subaccount, 0)
for i := 0; i < totalSubaccounts; i++ {
subaccount := satypes.Subaccount{
Id: &satypes.SubaccountId{
Owner: sdk.MustBech32ifyAddressBytes(
config.Bech32PrefixAccAddr,
constants.AliceAccAddress,
),
Number: uint32(i),
},
AssetPositions: testutil.CreateUsdcAssetPositions(
// Create asset positions with balances ranging from -500 to 500.
big.NewInt(int64(i - totalSubaccounts/2)),
),
}

// Handle special case.
if i-totalSubaccounts/2 == 0 {
subaccount.AssetPositions = nil
}

allSubaccounts = append(allSubaccounts, subaccount)
}

for iter := 0; iter < 100; iter++ {
// Setup keeper state and test parameters.
ctx, subaccountsKeeper, _, _, _, _, _, _, _, _ := keepertest.SubaccountsKeepers(t, false)

// Shuffle the subaccounts so that insertion order is random.
slices.Shuffle(allSubaccounts)

store := subaccountsKeeper.GetSafetyHeapStore(ctx, perpetualId, side)
for i, subaccount := range allSubaccounts {
subaccountsKeeper.SetSubaccount(ctx, subaccount)
subaccountsKeeper.AddSubaccountToSafetyHeap(
ctx,
*subaccount.Id,
perpetualId,
side,
)

require.Equal(
t,
uint32(i+1),
subaccountsKeeper.GetSafetyHeapLength(store),
)
}

// Make sure subaccounts are sorted correctly.
for i := 0; i < totalSubaccounts; i++ {
// Get the subaccount with the lowest safety score.
// In this case, the subaccount with the lowest USDC balance.
subaccountId := subaccountsKeeper.MustGetSubaccountAtIndex(store, uint32(0))
subaccount := subaccountsKeeper.GetSubaccount(ctx, subaccountId)

// Subaccounts should be sorted by asset position balance.
require.Equal(t, uint32(i), subaccountId.Number)
require.Equal(
t,
big.NewInt(int64(i-totalSubaccounts/2)),
subaccount.GetUsdcPosition(),
)

// Remove the subaccount from the heap.
subaccountsKeeper.RemoveSubaccountFromSafetyHeap(
ctx,
subaccountId,
perpetualId,
side,
)
require.Equal(
t,
uint32(totalSubaccounts-i-1),
subaccountsKeeper.GetSafetyHeapLength(store),
)
}
}
}

func TestSafetyHeapInsertRemoveIndex(t *testing.T) {
perpetualId := uint32(0)
side := satypes.Long
totalSubaccounts := 100

// Create 1000 subaccounts with balances ranging from -500 to 500.
// The subaccounts should be sorted by balance.
allSubaccounts := make([]satypes.Subaccount, 0)
for i := 0; i < totalSubaccounts; i++ {
subaccount := satypes.Subaccount{
Id: &satypes.SubaccountId{
Owner: sdk.MustBech32ifyAddressBytes(
config.Bech32PrefixAccAddr,
constants.AliceAccAddress,
),
Number: uint32(i),
},
AssetPositions: testutil.CreateUsdcAssetPositions(
// Create asset positions with balances ranging from -500 to 500.
big.NewInt(int64(i - totalSubaccounts/2)),
),
}

// Handle special case.
if i-totalSubaccounts/2 == 0 {
subaccount.AssetPositions = nil
}

allSubaccounts = append(allSubaccounts, subaccount)
}

for iter := 0; iter < 100; iter++ {
// Setup keeper state and test parameters.
ctx, subaccountsKeeper, _, _, _, _, _, _, _, _ := keepertest.SubaccountsKeepers(t, false)

// Shuffle the subaccounts so that insertion order is random.
slices.Shuffle(allSubaccounts)

store := subaccountsKeeper.GetSafetyHeapStore(ctx, perpetualId, side)
for i, subaccount := range allSubaccounts {
subaccountsKeeper.SetSubaccount(ctx, subaccount)
subaccountsKeeper.AddSubaccountToSafetyHeap(
ctx,
*subaccount.Id,
perpetualId,
side,
)

require.Equal(
t,
uint32(i+1),
subaccountsKeeper.GetSafetyHeapLength(store),
)
}

for i := totalSubaccounts; i > 0; i-- {
// Remove a random subaccount from the heap.
index := rand.Intn(i)

subaccountId := subaccountsKeeper.MustGetSubaccountAtIndex(store, uint32(index))
subaccountsKeeper.RemoveSubaccountFromSafetyHeap(
ctx,
subaccountId,
perpetualId,
side,
)

require.Equal(
t,
uint32(i-1),
subaccountsKeeper.GetSafetyHeapLength(store),
)

// Verify that the heap property is restored.
verifyHeapProperties(t, subaccountsKeeper, ctx, store, 0)
}
}
}

func verifyHeapProperties(t *testing.T, k *keeper.Keeper, ctx sdk.Context, store prefix.Store, index uint32) {
length := k.GetSafetyHeapLength(store)
leftChild, rightChild := 2*index+1, 2*index+2

if leftChild < length {
require.True(t, k.Less(ctx, store, index, leftChild))
verifyHeapProperties(t, k, ctx, store, leftChild)
}

if rightChild < length {
require.True(t, k.Less(ctx, store, index, rightChild))
verifyHeapProperties(t, k, ctx, store, rightChild)
}
}

0 comments on commit cf745f4

Please sign in to comment.