From 9401504fbd4b14072882e8a9ddfb0ecc61390c38 Mon Sep 17 00:00:00 2001 From: affan Date: Thu, 5 Sep 2024 15:55:41 -0400 Subject: [PATCH] Implement message server for affiliates --- protocol/x/affiliates/keeper/msg_server.go | 17 ++- .../x/affiliates/keeper/msg_server_test.go | 108 ++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/protocol/x/affiliates/keeper/msg_server.go b/protocol/x/affiliates/keeper/msg_server.go index bda65b27eb7..d2fc8363f6e 100644 --- a/protocol/x/affiliates/keeper/msg_server.go +++ b/protocol/x/affiliates/keeper/msg_server.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/types" ) @@ -13,12 +15,23 @@ type msgServer struct { // RegisterAffiliate implements types.MsgServer. func (k msgServer) RegisterAffiliate(ctx context.Context, msg *types.MsgRegisterAffiliate) (*types.MsgRegisterAffiliateResponse, error) { - return nil, nil + sdkCtx := sdk.UnwrapSDKContext(ctx) + err := k.Keeper.RegisterAffiliate(sdkCtx, msg.Referee, msg.Affiliate) + if err != nil { + return nil, err + } + return &types.MsgRegisterAffiliateResponse{}, nil } func (k msgServer) UpdateAffiliateTiers(ctx context.Context, msg *types.MsgUpdateAffiliateTiers) (*types.MsgUpdateAffiliateTiersResponse, error) { - return nil, nil + sdkCtx := sdk.UnwrapSDKContext(ctx) + if !k.Keeper.HasAuthority(msg.Authority) { + return nil, errors.New("invalid authority") + } + k.Keeper.UpdateAffiliateTiers(sdkCtx, *msg.Tiers) + + return &types.MsgUpdateAffiliateTiersResponse{}, nil } // NewMsgServerImpl returns an implementation of the MsgServer interface diff --git a/protocol/x/affiliates/keeper/msg_server_test.go b/protocol/x/affiliates/keeper/msg_server_test.go index f4a1429abc2..4956e97ff42 100644 --- a/protocol/x/affiliates/keeper/msg_server_test.go +++ b/protocol/x/affiliates/keeper/msg_server_test.go @@ -4,7 +4,10 @@ import ( "context" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/dydxprotocol/v4-chain/protocol/lib" testapp "github.com/dydxprotocol/v4-chain/protocol/testutil/app" + constants "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/keeper" "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/types" "github.com/stretchr/testify/require" @@ -24,3 +27,108 @@ func TestMsgServer(t *testing.T) { require.NotNil(t, ms) require.NotNil(t, ctx) } + +func TestMsgServer_RegisterAffiliate(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgRegisterAffiliate + expectErr error + setup func(ctx sdk.Context, k keeper.Keeper) + }{ + { + name: "valid registration", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: nil, + }, + { + name: "invalid referee address", + msg: &types.MsgRegisterAffiliate{ + Referee: "invalid_address", + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: types.ErrInvalidAddress, + }, + { + name: "invalid affiliate address", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: "invalid_address", + }, + expectErr: types.ErrInvalidAddress, + }, + { + name: "referee already has an affiliate", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: types.ErrAffiliateAlreadyExistsForReferee, + setup: func(ctx sdk.Context, k keeper.Keeper) { + err := k.RegisterAffiliate(ctx, constants.BobAccAddress.String(), constants.AliceAccAddress.String()) + require.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + if tc.setup != nil { + tc.setup(sdkCtx, k) + } + _, err := ms.RegisterAffiliate(ctx, tc.msg) + if tc.expectErr != nil { + require.ErrorIs(t, err, tc.expectErr) + } else { + require.NoError(t, err) + affiliate, found := k.GetReferredBy(sdkCtx, tc.msg.Referee) + require.True(t, found) + require.Equal(t, tc.msg.Affiliate, affiliate) + } + }) + } +} + +func TestMsgServer_UpdateAffiliateTiers(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgUpdateAffiliateTiers + expectErr bool + }{ + { + name: "Gov module updates tiers", + msg: &types.MsgUpdateAffiliateTiers{ + Authority: lib.GovModuleAddress.String(), + Tiers: &types.DefaultAffiliateTiers, + }, + }, + { + name: "non-gov module updates tiers", + msg: &types.MsgUpdateAffiliateTiers{ + Authority: constants.BobAccAddress.String(), + Tiers: &types.DefaultAffiliateTiers, + }, + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + _, err := ms.UpdateAffiliateTiers(ctx, tc.msg) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + tiers, err := k.GetAllAffiliateTiers(sdkCtx) + require.NoError(t, err) + require.Equal(t, tc.msg.Tiers, tiers) + } + }) + } +}