Skip to content

Commit 38b4bfd

Browse files
authored
fix: Revert "fix: Error handling in StreamingBatchWriter" (#1918)
Reverts #1913 This broke come stuff, so reverting it to unblock SDK changes cloudquery/cloudquery#19312 (comment)
1 parent 00b9d9a commit 38b4bfd

File tree

2 files changed

+76
-103
lines changed

2 files changed

+76
-103
lines changed

writers/streamingbatchwriter/streamingbatchwriter.go

+54-74
Original file line numberDiff line numberDiff line change
@@ -182,28 +182,26 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182182
errCh := make(chan error)
183183
defer close(errCh)
184184

185-
for {
186-
select {
187-
case msg, ok := <-msgs:
188-
if !ok {
189-
return w.Close(ctx)
190-
}
185+
go func() {
186+
for err := range errCh {
187+
w.logger.Err(err).Msg("error from StreamingBatchWriter")
188+
}
189+
}()
191190

192-
msgType := writers.MsgID(msg)
193-
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
194-
if err := w.Flush(ctx); err != nil {
195-
return err
196-
}
197-
}
198-
w.lastMsgType = msgType
199-
if err := w.startWorker(ctx, errCh, msg); err != nil {
191+
for msg := range msgs {
192+
msgType := writers.MsgID(msg)
193+
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
194+
if err := w.Flush(ctx); err != nil {
200195
return err
201196
}
202-
203-
case err := <-errCh:
197+
}
198+
w.lastMsgType = msgType
199+
if err := w.startWorker(ctx, errCh, msg); err != nil {
204200
return err
205201
}
206202
}
203+
204+
return w.Close(ctx)
207205
}
208206

209207
func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {
@@ -223,14 +221,13 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
223221
case *message.WriteMigrateTable:
224222
w.workersLock.Lock()
225223
defer w.workersLock.Unlock()
226-
227224
if w.migrateWorker != nil {
228225
w.migrateWorker.ch <- m
229226
return nil
230227
}
231-
228+
ch := make(chan *message.WriteMigrateTable)
232229
w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{
233-
ch: make(chan *message.WriteMigrateTable),
230+
ch: ch,
234231
writeFunc: w.client.MigrateTable,
235232

236233
flush: make(chan chan bool),
@@ -244,19 +241,17 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
244241
w.workersWaitGroup.Add(1)
245242
go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName)
246243
w.migrateWorker.ch <- m
247-
248244
return nil
249245
case *message.WriteDeleteStale:
250246
w.workersLock.Lock()
251247
defer w.workersLock.Unlock()
252-
253248
if w.deleteStaleWorker != nil {
254249
w.deleteStaleWorker.ch <- m
255250
return nil
256251
}
257-
252+
ch := make(chan *message.WriteDeleteStale)
258253
w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{
259-
ch: make(chan *message.WriteDeleteStale),
254+
ch: ch,
260255
writeFunc: w.client.DeleteStale,
261256

262257
flush: make(chan chan bool),
@@ -270,29 +265,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
270265
w.workersWaitGroup.Add(1)
271266
go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName)
272267
w.deleteStaleWorker.ch <- m
273-
274268
return nil
275269
case *message.WriteInsert:
276270
w.workersLock.RLock()
277-
worker, ok := w.insertWorkers[tableName]
271+
wr, ok := w.insertWorkers[tableName]
278272
w.workersLock.RUnlock()
279273
if ok {
280-
worker.ch <- m
274+
wr.ch <- m
281275
return nil
282276
}
283277

284-
w.workersLock.Lock()
285-
activeWorker, ok := w.insertWorkers[tableName]
286-
if ok {
287-
w.workersLock.Unlock()
288-
// some other goroutine could have already added the worker
289-
// just send the message to it & discard our allocated worker
290-
activeWorker.ch <- m
291-
return nil
292-
}
293-
294-
worker = &streamingWorkerManager[*message.WriteInsert]{
295-
ch: make(chan *message.WriteInsert),
278+
ch := make(chan *message.WriteInsert)
279+
wr = &streamingWorkerManager[*message.WriteInsert]{
280+
ch: ch,
296281
writeFunc: w.client.WriteTable,
297282

298283
flush: make(chan chan bool),
@@ -302,27 +287,33 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
302287
batchTimeout: w.batchTimeout,
303288
tickerFn: w.tickerFn,
304289
}
305-
306-
w.insertWorkers[tableName] = worker
290+
w.workersLock.Lock()
291+
wrOld, ok := w.insertWorkers[tableName]
292+
if ok {
293+
w.workersLock.Unlock()
294+
// some other goroutine could have already added the worker
295+
// just send the message to it & discard our allocated worker
296+
wrOld.ch <- m
297+
return nil
298+
}
299+
w.insertWorkers[tableName] = wr
307300
w.workersLock.Unlock()
308301

309302
w.workersWaitGroup.Add(1)
310-
go worker.run(ctx, &w.workersWaitGroup, tableName)
311-
worker.ch <- m
312-
303+
go wr.run(ctx, &w.workersWaitGroup, tableName)
304+
ch <- m
313305
return nil
314306
case *message.WriteDeleteRecord:
315307
w.workersLock.Lock()
316308
defer w.workersLock.Unlock()
317-
318309
if w.deleteRecordWorker != nil {
319310
w.deleteRecordWorker.ch <- m
320311
return nil
321312
}
322-
313+
ch := make(chan *message.WriteDeleteRecord)
323314
// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
324315
w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{
325-
ch: make(chan *message.WriteDeleteRecord),
316+
ch: ch,
326317
writeFunc: w.client.DeleteRecords,
327318

328319
flush: make(chan chan bool),
@@ -336,7 +327,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
336327
w.workersWaitGroup.Add(1)
337328
go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName)
338329
w.deleteRecordWorker.ch <- m
339-
340330
return nil
341331
default:
342332
return fmt.Errorf("unhandled message type: %T", msg)
@@ -358,40 +348,35 @@ type streamingWorkerManager[T message.WriteMessage] struct {
358348
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
359349
defer wg.Done()
360350
var (
361-
inputCh chan T
362-
outputCh chan error
363-
open bool
351+
clientCh chan T
352+
clientErrCh chan error
353+
open bool
364354
)
365355

366356
ensureOpened := func() {
367357
if open {
368358
return
369359
}
370360

371-
inputCh = make(chan T)
372-
outputCh = make(chan error)
361+
clientCh = make(chan T)
362+
clientErrCh = make(chan error, 1)
373363
go func() {
374-
defer close(outputCh)
364+
defer close(clientErrCh)
375365
defer func() {
376-
if msg := recover(); msg != nil {
377-
switch v := msg.(type) {
378-
case error:
379-
outputCh <- fmt.Errorf("panic: %w [recovered]", v)
380-
default:
381-
outputCh <- fmt.Errorf("panic: %v [recovered]", msg)
382-
}
366+
if err := recover(); err != nil {
367+
clientErrCh <- fmt.Errorf("panic: %v", err)
383368
}
384369
}()
385-
result := s.writeFunc(ctx, inputCh)
386-
outputCh <- result
370+
clientErrCh <- s.writeFunc(ctx, clientCh)
387371
}()
388-
389372
open = true
390373
}
391-
392374
closeFlush := func() {
393375
if open {
394-
close(inputCh)
376+
close(clientCh)
377+
if err := <-clientErrCh; err != nil {
378+
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
379+
}
395380
s.limit.Reset()
396381
}
397382
open = false
@@ -415,7 +400,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
415400
if add != nil {
416401
ensureOpened()
417402
s.limit.AddSlice(add)
418-
inputCh <- any(&message.WriteInsert{Record: add.Record}).(T)
403+
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
419404
}
420405
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
421406
// flush current batch
@@ -425,7 +410,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
425410
for _, sliceToFlush := range toFlush {
426411
ensureOpened()
427412
s.limit.AddRows(sliceToFlush.NumRows())
428-
inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
413+
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
429414
closeFlush()
430415
ticker.Reset(s.batchTimeout)
431416
}
@@ -434,11 +419,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
434419
if rest != nil {
435420
ensureOpened()
436421
s.limit.AddSlice(rest)
437-
inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
422+
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
438423
}
439424
} else {
440425
ensureOpened()
441-
inputCh <- r
426+
clientCh <- r
442427
s.limit.AddRows(1)
443428
if s.limit.ReachedLimit() {
444429
closeFlush()
@@ -456,11 +441,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
456441
ticker.Reset(s.batchTimeout)
457442
}
458443
done <- true
459-
case err := <-outputCh:
460-
if err != nil {
461-
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
462-
return
463-
}
464444
case <-ctxDone:
465445
// this means the request was cancelled
466446
return // after this NO other call will succeed

writers/streamingbatchwriter/streamingbatchwriter_test.go

+22-29
Original file line numberDiff line numberDiff line change
@@ -201,30 +201,20 @@ func TestStreamingBatchSizeRows(t *testing.T) {
201201
ch <- &message.WriteInsert{
202202
Record: record,
203203
}
204+
time.Sleep(50 * time.Millisecond)
204205

205-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
206-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
207-
208-
ch <- &message.WriteInsert{
209-
Record: record,
206+
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
207+
t.Fatalf("expected 0 insert messages, got %d", l)
210208
}
211209

212-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
213-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
214-
215210
ch <- &message.WriteInsert{
216211
Record: record,
217212
}
218-
219-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
220-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
221-
222-
ch <- &message.WriteInsert{
213+
ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one
223214
Record: record,
224215
}
225216

226-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 4)
227-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
217+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
228218

229219
close(ch)
230220
if err := <-errCh; err != nil {
@@ -235,7 +225,7 @@ func TestStreamingBatchSizeRows(t *testing.T) {
235225
t.Fatalf("expected 0 open tables, got %d", l)
236226
}
237227

238-
if l := testClient.MessageLen(messageTypeInsert); l != 4 {
228+
if l := testClient.MessageLen(messageTypeInsert); l != 3 {
239229
t.Fatalf("expected 3 insert messages, got %d", l)
240230
}
241231
}
@@ -263,12 +253,18 @@ func TestStreamingBatchTimeout(t *testing.T) {
263253
ch <- &message.WriteInsert{
264254
Record: record,
265255
}
256+
time.Sleep(50 * time.Millisecond)
266257

267-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
258+
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
259+
t.Fatalf("expected 0 insert messages, got %d", l)
260+
}
268261

269-
time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed
262+
// we need to wait for the batch to be flushed
263+
time.Sleep(time.Millisecond * 50)
270264

271-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
265+
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
266+
t.Fatalf("expected 0 insert messages, got %d", l)
267+
}
272268

273269
// flush
274270
tickFn()
@@ -305,35 +301,32 @@ func TestStreamingBatchNoTimeout(t *testing.T) {
305301
ch <- &message.WriteInsert{
306302
Record: record,
307303
}
304+
time.Sleep(50 * time.Millisecond)
308305

309-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
310-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
306+
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
307+
t.Fatalf("expected 0 insert messages, got %d", l)
308+
}
311309

312310
time.Sleep(2 * time.Second)
313311

314-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
315-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
312+
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
313+
t.Fatalf("expected 0 insert messages, got %d", l)
314+
}
316315

317316
ch <- &message.WriteInsert{
318317
Record: record,
319318
}
320-
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
321-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
322-
323319
ch <- &message.WriteInsert{
324320
Record: record,
325321
}
326322

327323
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
328-
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
329324

330325
close(ch)
331326
if err := <-errCh; err != nil {
332327
t.Fatal(err)
333328
}
334329

335-
time.Sleep(50 * time.Millisecond)
336-
337330
if l := testClient.OpenLen(messageTypeInsert); l != 0 {
338331
t.Fatalf("expected 0 open tables, got %d", l)
339332
}

0 commit comments

Comments
 (0)