Skip to content

Commit

Permalink
Implement message server for affiliates
Browse files Browse the repository at this point in the history
  • Loading branch information
affanv14 committed Sep 5, 2024
1 parent 3faadf9 commit 9401504
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
17 changes: 15 additions & 2 deletions protocol/x/affiliates/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package keeper

import (
"context"
"errors"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/x/affiliates/types"
)

Expand All @@ -13,12 +15,23 @@ type msgServer struct {
// RegisterAffiliate implements types.MsgServer.
func (k msgServer) RegisterAffiliate(ctx context.Context,
msg *types.MsgRegisterAffiliate) (*types.MsgRegisterAffiliateResponse, error) {
return nil, nil
sdkCtx := sdk.UnwrapSDKContext(ctx)
err := k.Keeper.RegisterAffiliate(sdkCtx, msg.Referee, msg.Affiliate)
if err != nil {
return nil, err
}
return &types.MsgRegisterAffiliateResponse{}, nil
}

func (k msgServer) UpdateAffiliateTiers(ctx context.Context,
msg *types.MsgUpdateAffiliateTiers) (*types.MsgUpdateAffiliateTiersResponse, error) {
return nil, nil
sdkCtx := sdk.UnwrapSDKContext(ctx)
if !k.Keeper.HasAuthority(msg.Authority) {
return nil, errors.New("invalid authority")
}
k.Keeper.UpdateAffiliateTiers(sdkCtx, *msg.Tiers)

return &types.MsgUpdateAffiliateTiersResponse{}, nil
}

// NewMsgServerImpl returns an implementation of the MsgServer interface
Expand Down
108 changes: 108 additions & 0 deletions protocol/x/affiliates/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import (
"context"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/lib"
testapp "github.com/dydxprotocol/v4-chain/protocol/testutil/app"
constants "github.com/dydxprotocol/v4-chain/protocol/testutil/constants"
"github.com/dydxprotocol/v4-chain/protocol/x/affiliates/keeper"
"github.com/dydxprotocol/v4-chain/protocol/x/affiliates/types"
"github.com/stretchr/testify/require"
Expand All @@ -24,3 +27,108 @@ func TestMsgServer(t *testing.T) {
require.NotNil(t, ms)
require.NotNil(t, ctx)
}

func TestMsgServer_RegisterAffiliate(t *testing.T) {
testCases := []struct {
name string
msg *types.MsgRegisterAffiliate
expectErr error
setup func(ctx sdk.Context, k keeper.Keeper)
}{
{
name: "valid registration",
msg: &types.MsgRegisterAffiliate{
Referee: constants.BobAccAddress.String(),
Affiliate: constants.AliceAccAddress.String(),
},
expectErr: nil,
},
{
name: "invalid referee address",
msg: &types.MsgRegisterAffiliate{
Referee: "invalid_address",
Affiliate: constants.AliceAccAddress.String(),
},
expectErr: types.ErrInvalidAddress,
},
{
name: "invalid affiliate address",
msg: &types.MsgRegisterAffiliate{
Referee: constants.BobAccAddress.String(),
Affiliate: "invalid_address",
},
expectErr: types.ErrInvalidAddress,
},
{
name: "referee already has an affiliate",
msg: &types.MsgRegisterAffiliate{
Referee: constants.BobAccAddress.String(),
Affiliate: constants.AliceAccAddress.String(),
},
expectErr: types.ErrAffiliateAlreadyExistsForReferee,
setup: func(ctx sdk.Context, k keeper.Keeper) {
err := k.RegisterAffiliate(ctx, constants.BobAccAddress.String(), constants.AliceAccAddress.String())
require.NoError(t, err)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
k, ms, ctx := setupMsgServer(t)
sdkCtx := sdk.UnwrapSDKContext(ctx)
if tc.setup != nil {
tc.setup(sdkCtx, k)
}
_, err := ms.RegisterAffiliate(ctx, tc.msg)
if tc.expectErr != nil {
require.ErrorIs(t, err, tc.expectErr)
} else {
require.NoError(t, err)
affiliate, found := k.GetReferredBy(sdkCtx, tc.msg.Referee)
require.True(t, found)
require.Equal(t, tc.msg.Affiliate, affiliate)
}
})
}
}

func TestMsgServer_UpdateAffiliateTiers(t *testing.T) {
testCases := []struct {
name string
msg *types.MsgUpdateAffiliateTiers
expectErr bool
}{
{
name: "Gov module updates tiers",
msg: &types.MsgUpdateAffiliateTiers{
Authority: lib.GovModuleAddress.String(),
Tiers: &types.DefaultAffiliateTiers,
},
},
{
name: "non-gov module updates tiers",
msg: &types.MsgUpdateAffiliateTiers{
Authority: constants.BobAccAddress.String(),
Tiers: &types.DefaultAffiliateTiers,
},
expectErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
k, ms, ctx := setupMsgServer(t)
sdkCtx := sdk.UnwrapSDKContext(ctx)
_, err := ms.UpdateAffiliateTiers(ctx, tc.msg)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
tiers, err := k.GetAllAffiliateTiers(sdkCtx)
require.NoError(t, err)
require.Equal(t, tc.msg.Tiers, tiers)
}
})
}
}

0 comments on commit 9401504

Please sign in to comment.