Skip to content

Commit

Permalink
Stop mocking global func() variables
Browse files Browse the repository at this point in the history
Doing that may cause erratic test failures when we run them in parallel, so
move the functions the tests need to mock as struct fields that are not
shared across tests.
  • Loading branch information
gsalgado committed May 14, 2015
1 parent 97e84fe commit fe0f609
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 110 deletions.
17 changes: 0 additions & 17 deletions votingpool/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"testing"

"github.com/btcsuite/btclog"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/waddrmgr"
)

Expand Down Expand Up @@ -99,19 +98,3 @@ func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) {
t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2)
}
}

// replaceCalculateTxFee replaces the calculateTxFee func with the given one
// and returns a function that restores it to the original one.
func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() {
orig := calculateTxFee
calculateTxFee = f
return func() { calculateTxFee = orig }
}

// replaceCalculateTxSize replaces the calculateTxSize func with the given one
// and returns a function that restores it to the original one.
func replaceCalculateTxSize(f func(*withdrawalTx) int) func() {
orig := calculateTxSize
calculateTxSize = f
return func() { calculateTxSize = orig }
}
6 changes: 3 additions & 3 deletions votingpool/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func getUniqueID() uint32 {
// createWithdrawalTx creates a withdrawalTx with the given input and output amounts.
func createWithdrawalTx(t *testing.T, pool *Pool, inputAmounts []int64, outputAmounts []int64) *withdrawalTx {
net := pool.Manager().ChainParams()
tx := newWithdrawalTx()
tx := newWithdrawalTx(defaultTxOptions)
_, credits := TstCreateCreditsOnNewSeries(t, pool, inputAmounts)
for _, c := range credits {
tx.addInput(c)
Expand Down Expand Up @@ -418,8 +418,8 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add
return addr
}

func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount {
return func(tx *withdrawalTx) btcutil.Amount { return fee }
func TstConstantFee(fee btcutil.Amount) func() btcutil.Amount {
return func() btcutil.Amount { return fee }
}

func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo {
Expand Down
56 changes: 36 additions & 20 deletions votingpool/withdrawal.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ type withdrawal struct {
pendingRequests []OutputRequest
eligibleInputs []credit
current *withdrawalTx
// txOptions is a function called for every new withdrawalTx created as
// part of this withdrawal. It is defined as a function field because it
// exists mainly so that tests can mock withdrawalTx fields.
txOptions func(tx *withdrawalTx)
}

// withdrawalTxOut wraps an OutputRequest and provides a separate amount field.
Expand Down Expand Up @@ -301,10 +305,26 @@ type withdrawalTx struct {

// changeOutput holds information about the change for this transaction.
changeOutput *wire.TxOut

// calculateSize returns the estimated serialized size (in bytes) of this
// tx. See calculateTxSize() for details on how that's done. We use a
// struct field instead of a method so that it can be replaced in tests.
calculateSize func() int
// calculateFee calculates the expected network fees for this tx. We use a
// struct field instead of a method so that it can be replaced in tests.
calculateFee func() btcutil.Amount
}

func newWithdrawalTx() *withdrawalTx {
return &withdrawalTx{}
// newWithdrawalTx creates a new withdrawalTx and calls setOptions()
// passing the newly created tx.
func newWithdrawalTx(setOptions func(tx *withdrawalTx)) *withdrawalTx {
tx := &withdrawalTx{}
tx.calculateSize = func() int { return calculateTxSize(tx) }
tx.calculateFee = func() btcutil.Amount {
return btcutil.Amount(1+tx.calculateSize()/1000) * feeIncrement
}
setOptions(tx)
return tx
}

// ntxid returns the unique ID for this transaction.
Expand All @@ -323,7 +343,7 @@ func (tx *withdrawalTx) isTooBig() bool {
// In bitcoind a tx is considered standard only if smaller than
// MAX_STANDARD_TX_SIZE; that's why we consider anything >= txMaxSize to
// be too big.
return calculateTxSize(tx) >= txMaxSize
return tx.calculateSize() >= txMaxSize
}

// inputTotal returns the sum amount of all inputs in this tx.
Expand Down Expand Up @@ -401,7 +421,7 @@ func (tx *withdrawalTx) removeInput() credit {
// added after it's called. Also, callsites must make sure adding a change
// output won't cause the tx to exceed the size limit.
func (tx *withdrawalTx) addChange(pkScript []byte) bool {
tx.fee = calculateTxFee(tx)
tx.fee = tx.calculateFee()
change := tx.inputTotal() - tx.outputTotal() - tx.fee
log.Debugf("addChange: input total %v, output total %v, fee %v", tx.inputTotal(),
tx.outputTotal(), tx.fee)
Expand Down Expand Up @@ -430,7 +450,7 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error)

var removedInputs []credit
// Continue until sum(in) < sum(out) + fee
for tx.inputTotal() >= tx.outputTotal()+calculateTxFee(tx) {
for tx.inputTotal() >= tx.outputTotal()+tx.calculateFee() {
removedInputs = append(removedInputs, tx.removeInput())
}

Expand All @@ -440,6 +460,8 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error)
return removedInputs, removedOutput, nil
}

func defaultTxOptions(tx *withdrawalTx) {}

func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
changeStart ChangeAddress) *withdrawal {
outputs := make(map[OutBailmentID]*WithdrawalOutput, len(requests))
Expand All @@ -452,10 +474,10 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
}
return &withdrawal{
roundID: roundID,
current: newWithdrawalTx(),
pendingRequests: requests,
eligibleInputs: inputs,
status: status,
txOptions: defaultTxOptions,
}
}

Expand Down Expand Up @@ -553,7 +575,7 @@ func (w *withdrawal) fulfillNextRequest() error {
return w.handleOversizeTx()
}

fee := calculateTxFee(w.current)
fee := w.current.calculateFee()
for w.current.inputTotal() < w.current.outputTotal()+fee {
if len(w.eligibleInputs) == 0 {
log.Debug("Splitting last output because we don't have enough inputs")
Expand All @@ -563,7 +585,7 @@ func (w *withdrawal) fulfillNextRequest() error {
break
}
w.current.addInput(w.popInput())
fee = calculateTxFee(w.current)
fee = w.current.calculateFee()

if w.current.isTooBig() {
return w.handleOversizeTx()
Expand Down Expand Up @@ -647,7 +669,7 @@ func (w *withdrawal) finalizeCurrentTx() error {
}

w.transactions = append(w.transactions, tx)
w.current = newWithdrawalTx()
w.current = newWithdrawalTx(w.txOptions)
return nil
}

Expand Down Expand Up @@ -683,12 +705,13 @@ func (w *withdrawal) fulfillRequests() error {
// Sort outputs by outBailmentID (hash(server ID, tx #))
sort.Sort(byOutBailmentID(w.pendingRequests))

w.current = newWithdrawalTx(w.txOptions)
for len(w.pendingRequests) > 0 {
if err := w.fulfillNextRequest(); err != nil {
return err
}
tx := w.current
if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+calculateTxFee(tx) {
if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+tx.calculateFee() {
// We don't have more eligible inputs and all the inputs in the
// current tx have been spent.
break
Expand Down Expand Up @@ -731,7 +754,7 @@ func (w *withdrawal) splitLastOutput() error {
output := tx.outputs[len(tx.outputs)-1]
log.Debugf("Splitting tx output for %s", output.request)
origAmount := output.amount
spentAmount := tx.outputTotal() + calculateTxFee(tx) - output.amount
spentAmount := tx.outputTotal() + tx.calculateFee() - output.amount
// This is how much we have left after satisfying all outputs except the last
// one. IOW, all we have left for the last output, so we set that as the
// amount of the tx's last output.
Expand Down Expand Up @@ -993,16 +1016,9 @@ func validateSigScript(msgtx *wire.MsgTx, idx int, pkScript []byte) error {
return nil
}

// calculateTxFee calculates the expected network fees for a given tx. We use
// a variable instead of a function so that it can be replaced in tests.
var calculateTxFee = func(tx *withdrawalTx) btcutil.Amount {
return btcutil.Amount(1+calculateTxSize(tx)/1000) * feeIncrement
}

// calculateTxSize returns an estimate of the serialized size (in bytes) of the
// given transaction. It assumes all tx inputs are P2SH multi-sig. We use a
// variable instead of a function so that it can be replaced in tests.
var calculateTxSize = func(tx *withdrawalTx) int {
// given transaction. It assumes all tx inputs are P2SH multi-sig.
func calculateTxSize(tx *withdrawalTx) int {
msgtx := tx.toMsgTx()
// Assume that there will always be a change output, for simplicity. We
// simulate that by simply copying the first output as all we care about is
Expand Down
Loading

0 comments on commit fe0f609

Please sign in to comment.