diff --git a/daemon/algod/api/server/v2/handlers.go b/daemon/algod/api/server/v2/handlers.go index 065337f32d..0282594f4f 100644 --- a/daemon/algod/api/server/v2/handlers.go +++ b/daemon/algod/api/server/v2/handlers.go @@ -22,7 +22,6 @@ import ( "encoding/base64" "errors" "fmt" - "golang.org/x/sync/semaphore" "io" "math" "net/http" @@ -31,6 +30,7 @@ import ( "time" "github.com/labstack/echo/v4" + "golang.org/x/sync/semaphore" "github.com/algorand/avm-abi/apps" "github.com/algorand/go-codec/codec" @@ -96,6 +96,7 @@ type LedgerForAPI interface { LatestTotals() (basics.Round, ledgercore.AccountTotals, error) BlockHdr(rnd basics.Round) (blk bookkeeping.BlockHeader, err error) Wait(r basics.Round) chan struct{} + WaitWithCancel(r basics.Round) (chan struct{}, func()) GetCreator(cidx basics.CreatableIndex, ctype basics.CreatableType) (basics.Address, bool, error) EncodedBlockCert(rnd basics.Round) (blk []byte, cert []byte, err error) Block(rnd basics.Round) (blk bookkeeping.Block, err error) @@ -940,11 +941,15 @@ func (v2 *Handlers) WaitForBlock(ctx echo.Context, round uint64) error { } // Wait + ledgerWaitCh, cancelLedgerWait := ledger.WaitWithCancel(basics.Round(round + 1)) + defer cancelLedgerWait() select { case <-v2.Shutdown: return internalError(ctx, err, errServiceShuttingDown, v2.Log) + case <-ctx.Request().Context().Done(): + return ctx.NoContent(http.StatusRequestTimeout) case <-time.After(WaitForBlockTimeout): - case <-ledger.Wait(basics.Round(round + 1)): + case <-ledgerWaitCh: } // Return status after the wait diff --git a/daemon/algod/api/server/v2/test/handlers_resources_test.go b/daemon/algod/api/server/v2/test/handlers_resources_test.go index 1de86ddc19..adf187053a 100644 --- a/daemon/algod/api/server/v2/test/handlers_resources_test.go +++ b/daemon/algod/api/server/v2/test/handlers_resources_test.go @@ -19,11 +19,12 @@ package test import ( "encoding/json" "fmt" - "github.com/algorand/go-algorand/data/transactions/logic" "net/http" "net/http/httptest" "testing" + "github.com/algorand/go-algorand/data/transactions/logic" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -135,6 +136,9 @@ func (l *mockLedger) BlockHdr(rnd basics.Round) (bookkeeping.BlockHeader, error) func (l *mockLedger) Wait(r basics.Round) chan struct{} { panic("not implemented") } +func (l *mockLedger) WaitWithCancel(r basics.Round) (chan struct{}, func()) { + panic("not implemented") +} func (l *mockLedger) GetCreator(cidx basics.CreatableIndex, ctype basics.CreatableType) (c basics.Address, ok bool, err error) { panic("not implemented") } diff --git a/ledger/bulletin.go b/ledger/bulletin.go index 039fb3376e..5dc2b99aaa 100644 --- a/ledger/bulletin.go +++ b/ledger/bulletin.go @@ -18,7 +18,6 @@ package ledger import ( "context" - "sync/atomic" "github.com/algorand/go-deadlock" @@ -28,29 +27,17 @@ import ( "github.com/algorand/go-algorand/ledger/store/trackerdb" ) -// notifier is a struct that encapsulates a single-shot channel; it will only be signaled once. +// notifier is a struct that encapsulates a single-shot channel; it should only be signaled once. type notifier struct { - signal chan struct{} - notified *atomic.Uint32 -} - -// makeNotifier constructs a notifier that has not been signaled. -func makeNotifier() notifier { - return notifier{signal: make(chan struct{}), notified: &atomic.Uint32{}} -} - -// notify signals the channel if it hasn't already done so -func (notifier *notifier) notify() { - if notifier.notified.CompareAndSwap(0, 1) { - close(notifier.signal) - } + signal chan struct{} + count int } // bulletin provides an easy way to wait on a round to be written to the ledger. // To use it, call <-Wait(round). type bulletin struct { mu deadlock.Mutex - pendingNotificationRequests map[basics.Round]notifier + pendingNotificationRequests map[basics.Round]*notifier latestRound basics.Round } @@ -62,7 +49,7 @@ type bulletinMem struct { func makeBulletin() *bulletin { b := new(bulletin) - b.pendingNotificationRequests = make(map[basics.Round]notifier) + b.pendingNotificationRequests = make(map[basics.Round]*notifier) return b } @@ -80,14 +67,32 @@ func (b *bulletin) Wait(round basics.Round) chan struct{} { signal, exists := b.pendingNotificationRequests[round] if !exists { - signal = makeNotifier() + signal = ¬ifier{signal: make(chan struct{})} b.pendingNotificationRequests[round] = signal } + // Increment count of waiters, to support canceling. + signal.count++ + return signal.signal } +// CancelWait removes a wait for a particular round. If no one else is waiting, the +// notifier channel for that round is removed. +func (b *bulletin) CancelWait(round basics.Round) { + b.mu.Lock() + defer b.mu.Unlock() + + signal, exists := b.pendingNotificationRequests[round] + if exists { + signal.count-- + if signal.count <= 0 { + delete(b.pendingNotificationRequests, round) + } + } +} + func (b *bulletin) loadFromDisk(l ledgerForTracker, _ basics.Round) error { - b.pendingNotificationRequests = make(map[basics.Round]notifier) + b.pendingNotificationRequests = make(map[basics.Round]*notifier) b.latestRound = l.Latest() return nil } @@ -105,7 +110,8 @@ func (b *bulletin) notifyRound(rnd basics.Round) { } delete(b.pendingNotificationRequests, pending) - signal.notify() + // signal the channel by closing it; this is under lock and will only happen once + close(signal.signal) } b.latestRound = rnd diff --git a/ledger/bulletin_test.go b/ledger/bulletin_test.go index 5a6f6bb833..88d3784470 100644 --- a/ledger/bulletin_test.go +++ b/ledger/bulletin_test.go @@ -20,7 +20,9 @@ import ( "testing" "time" + "github.com/algorand/go-algorand/data/basics" "github.com/algorand/go-algorand/test/partitiontest" + "github.com/stretchr/testify/require" ) const epsilon = 5 * time.Millisecond @@ -100,3 +102,109 @@ func TestBulletin(t *testing.T) { t.Errorf("<-Wait(10) finished late") } } + +func TestCancelWait(t *testing.T) { + partitiontest.PartitionTest(t) + + bul := makeBulletin() + + // Calling Wait before CancelWait + waitCh := bul.Wait(5) + bul.CancelWait(5) + bul.committedUpTo(5) + select { + case <-waitCh: + t.Errorf("<-Wait(5) should have been cancelled") + case <-time.After(epsilon): + // Correct + } + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(5)) + + // Calling CancelWait before Wait + bul.CancelWait(6) + select { + case <-bul.Wait(6): + t.Errorf("<-Wait(6) should have been cancelled") + case <-time.After(epsilon): + // Correct + } + require.Contains(t, bul.pendingNotificationRequests, basics.Round(6)) + require.Equal(t, bul.pendingNotificationRequests[basics.Round(6)].count, 1) + bul.CancelWait(6) + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(6)) + + // Two Waits, one cancelled + waitCh1 := bul.Wait(7) + waitCh2 := bul.Wait(7) + require.Equal(t, waitCh1, waitCh2) + bul.CancelWait(7) + select { + case <-waitCh1: + t.Errorf("<-Wait(7) should not be notified yet") + case <-time.After(epsilon): + // Correct + } + // Still one waiter + require.Contains(t, bul.pendingNotificationRequests, basics.Round(7)) + require.Equal(t, bul.pendingNotificationRequests[basics.Round(7)].count, 1) + + bul.committedUpTo(7) + select { + case <-waitCh1: + // Correct + case <-time.After(epsilon): + t.Errorf("<-Wait(7) should have been notified") + } + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(7)) + + // Wait followed by Cancel for a round that already completed + waitCh = bul.Wait(5) + bul.CancelWait(5) + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(5)) + select { + case <-waitCh: + // Correct + case <-time.After(epsilon): + t.Errorf("<-Wait(5) should have been notified right away") + } + + // Cancel Wait after Wait triggered + waitCh = bul.Wait(8) + require.Contains(t, bul.pendingNotificationRequests, basics.Round(8)) + require.Equal(t, bul.pendingNotificationRequests[basics.Round(8)].count, 1) + bul.committedUpTo(8) + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(8)) + select { + case <-waitCh: + // Correct + case <-time.After(epsilon): + t.Errorf("<-Wait(8) should have been notified") + } + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(8)) + bul.CancelWait(8) // should do nothing + + // Cancel Wait after Wait triggered, but before Wait returned + waitCh = bul.Wait(9) + require.Contains(t, bul.pendingNotificationRequests, basics.Round(9)) + require.Equal(t, bul.pendingNotificationRequests[basics.Round(9)].count, 1) + bul.committedUpTo(9) + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(9)) + bul.CancelWait(9) // should do nothing + select { + case <-waitCh: + // Correct + case <-time.After(epsilon): + t.Errorf("<-Wait(9) should have been notified") + } + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(9)) + + // Two waits, both cancelled + waitCh1 = bul.Wait(10) + waitCh2 = bul.Wait(10) + require.Equal(t, waitCh1, waitCh2) + bul.CancelWait(10) + require.Contains(t, bul.pendingNotificationRequests, basics.Round(10)) + require.Equal(t, bul.pendingNotificationRequests[basics.Round(10)].count, 1) + bul.CancelWait(10) + require.NotContains(t, bul.pendingNotificationRequests, basics.Round(10)) +} diff --git a/ledger/ledger.go b/ledger/ledger.go index dc3baaf766..7dad8bbb64 100644 --- a/ledger/ledger.go +++ b/ledger/ledger.go @@ -769,6 +769,16 @@ func (l *Ledger) Wait(r basics.Round) chan struct{} { return l.bulletinDisk.Wait(r) } +// WaitWithCancel returns a channel that closes once a given round is +// stored durably in the ledger. The returned function can be used to +// cancel the wait, which cleans up resources if no other Wait call is +// active for the same round. +func (l *Ledger) WaitWithCancel(r basics.Round) (chan struct{}, func()) { + l.trackerMu.RLock() + defer l.trackerMu.RUnlock() + return l.bulletinDisk.Wait(r), func() { l.bulletinDisk.CancelWait(r) } +} + // WaitMem returns a channel that closes once a given round is // available in memory in the ledger, but might not be stored // durably on disk yet.