diff --git a/protocol/x/affiliates/keeper/msg_server.go b/protocol/x/affiliates/keeper/msg_server.go index 72d736c826..00780ed7fb 100644 --- a/protocol/x/affiliates/keeper/msg_server.go +++ b/protocol/x/affiliates/keeper/msg_server.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "errors" errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" @@ -13,15 +14,27 @@ type msgServer struct { } // RegisterAffiliate implements types.MsgServer. +// This is only valid if a referee signs the message. +// For example, if Alice is the referee and Bob is the affiliate, +// then only Alice can register Bob as an affiliate. Any +// other signer that sends this message will be rejected. func (k msgServer) RegisterAffiliate(ctx context.Context, msg *types.MsgRegisterAffiliate) (*types.MsgRegisterAffiliateResponse, error) { - return nil, nil + sdkCtx := sdk.UnwrapSDKContext(ctx) + err := k.Keeper.RegisterAffiliate(sdkCtx, msg.Referee, msg.Affiliate) + if err != nil { + return nil, err + } + return &types.MsgRegisterAffiliateResponse{}, nil } func (k msgServer) UpdateAffiliateTiers(ctx context.Context, msg *types.MsgUpdateAffiliateTiers) (*types.MsgUpdateAffiliateTiersResponse, error) { - sdkCtx := sdk.UnwrapSDKContext(ctx) + if !k.Keeper.HasAuthority(msg.Authority) { + return nil, errors.New("invalid authority") + } + sdkCtx := sdk.UnwrapSDKContext(ctx) unconditionalRevShareConfig, err := k.revShareKeeper.GetUnconditionalRevShareConfigParams(sdkCtx) if err != nil { return nil, err @@ -35,7 +48,9 @@ func (k msgServer) UpdateAffiliateTiers(ctx context.Context, ) } - return nil, nil + k.Keeper.UpdateAffiliateTiers(sdkCtx, msg.Tiers) + + return &types.MsgUpdateAffiliateTiersResponse{}, nil } // NewMsgServerImpl returns an implementation of the MsgServer interface diff --git a/protocol/x/affiliates/keeper/msg_server_test.go b/protocol/x/affiliates/keeper/msg_server_test.go index f4a1429abc..404b554391 100644 --- a/protocol/x/affiliates/keeper/msg_server_test.go +++ b/protocol/x/affiliates/keeper/msg_server_test.go @@ -4,7 +4,10 @@ import ( "context" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/dydxprotocol/v4-chain/protocol/lib" testapp "github.com/dydxprotocol/v4-chain/protocol/testutil/app" + constants "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/keeper" "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/types" "github.com/stretchr/testify/require" @@ -24,3 +27,108 @@ func TestMsgServer(t *testing.T) { require.NotNil(t, ms) require.NotNil(t, ctx) } + +func TestMsgServer_RegisterAffiliate(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgRegisterAffiliate + expectErr error + setup func(ctx sdk.Context, k keeper.Keeper) + }{ + { + name: "valid registration", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: nil, + }, + { + name: "invalid referee address", + msg: &types.MsgRegisterAffiliate{ + Referee: "invalid_address", + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: types.ErrInvalidAddress, + }, + { + name: "invalid affiliate address", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: "invalid_address", + }, + expectErr: types.ErrInvalidAddress, + }, + { + name: "referee already has an affiliate", + msg: &types.MsgRegisterAffiliate{ + Referee: constants.BobAccAddress.String(), + Affiliate: constants.AliceAccAddress.String(), + }, + expectErr: types.ErrAffiliateAlreadyExistsForReferee, + setup: func(ctx sdk.Context, k keeper.Keeper) { + err := k.RegisterAffiliate(ctx, constants.BobAccAddress.String(), constants.AliceAccAddress.String()) + require.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + if tc.setup != nil { + tc.setup(sdkCtx, k) + } + _, err := ms.RegisterAffiliate(ctx, tc.msg) + if tc.expectErr != nil { + require.ErrorIs(t, err, tc.expectErr) + } else { + require.NoError(t, err) + affiliate, found := k.GetReferredBy(sdkCtx, tc.msg.Referee) + require.True(t, found) + require.Equal(t, tc.msg.Affiliate, affiliate) + } + }) + } +} + +func TestMsgServer_UpdateAffiliateTiers(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgUpdateAffiliateTiers + expectErr bool + }{ + { + name: "Gov module updates tiers", + msg: &types.MsgUpdateAffiliateTiers{ + Authority: lib.GovModuleAddress.String(), + Tiers: types.DefaultAffiliateTiers, + }, + }, + { + name: "non-gov module updates tiers", + msg: &types.MsgUpdateAffiliateTiers{ + Authority: constants.BobAccAddress.String(), + Tiers: types.DefaultAffiliateTiers, + }, + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + _, err := ms.UpdateAffiliateTiers(ctx, tc.msg) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + tiers, err := k.GetAllAffiliateTiers(sdkCtx) + require.NoError(t, err) + require.Equal(t, tc.msg.Tiers, tiers) + } + }) + } +}