@@ -182,28 +182,26 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182
182
errCh := make (chan error )
183
183
defer close (errCh )
184
184
185
- for {
186
- select {
187
- case msg , ok := <- msgs :
188
- if ! ok {
189
- return w .Close (ctx )
190
- }
185
+ go func () {
186
+ for err := range errCh {
187
+ w .logger .Err (err ).Msg ("error from StreamingBatchWriter" )
188
+ }
189
+ }()
191
190
192
- msgType := writers .MsgID (msg )
193
- if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194
- if err := w .Flush (ctx ); err != nil {
195
- return err
196
- }
197
- }
198
- w .lastMsgType = msgType
199
- if err := w .startWorker (ctx , errCh , msg ); err != nil {
191
+ for msg := range msgs {
192
+ msgType := writers .MsgID (msg )
193
+ if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194
+ if err := w .Flush (ctx ); err != nil {
200
195
return err
201
196
}
202
-
203
- case err := <- errCh :
197
+ }
198
+ w .lastMsgType = msgType
199
+ if err := w .startWorker (ctx , errCh , msg ); err != nil {
204
200
return err
205
201
}
206
202
}
203
+
204
+ return w .Close (ctx )
207
205
}
208
206
209
207
func (w * StreamingBatchWriter ) startWorker (ctx context.Context , errCh chan <- error , msg message.WriteMessage ) error {
@@ -223,14 +221,13 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
223
221
case * message.WriteMigrateTable :
224
222
w .workersLock .Lock ()
225
223
defer w .workersLock .Unlock ()
226
-
227
224
if w .migrateWorker != nil {
228
225
w .migrateWorker .ch <- m
229
226
return nil
230
227
}
231
-
228
+ ch := make ( chan * message. WriteMigrateTable )
232
229
w .migrateWorker = & streamingWorkerManager [* message.WriteMigrateTable ]{
233
- ch : make ( chan * message. WriteMigrateTable ) ,
230
+ ch : ch ,
234
231
writeFunc : w .client .MigrateTable ,
235
232
236
233
flush : make (chan chan bool ),
@@ -244,19 +241,17 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
244
241
w .workersWaitGroup .Add (1 )
245
242
go w .migrateWorker .run (ctx , & w .workersWaitGroup , tableName )
246
243
w .migrateWorker .ch <- m
247
-
248
244
return nil
249
245
case * message.WriteDeleteStale :
250
246
w .workersLock .Lock ()
251
247
defer w .workersLock .Unlock ()
252
-
253
248
if w .deleteStaleWorker != nil {
254
249
w .deleteStaleWorker .ch <- m
255
250
return nil
256
251
}
257
-
252
+ ch := make ( chan * message. WriteDeleteStale )
258
253
w .deleteStaleWorker = & streamingWorkerManager [* message.WriteDeleteStale ]{
259
- ch : make ( chan * message. WriteDeleteStale ) ,
254
+ ch : ch ,
260
255
writeFunc : w .client .DeleteStale ,
261
256
262
257
flush : make (chan chan bool ),
@@ -270,29 +265,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
270
265
w .workersWaitGroup .Add (1 )
271
266
go w .deleteStaleWorker .run (ctx , & w .workersWaitGroup , tableName )
272
267
w .deleteStaleWorker .ch <- m
273
-
274
268
return nil
275
269
case * message.WriteInsert :
276
270
w .workersLock .RLock ()
277
- worker , ok := w .insertWorkers [tableName ]
271
+ wr , ok := w .insertWorkers [tableName ]
278
272
w .workersLock .RUnlock ()
279
273
if ok {
280
- worker .ch <- m
274
+ wr .ch <- m
281
275
return nil
282
276
}
283
277
284
- w .workersLock .Lock ()
285
- activeWorker , ok := w .insertWorkers [tableName ]
286
- if ok {
287
- w .workersLock .Unlock ()
288
- // some other goroutine could have already added the worker
289
- // just send the message to it & discard our allocated worker
290
- activeWorker .ch <- m
291
- return nil
292
- }
293
-
294
- worker = & streamingWorkerManager [* message.WriteInsert ]{
295
- ch : make (chan * message.WriteInsert ),
278
+ ch := make (chan * message.WriteInsert )
279
+ wr = & streamingWorkerManager [* message.WriteInsert ]{
280
+ ch : ch ,
296
281
writeFunc : w .client .WriteTable ,
297
282
298
283
flush : make (chan chan bool ),
@@ -302,27 +287,33 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
302
287
batchTimeout : w .batchTimeout ,
303
288
tickerFn : w .tickerFn ,
304
289
}
305
-
306
- w .insertWorkers [tableName ] = worker
290
+ w .workersLock .Lock ()
291
+ wrOld , ok := w .insertWorkers [tableName ]
292
+ if ok {
293
+ w .workersLock .Unlock ()
294
+ // some other goroutine could have already added the worker
295
+ // just send the message to it & discard our allocated worker
296
+ wrOld .ch <- m
297
+ return nil
298
+ }
299
+ w .insertWorkers [tableName ] = wr
307
300
w .workersLock .Unlock ()
308
301
309
302
w .workersWaitGroup .Add (1 )
310
- go worker .run (ctx , & w .workersWaitGroup , tableName )
311
- worker .ch <- m
312
-
303
+ go wr .run (ctx , & w .workersWaitGroup , tableName )
304
+ ch <- m
313
305
return nil
314
306
case * message.WriteDeleteRecord :
315
307
w .workersLock .Lock ()
316
308
defer w .workersLock .Unlock ()
317
-
318
309
if w .deleteRecordWorker != nil {
319
310
w .deleteRecordWorker .ch <- m
320
311
return nil
321
312
}
322
-
313
+ ch := make ( chan * message. WriteDeleteRecord )
323
314
// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
324
315
w .deleteRecordWorker = & streamingWorkerManager [* message.WriteDeleteRecord ]{
325
- ch : make ( chan * message. WriteDeleteRecord ) ,
316
+ ch : ch ,
326
317
writeFunc : w .client .DeleteRecords ,
327
318
328
319
flush : make (chan chan bool ),
@@ -336,7 +327,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
336
327
w .workersWaitGroup .Add (1 )
337
328
go w .deleteRecordWorker .run (ctx , & w .workersWaitGroup , tableName )
338
329
w .deleteRecordWorker .ch <- m
339
-
340
330
return nil
341
331
default :
342
332
return fmt .Errorf ("unhandled message type: %T" , msg )
@@ -358,40 +348,35 @@ type streamingWorkerManager[T message.WriteMessage] struct {
358
348
func (s * streamingWorkerManager [T ]) run (ctx context.Context , wg * sync.WaitGroup , tableName string ) {
359
349
defer wg .Done ()
360
350
var (
361
- inputCh chan T
362
- outputCh chan error
363
- open bool
351
+ clientCh chan T
352
+ clientErrCh chan error
353
+ open bool
364
354
)
365
355
366
356
ensureOpened := func () {
367
357
if open {
368
358
return
369
359
}
370
360
371
- inputCh = make (chan T )
372
- outputCh = make (chan error )
361
+ clientCh = make (chan T )
362
+ clientErrCh = make (chan error , 1 )
373
363
go func () {
374
- defer close (outputCh )
364
+ defer close (clientErrCh )
375
365
defer func () {
376
- if msg := recover (); msg != nil {
377
- switch v := msg .(type ) {
378
- case error :
379
- outputCh <- fmt .Errorf ("panic: %w [recovered]" , v )
380
- default :
381
- outputCh <- fmt .Errorf ("panic: %v [recovered]" , msg )
382
- }
366
+ if err := recover (); err != nil {
367
+ clientErrCh <- fmt .Errorf ("panic: %v" , err )
383
368
}
384
369
}()
385
- result := s .writeFunc (ctx , inputCh )
386
- outputCh <- result
370
+ clientErrCh <- s .writeFunc (ctx , clientCh )
387
371
}()
388
-
389
372
open = true
390
373
}
391
-
392
374
closeFlush := func () {
393
375
if open {
394
- close (inputCh )
376
+ close (clientCh )
377
+ if err := <- clientErrCh ; err != nil {
378
+ s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
379
+ }
395
380
s .limit .Reset ()
396
381
}
397
382
open = false
@@ -415,7 +400,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
415
400
if add != nil {
416
401
ensureOpened ()
417
402
s .limit .AddSlice (add )
418
- inputCh <- any (& message.WriteInsert {Record : add .Record }).(T )
403
+ clientCh <- any (& message.WriteInsert {Record : add .Record }).(T )
419
404
}
420
405
if len (toFlush ) > 0 || rest != nil || s .limit .ReachedLimit () {
421
406
// flush current batch
@@ -425,7 +410,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
425
410
for _ , sliceToFlush := range toFlush {
426
411
ensureOpened ()
427
412
s .limit .AddRows (sliceToFlush .NumRows ())
428
- inputCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
413
+ clientCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
429
414
closeFlush ()
430
415
ticker .Reset (s .batchTimeout )
431
416
}
@@ -434,11 +419,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
434
419
if rest != nil {
435
420
ensureOpened ()
436
421
s .limit .AddSlice (rest )
437
- inputCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
422
+ clientCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
438
423
}
439
424
} else {
440
425
ensureOpened ()
441
- inputCh <- r
426
+ clientCh <- r
442
427
s .limit .AddRows (1 )
443
428
if s .limit .ReachedLimit () {
444
429
closeFlush ()
@@ -456,11 +441,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
456
441
ticker .Reset (s .batchTimeout )
457
442
}
458
443
done <- true
459
- case err := <- outputCh :
460
- if err != nil {
461
- s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
462
- return
463
- }
464
444
case <- ctxDone :
465
445
// this means the request was cancelled
466
446
return // after this NO other call will succeed
0 commit comments