1111from tempfile import TemporaryFile
1212from typing import Optional
1313from typing import TextIO
14+ from typing import Tuple
1415
1516import pytest
1617from _pytest .compat import TYPE_CHECKING
@@ -245,7 +246,6 @@ class NoCapture:
245246class SysCaptureBinary :
246247
247248 EMPTY_BUFFER = b""
248- _state = None
249249
250250 def __init__ (self , fd , tmpfile = None , * , tee = False ):
251251 name = patchsysdict [fd ]
@@ -257,6 +257,7 @@ def __init__(self, fd, tmpfile=None, *, tee=False):
257257 else :
258258 tmpfile = CaptureIO () if not tee else TeeCaptureIO (self ._old )
259259 self .tmpfile = tmpfile
260+ self ._state = "initialized"
260261
261262 def repr (self , class_name : str ) -> str :
262263 return "<{} {} _old={} _state={!r} tmpfile={!r}>" .format (
@@ -276,32 +277,49 @@ def __repr__(self) -> str:
276277 self .tmpfile ,
277278 )
278279
280+ def _assert_state (self , op : str , states : Tuple [str , ...]) -> None :
281+ assert (
282+ self ._state in states
283+ ), "cannot {} in state {!r}: expected one of {}" .format (
284+ op , self ._state , ", " .join (states )
285+ )
286+
279287 def start (self ):
288+ self ._assert_state ("start" , ("initialized" ,))
280289 setattr (sys , self .name , self .tmpfile )
281290 self ._state = "started"
282291
283292 def snap (self ):
293+ self ._assert_state ("snap" , ("started" , "suspended" ))
284294 self .tmpfile .seek (0 )
285295 res = self .tmpfile .buffer .read ()
286296 self .tmpfile .seek (0 )
287297 self .tmpfile .truncate ()
288298 return res
289299
290300 def done (self ):
301+ self ._assert_state ("done" , ("initialized" , "started" , "suspended" , "done" ))
302+ if self ._state == "done" :
303+ return
291304 setattr (sys , self .name , self ._old )
292305 del self ._old
293306 self .tmpfile .close ()
294307 self ._state = "done"
295308
296309 def suspend (self ):
310+ self ._assert_state ("suspend" , ("started" , "suspended" ))
297311 setattr (sys , self .name , self ._old )
298312 self ._state = "suspended"
299313
300314 def resume (self ):
315+ self ._assert_state ("resume" , ("started" , "suspended" ))
316+ if self ._state == "started" :
317+ return
301318 setattr (sys , self .name , self .tmpfile )
302- self ._state = "resumed "
319+ self ._state = "started "
303320
304321 def writeorg (self , data ):
322+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
305323 self ._old .flush ()
306324 self ._old .buffer .write (data )
307325 self ._old .buffer .flush ()
@@ -317,6 +335,7 @@ def snap(self):
317335 return res
318336
319337 def writeorg (self , data ):
338+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
320339 self ._old .write (data )
321340 self ._old .flush ()
322341
@@ -328,7 +347,6 @@ class FDCaptureBinary:
328347 """
329348
330349 EMPTY_BUFFER = b""
331- _state = None
332350
333351 def __init__ (self , targetfd ):
334352 self .targetfd = targetfd
@@ -368,6 +386,8 @@ def __init__(self, targetfd):
368386 else :
369387 self .syscapture = NoCapture ()
370388
389+ self ._state = "initialized"
390+
371391 def __repr__ (self ):
372392 return "<{} {} oldfd={} _state={!r} tmpfile={!r}>" .format (
373393 self .__class__ .__name__ ,
@@ -377,13 +397,22 @@ def __repr__(self):
377397 self .tmpfile ,
378398 )
379399
400+ def _assert_state (self , op : str , states : Tuple [str , ...]) -> None :
401+ assert (
402+ self ._state in states
403+ ), "cannot {} in state {!r}: expected one of {}" .format (
404+ op , self ._state , ", " .join (states )
405+ )
406+
380407 def start (self ):
381408 """ Start capturing on targetfd using memorized tmpfile. """
409+ self ._assert_state ("start" , ("initialized" ,))
382410 os .dup2 (self .tmpfile .fileno (), self .targetfd )
383411 self .syscapture .start ()
384412 self ._state = "started"
385413
386414 def snap (self ):
415+ self ._assert_state ("snap" , ("started" , "suspended" ))
387416 self .tmpfile .seek (0 )
388417 res = self .tmpfile .buffer .read ()
389418 self .tmpfile .seek (0 )
@@ -393,6 +422,9 @@ def snap(self):
393422 def done (self ):
394423 """ stop capturing, restore streams, return original capture file,
395424 seeked to position zero. """
425+ self ._assert_state ("done" , ("initialized" , "started" , "suspended" , "done" ))
426+ if self ._state == "done" :
427+ return
396428 os .dup2 (self .targetfd_save , self .targetfd )
397429 os .close (self .targetfd_save )
398430 if self .targetfd_invalid is not None :
@@ -404,17 +436,24 @@ def done(self):
404436 self ._state = "done"
405437
406438 def suspend (self ):
439+ self ._assert_state ("suspend" , ("started" , "suspended" ))
440+ if self ._state == "suspended" :
441+ return
407442 self .syscapture .suspend ()
408443 os .dup2 (self .targetfd_save , self .targetfd )
409444 self ._state = "suspended"
410445
411446 def resume (self ):
447+ self ._assert_state ("resume" , ("started" , "suspended" ))
448+ if self ._state == "started" :
449+ return
412450 self .syscapture .resume ()
413451 os .dup2 (self .tmpfile .fileno (), self .targetfd )
414- self ._state = "resumed "
452+ self ._state = "started "
415453
416454 def writeorg (self , data ):
417455 """ write to original file descriptor. """
456+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
418457 os .write (self .targetfd_save , data )
419458
420459
@@ -428,6 +467,7 @@ class FDCapture(FDCaptureBinary):
428467 EMPTY_BUFFER = "" # type: ignore
429468
430469 def snap (self ):
470+ self ._assert_state ("snap" , ("started" , "suspended" ))
431471 self .tmpfile .seek (0 )
432472 res = self .tmpfile .read ()
433473 self .tmpfile .seek (0 )
0 commit comments