diff --git a/cpp/csp/engine/RootEngine.cpp b/cpp/csp/engine/RootEngine.cpp index a65f3a5d1..2ddc0eaca 100644 --- a/cpp/csp/engine/RootEngine.cpp +++ b/cpp/csp/engine/RootEngine.cpp @@ -11,12 +11,24 @@ namespace csp { -static volatile bool g_SIGNALED = false; +static volatile int g_SIGNAL_COUNT = 0; +/* +The signal count variable is maintained to ensure that multiple engine threads shutdown properly. + +An interrupt should cause all running engines to stop, but should not affect future runs in the same process. +Thus, each root engine keeps track of the signal count when its created. When an interrupt occurs, one engine thread +handles the interrupt by incrementing the count. Then, all other root engines detect the signal by comparing their +initial count to the current count. + +Future runs after the interrupt remain unaffected since they are initialized with the updated signal count, and will +only consider themselves "interupted" if another signal is received during their execution. +*/ + static struct sigaction g_prevSIGTERMaction; static void handle_SIGTERM( int signum ) { - g_SIGNALED = true; + g_SIGNAL_COUNT++; if( g_prevSIGTERMaction.sa_handler ) (*g_prevSIGTERMaction.sa_handler)( signum ); } @@ -58,6 +70,7 @@ RootEngine::RootEngine( const Dictionary & settings ) : Engine( m_cycleStepTable m_cycleCount( 0 ), m_settings( settings ), m_inRealtime( false ), + m_initSignalCount( g_SIGNAL_COUNT ), m_pushEventQueue( m_settings.queueWaitTime > TimeDelta::ZERO() ) { if( settings.get( "profile", false ) ) @@ -78,7 +91,7 @@ RootEngine::~RootEngine() bool RootEngine::interrupted() const { - return g_SIGNALED; + return g_SIGNAL_COUNT != m_initSignalCount; } void RootEngine::preRun( DateTime start, DateTime end ) @@ -131,7 +144,7 @@ void RootEngine::processEndCycle() void RootEngine::runSim( DateTime end ) { m_inRealtime = false; - while( m_scheduler.hasEvents() && m_state == State::RUNNING && !g_SIGNALED ) + while( m_scheduler.hasEvents() && m_state == State::RUNNING && !interrupted() ) { m_now = m_scheduler.nextTime(); if( m_now > end ) @@ -161,7 +174,7 @@ void RootEngine::runRealtime( DateTime end ) m_inRealtime = true; bool haveEvents = false; - while( m_state == State::RUNNING && !g_SIGNALED ) + while( m_state == State::RUNNING && !interrupted() ) { TimeDelta waitTime; if( !m_pendingPushEvents.hasEvents() ) diff --git a/cpp/csp/engine/RootEngine.h b/cpp/csp/engine/RootEngine.h index e3999d572..746704580 100644 --- a/cpp/csp/engine/RootEngine.h +++ b/cpp/csp/engine/RootEngine.h @@ -127,6 +127,7 @@ class RootEngine : public Engine PendingPushEvents m_pendingPushEvents; Settings m_settings; bool m_inRealtime; + int m_initSignalCount; PushEventQueue m_pushEventQueue; diff --git a/cpp/csp/python/PyNode.cpp b/cpp/csp/python/PyNode.cpp index ccad4e715..ba4ebbd7d 100644 --- a/cpp/csp/python/PyNode.cpp +++ b/cpp/csp/python/PyNode.cpp @@ -212,18 +212,17 @@ void PyNode::start() void PyNode::stop() { - PyObjectPtr rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) ); - if( !rv.ptr() ) + if( this -> rootEngine() -> interrupted() && PyErr_CheckSignals() == -1 ) { - if( PyErr_Occurred() == PyExc_KeyboardInterrupt ) - { - PyErr_Clear(); - rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) ); - } - - if( !rv.ptr() ) - CSP_THROW( PythonPassthrough, "" ); + // When an interrupt occurs a KeyboardInterrupt exception is raised in Python, which we need to clear + // before calling "close" on the generator. Else, the close method will fail due to the unhandled + // exception, and we lose the state of the generator before the "finally" block that calls stop() is executed. + PyErr_Clear(); } + + PyObjectPtr rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) ); + if( !rv.ptr() ) + CSP_THROW( PythonPassthrough, "" ); } PyNode * PyNode::create( PyEngine * pyengine, PyObject * inputs, PyObject * outputs, PyObject * gen ) diff --git a/cpp/csp/python/csptestlibimpl.cpp b/cpp/csp/python/csptestlibimpl.cpp index 23bd299f5..f5a750bf8 100644 --- a/cpp/csp/python/csptestlibimpl.cpp +++ b/cpp/csp/python/csptestlibimpl.cpp @@ -66,6 +66,37 @@ EXPORT_CPPNODE( start_n2_throw ); } +namespace interrupt_stop_test +{ + +using namespace csp::python; + +void setStatus( const DialectGenericType & obj_, int64_t idx ) +{ + PyObjectPtr obj = PyObjectPtr::own( toPython( obj_ ) ); + PyObjectPtr list = PyObjectPtr::own( PyObject_GetAttrString( obj.get(), "stopped" ) ); + PyList_SET_ITEM( list.get(), idx, Py_True ); +} + +DECLARE_CPPNODE( set_stop_index ) +{ + INIT_CPPNODE( set_stop_index ) {} + + SCALAR_INPUT( DialectGenericType, obj_ ); + SCALAR_INPUT( int64_t, idx ); + + START() {} + INVOKE() {} + + STOP() + { + setStatus( obj_, idx ); + } +}; +EXPORT_CPPNODE( set_stop_index ); + +} + } } @@ -73,6 +104,7 @@ EXPORT_CPPNODE( start_n2_throw ); // Test nodes REGISTER_CPPNODE( csp::cppnodes::testing::stop_start_test, start_n1_set_value ); REGISTER_CPPNODE( csp::cppnodes::testing::stop_start_test, start_n2_throw ); +REGISTER_CPPNODE( csp::cppnodes::testing::interrupt_stop_test, set_stop_index ); static PyModuleDef _csptestlibimpl_module = { PyModuleDef_HEAD_INIT, diff --git a/csp/tests/test_engine.py b/csp/tests/test_engine.py index 3ae54afbb..028b5d6f9 100644 --- a/csp/tests/test_engine.py +++ b/csp/tests/test_engine.py @@ -2064,6 +2064,60 @@ def g() -> ts[int]: csp.run(g, starttime=datetime(2020, 1, 1), endtime=timedelta()) self.assertTrue(status["started"] and status["stopped"]) + def test_interrupt_stops_all_nodes(self): + @csp.node + def n(l: list, idx: int): + with csp.stop(): + l[idx] = True + + @csp.node + def raise_interrupt(): + with csp.alarms(): + a = csp.alarm(bool) + with csp.start(): + csp.schedule_alarm(a, timedelta(seconds=1), True) + if csp.ticked(a): + import signal + os.kill(os.getpid(), signal.SIGINT) + + # Python nodes + @csp.graph + def g(l: list): + n(l, 0) + n(l, 1) + n(l, 2) + raise_interrupt() + + stopped = [False, False, False] + with self.assertRaises(KeyboardInterrupt): + csp.run(g, stopped, starttime=datetime.utcnow(), endtime=timedelta(seconds=60), realtime=True) + + for element in stopped: + self.assertTrue(element) + + # C++ nodes + class RTI: + def __init__(self): + self.stopped = [False, False, False] + + @csp.node(cppimpl=_csptestlibimpl.set_stop_index) + def n2(obj_: object, idx: int): + return + + @csp.graph + def g2(rti: RTI): + n2(rti, 0) + n2(rti, 1) + n2(rti, 2) + raise_interrupt() + + rti = RTI() + with self.assertRaises(KeyboardInterrupt): + csp.run(g2, rti, starttime=datetime.utcnow(), endtime=timedelta(seconds=60), realtime=True) + + for element in rti.stopped: + self.assertTrue(element) + if __name__ == "__main__": unittest.main()