Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
shrimalmadhur committed Sep 27, 2024
1 parent 2e5c79a commit 500b8dc
Showing 1 changed file with 62 additions and 37 deletions.
99 changes: 62 additions & 37 deletions pkg/rewards/show.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rewards

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -29,10 +30,18 @@ import (

type ClaimType string

type ELReader interface {
GetCumulativeClaimed(opts *bind.CallOpts, earnerAddress, tokenAddress gethcommon.Address) (*big.Int, error)
}

const (
All ClaimType = "all"
Unclaimed ClaimType = "unclaimed"
Claimed ClaimType = "claimed"

AllRewards = "All Rewards"
UnclaimedRewards = "Unclaimed Rewards"
ClaimedRewards = "Claimed Rewards"
)

func ShowCmd(p utils.Prompter) *cli.Command {
Expand Down Expand Up @@ -120,56 +129,72 @@ func ShowRewards(cCtx *cli.Context) error {
}

allRewards := make(map[gethcommon.Address]*big.Int)
msg := AllRewards
for pair := tokenAddressesMap.Oldest(); pair != nil; pair = pair.Next() {
amt, _ := new(big.Int).SetString(pair.Value.String(), 10)
allRewards[pair.Key] = amt
}

if config.ClaimType == All {
fmt.Println(strings.Repeat("-", 30), "All Rewards", strings.Repeat("-", 30))
err := handleRewardsOutput(config.Output, config.OutputType, allRewards)
if config.ClaimType != All {
claimedRewards, err := getClaimedRewards(ctx, elReader, config.EarnerAddress, allRewards)
if err != nil {
return err
return eigenSdkUtils.WrapError("failed to get claimed rewards", err)
}
} else {
claimedRewards := make(map[gethcommon.Address]*big.Int)
for address, _ := range allRewards {
claimed, err := elReader.GetCumulativeClaimed(&bind.CallOpts{Context: ctx}, config.EarnerAddress, address)
if err != nil {
return err
}
if claimed == nil {
claimed = big.NewInt(0)
}
claimedRewards[address] = claimed
switch config.ClaimType {
case Claimed:
allRewards = claimedRewards
msg = ClaimedRewards
case Unclaimed:
allRewards = calculateUnclaimedRewards(allRewards, claimedRewards)
msg = UnclaimedRewards
}
if config.ClaimType == Claimed {
fmt.Println(strings.Repeat("-", 30), "Claimed Rewards", strings.Repeat("-", 30))
err := handleRewardsOutput(config.Output, config.OutputType, claimedRewards)
if err != nil {
return err
}
} else if config.ClaimType == Unclaimed {
unclaimedRewards := make(map[gethcommon.Address]*big.Int)
for address, _ := range allRewards {
total := allRewards[address]
claimed := claimedRewards[address]
unclaimedRewards[address] = new(big.Int).Sub(total, claimed)
}
fmt.Println(strings.Repeat("-", 30), "Unclaimed Rewards", strings.Repeat("-", 30))
err := handleRewardsOutput(config.Output, config.OutputType, unclaimedRewards)
if err != nil {
return err
}
} else {
return fmt.Errorf("claim type %s not supported", config.ClaimType)
}
err = handleRewardsOutput(config.Output, config.OutputType, allRewards, msg)
if err != nil {
return err
}
return nil
}

func getClaimedRewards(
ctx context.Context,
elReader ELReader,
earnerAddress gethcommon.Address,
allRewards map[gethcommon.Address]*big.Int,
) (map[gethcommon.Address]*big.Int, error) {
claimedRewards := make(map[gethcommon.Address]*big.Int)
for address := range allRewards {
claimed, err := elReader.GetCumulativeClaimed(&bind.CallOpts{Context: ctx}, earnerAddress, address)
if err != nil {
return nil, err
}
if claimed == nil {
claimed = big.NewInt(0)
}
claimedRewards[address] = claimed
}
return claimedRewards, nil
}

return nil
func calculateUnclaimedRewards(
allRewards,
claimedRewards map[gethcommon.Address]*big.Int,
) map[gethcommon.Address]*big.Int {
unclaimedRewards := make(map[gethcommon.Address]*big.Int)
for address, total := range allRewards {
claimed := claimedRewards[address]
unclaimedRewards[address] = new(big.Int).Sub(total, claimed)
}
return unclaimedRewards
}

func handleRewardsOutput(outputFile string, outputType string, rewards map[gethcommon.Address]*big.Int) error {
func handleRewardsOutput(
outputFile string,
outputType string,
rewards map[gethcommon.Address]*big.Int,
msg string,
) error {
fmt.Println(strings.Repeat("-", 30), msg, strings.Repeat("-", 30))
if outputType == "json" {
allRewards := make(allRewardsJson, 0)
for address, amount := range rewards {
Expand Down

0 comments on commit 500b8dc

Please sign in to comment.