1
1
package promise
2
2
3
- import "reflect"
4
- import "sync"
5
- import "sync/atomic"
6
- import "github.com/pkg/errors"
3
+ import (
4
+ "fmt"
5
+ "reflect"
6
+ "sync"
7
+ "sync/atomic"
8
+
9
+ "github.com/pkg/errors"
10
+ )
7
11
8
12
type promiseType int
9
13
10
14
const (
11
- legacyCall promiseType = iota
12
- simpleCall
15
+ simpleCall promiseType = iota
13
16
thenCall
14
17
allCall
18
+ raceCall
15
19
anyCall
16
20
)
17
21
@@ -23,10 +27,12 @@ type Promise struct {
23
27
functionRv reflect.Value
24
28
results []reflect.Value
25
29
resultType []reflect.Type
30
+ anyErrs []error
26
31
// returnsError is true if the last value returns an error
27
32
returnsError bool
28
33
cond sync.Cond
29
34
counter int64
35
+ errCounter int64
30
36
noCopy
31
37
}
32
38
@@ -36,7 +42,7 @@ type noCopy struct{}
36
42
func (* noCopy ) Lock () {}
37
43
func (* noCopy ) Unlock () {}
38
44
39
- func (p * Promise ) anyCall (priors []* Promise , index int ) (results []reflect.Value ) {
45
+ func (p * Promise ) raceCall (priors []* Promise , index int ) (results []reflect.Value ) {
40
46
prior := priors [index ]
41
47
prior .cond .L .Lock ()
42
48
for ! prior .complete {
@@ -78,6 +84,40 @@ func (p *Promise) allCall(priors []*Promise, index int) (results []reflect.Value
78
84
return nil
79
85
}
80
86
87
+ // AnyErr returns when all promises passed to Any fail
88
+ type AnyErr struct {
89
+ // Errs contains the error of all passed promises
90
+ Errs []error
91
+ // LastErr contains the error of the last promise to fail.
92
+ LastErr error
93
+ }
94
+
95
+ func (err * AnyErr ) Error () string {
96
+ return fmt .Sprintf ("all %d promises failed. last err=%v" , len (err .Errs ), err .LastErr )
97
+ }
98
+
99
+ func (p * Promise ) anyCall (priors []* Promise , index int ) (results []reflect.Value ) {
100
+ prior := priors [index ]
101
+ prior .cond .L .Lock ()
102
+ for ! prior .complete {
103
+ prior .cond .Wait ()
104
+ }
105
+ prior .cond .L .Unlock ()
106
+ if prior .err != nil {
107
+ remaining := atomic .AddInt64 (& p .errCounter , - 1 )
108
+ p .anyErrs [index ] = prior .err
109
+ if remaining != 0 {
110
+ return nil
111
+ }
112
+ panic (AnyErr {Errs : p .anyErrs [:], LastErr : prior .err })
113
+ }
114
+ remaining := atomic .AddInt64 (& p .counter , - 1 )
115
+ if remaining == 0 {
116
+ return prior .results [:]
117
+ }
118
+ return nil
119
+ }
120
+
81
121
func empty () {}
82
122
83
123
// All returns a promise that resolves if all of the passed promises
@@ -107,10 +147,10 @@ func All(promises ...*Promise) *Promise {
107
147
108
148
const anyErrorFormat = "promise %d has an unexpected return type, expected all promises passed to Any to return the same type"
109
149
110
- // Any returns a promise that resolves if any of the passed promises
150
+ // Race returns a promise that resolves if any of the passed promises
111
151
// succeed or fails if any of the passed promises panics.
112
152
// All of the supplied promises must be of the same type.
113
- func Any (promises ... * Promise ) * Promise {
153
+ func Race (promises ... * Promise ) * Promise {
114
154
if len (promises ) == 0 {
115
155
return New (empty )
116
156
}
@@ -135,16 +175,57 @@ func Any(promises ...*Promise) *Promise {
135
175
136
176
p := & Promise {
137
177
cond : sync.Cond {L : & sync.Mutex {}},
138
- t : anyCall ,
178
+ t : raceCall ,
139
179
}
140
180
141
181
// Extract the type
142
- p .resultType = []reflect.Type {}
143
- for _ , prior := range promises {
144
- p .resultType = append (p .resultType , prior .resultType ... )
182
+ p .resultType = firstResultType [:]
183
+
184
+ p .counter = int64 (1 )
185
+
186
+ for i := range promises {
187
+ go p .run (reflect.Value {}, nil , promises , i , nil )
188
+ }
189
+ return p
190
+ }
191
+
192
+ // Any returns a promise that resolves if any of the passed promises
193
+ // succeed or fails if all of the passed promises panics.
194
+ // All of the supplied promises must be of the same type.
195
+ func Any (promises ... * Promise ) * Promise {
196
+ if len (promises ) == 0 {
197
+ return New (empty )
198
+ }
199
+
200
+ if len (promises ) == 1 {
201
+ return promises [0 ]
202
+ }
203
+
204
+ // Check that all the promises have the same return type
205
+ firstResultType := promises [0 ].resultType
206
+ for promiseIdx , promise := range promises [1 :] {
207
+ newResultType := promise .resultType
208
+ if len (firstResultType ) != len (newResultType ) {
209
+ panic (errors .Errorf (anyErrorFormat , promiseIdx ))
210
+ }
211
+ for index := range firstResultType {
212
+ if firstResultType [index ] != newResultType [index ] {
213
+ panic (errors .Errorf (anyErrorFormat , promiseIdx ))
214
+ }
215
+ }
216
+ }
217
+
218
+ p := & Promise {
219
+ cond : sync.Cond {L : & sync.Mutex {}},
220
+ t : anyCall ,
221
+ anyErrs : make ([]error , len (promises )),
145
222
}
146
223
224
+ // Extract the type
225
+ p .resultType = firstResultType [:]
226
+
147
227
p .counter = int64 (1 )
228
+ p .errCounter = int64 (len (promises ))
148
229
149
230
for i := range promises {
150
231
go p .run (reflect.Value {}, nil , promises , i , nil )
@@ -225,10 +306,10 @@ func (p *Promise) thenCall(prior *Promise, functionRv reflect.Value) []reflect.V
225
306
if p .err != nil {
226
307
panic (errors .Wrap (p .err , "error in previous promise" ))
227
308
}
228
- results := functionRv .Call (prior .results )
229
- if prior .returnsError && prior .err != nil {
309
+ if prior .err != nil {
230
310
panic (prior .err )
231
311
}
312
+ results := functionRv .Call (prior .results )
232
313
return results
233
314
}
234
315
@@ -322,6 +403,11 @@ func (p *Promise) run(functionRv reflect.Value, prior *Promise, priors []*Promis
322
403
}
323
404
case anyCall :
324
405
results = p .anyCall (priors , index )
406
+ if results == nil {
407
+ return
408
+ }
409
+ case raceCall :
410
+ results = p .raceCall (priors , index )
325
411
default :
326
412
panic ("unexpected call type" )
327
413
}
@@ -417,7 +503,7 @@ func (p *Promise) Wait(out ...interface{}) error {
417
503
p .cond .L .Unlock ()
418
504
419
505
if p .err != nil {
420
- return errors .Wrap (p .err , "panic() during promise execution" )
506
+ return errors .Wrap (p .err , "error during promise execution" )
421
507
}
422
508
423
509
var outRvs []reflect.Value
0 commit comments