diff --git a/goccm.go b/goccm.go index 737e700..09ddbd2 100644 --- a/goccm.go +++ b/goccm.go @@ -34,8 +34,8 @@ type ( // This channel indicates when all goroutines have finished their job. allDoneCh chan bool - // The close flag allows we know when we can close the manager - closed bool + // The closed channel is closed which controller should close + closed chan bool // The running count allows we know the number of goroutines are running runningCount int32 @@ -50,6 +50,7 @@ func New(maxGoRoutines int) *concurrencyManager { managerCh: make(chan interface{}, maxGoRoutines), doneCh: make(chan bool), allDoneCh: make(chan bool), + closed: make(chan bool), } // Fill the manager channel by placeholder values @@ -75,7 +76,7 @@ func (c *concurrencyManager) controller() { // When the closed flag is set, // we need to close the manager if it doesn't have any running goroutine - if c.closed && c.runningCount == 0 { + if c.IsClosed() && c.RunningCount() == 0 { break } } @@ -105,8 +106,24 @@ func (c *concurrencyManager) Done() { } // Close the manager manually +// terminate if channel is already closed func (c *concurrencyManager) Close() { - c.closed = true + // terminate if channel is already closed + select { + case <-c.closed: + return + default: + close(c.closed) + } +} + +func (c *concurrencyManager) IsClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } } // WaitAllDone Wait for all goroutines are done @@ -120,5 +137,5 @@ func (c *concurrencyManager) WaitAllDone() { // RunningCount Returns the number of goroutines which are running func (c *concurrencyManager) RunningCount() int32 { - return c.runningCount + return atomic.AddInt32(&c.runningCount, 0) } diff --git a/goccm_test.go b/goccm_test.go index baebda1..6a44c38 100644 --- a/goccm_test.go +++ b/goccm_test.go @@ -2,6 +2,7 @@ package goccm import ( "fmt" + "log" "testing" "time" ) @@ -12,25 +13,31 @@ func TestExample(t *testing.T) { c.Wait() go func(i int) { fmt.Printf("Job %d is running\n", i) - time.Sleep(2 * time.Second) + time.Sleep(20 * time.Millisecond) c.Done() }(i) } c.WaitAllDone() } +// TestManuallyClose will close after 5 jobs, others should not run func TestManuallyClose(t *testing.T) { - executedJobs := 0 + executedJobs := make(chan int, 1000) + c := New(3) for i := 1; i <= 1000; i++ { + jobId := i + c.Wait() go func() { - executedJobs++ - fmt.Printf("Executed jobs %d\n", executedJobs) - time.Sleep(2 * time.Second) + executedJobs <- jobId + fmt.Printf("Executed job id %d\n", jobId) + time.Sleep(20 * time.Millisecond) c.Done() }() + if i == 5 { + log.Printf("Closing manager") c.Close() break } @@ -40,21 +47,26 @@ func TestManuallyClose(t *testing.T) { func TestConcurrency(t *testing.T) { var maxRunningJobs = 3 - var testMaxRunningJobs int32 + testMaxRunningJobs := make(chan int32, 100) c := New(maxRunningJobs) + for i := 1; i <= 10; i++ { c.Wait() go func(i int) { fmt.Printf("Current running jobs %d\n", c.RunningCount()) - if c.RunningCount() > testMaxRunningJobs { - testMaxRunningJobs = c.RunningCount() - } - time.Sleep(2 * time.Second) + testMaxRunningJobs <- c.RunningCount() + time.Sleep(20 * time.Millisecond) c.Done() }(i) } + c.WaitAllDone() - if testMaxRunningJobs > int32(maxRunningJobs) { - t.Errorf("The number of concurrency jobs has exceeded %d. Real result %d.", maxRunningJobs, testMaxRunningJobs) + + for i := 1; i <= 10; i++ { + observed := <-testMaxRunningJobs + + if observed > int32(maxRunningJobs) { + t.Errorf("The number of concurrency jobs has exceeded %d. Real result %d.", maxRunningJobs, testMaxRunningJobs) + } } }