Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions cpp/csp/engine/RootEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
namespace csp
{

static volatile bool g_SIGNALED = false;
static volatile int g_SIGNAL_COUNT = 0;
/*
The signal count variable is maintained to ensure that multiple engine threads shutdown properly.

An interrupt should cause all running engines to stop, but should not affect future runs in the same process.
Thus, each root engine keeps track of the signal count when its created. When an interrupt occurs, one engine thread
handles the interrupt by incrementing the count. Then, all other root engines detect the signal by comparing their
initial count to the current count.

Future runs after the interrupt remain unaffected since they are initialized with the updated signal count, and will
only consider themselves "interupted" if another signal is received during their execution.
*/

static struct sigaction g_prevSIGTERMaction;

static void handle_SIGTERM( int signum )
{
g_SIGNALED = true;
g_SIGNAL_COUNT++;
if( g_prevSIGTERMaction.sa_handler )
(*g_prevSIGTERMaction.sa_handler)( signum );
}
Expand Down Expand Up @@ -58,6 +70,7 @@ RootEngine::RootEngine( const Dictionary & settings ) : Engine( m_cycleStepTable
m_cycleCount( 0 ),
m_settings( settings ),
m_inRealtime( false ),
m_initSignalCount( g_SIGNAL_COUNT ),
m_pushEventQueue( m_settings.queueWaitTime > TimeDelta::ZERO() )
{
if( settings.get<bool>( "profile", false ) )
Expand All @@ -78,7 +91,7 @@ RootEngine::~RootEngine()

bool RootEngine::interrupted() const
{
return g_SIGNALED;
return g_SIGNAL_COUNT != m_initSignalCount;
}

void RootEngine::preRun( DateTime start, DateTime end )
Expand Down Expand Up @@ -131,7 +144,7 @@ void RootEngine::processEndCycle()
void RootEngine::runSim( DateTime end )
{
m_inRealtime = false;
while( m_scheduler.hasEvents() && m_state == State::RUNNING && !g_SIGNALED )
while( m_scheduler.hasEvents() && m_state == State::RUNNING && !interrupted() )
{
m_now = m_scheduler.nextTime();
if( m_now > end )
Expand Down Expand Up @@ -161,7 +174,7 @@ void RootEngine::runRealtime( DateTime end )

m_inRealtime = true;
bool haveEvents = false;
while( m_state == State::RUNNING && !g_SIGNALED )
while( m_state == State::RUNNING && !interrupted() )
{
TimeDelta waitTime;
if( !m_pendingPushEvents.hasEvents() )
Expand Down
1 change: 1 addition & 0 deletions cpp/csp/engine/RootEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class RootEngine : public Engine
PendingPushEvents m_pendingPushEvents;
Settings m_settings;
bool m_inRealtime;
int m_initSignalCount;

PushEventQueue m_pushEventQueue;

Expand Down
19 changes: 9 additions & 10 deletions cpp/csp/python/PyNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,17 @@ void PyNode::start()

void PyNode::stop()
{
PyObjectPtr rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) );
if( !rv.ptr() )
if( this -> rootEngine() -> interrupted() && PyErr_CheckSignals() == -1 )
{
if( PyErr_Occurred() == PyExc_KeyboardInterrupt )
{
PyErr_Clear();
rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) );
}

if( !rv.ptr() )
CSP_THROW( PythonPassthrough, "" );
// When an interrupt occurs a KeyboardInterrupt exception is raised in Python, which we need to clear
// before calling "close" on the generator. Else, the close method will fail due to the unhandled
// exception, and we lose the state of the generator before the "finally" block that calls stop() is executed.
PyErr_Clear();
}

PyObjectPtr rv = PyObjectPtr::own( PyObject_CallMethod( m_gen.ptr(), "close", nullptr ) );
if( !rv.ptr() )
CSP_THROW( PythonPassthrough, "" );
}

PyNode * PyNode::create( PyEngine * pyengine, PyObject * inputs, PyObject * outputs, PyObject * gen )
Expand Down
32 changes: 32 additions & 0 deletions cpp/csp/python/csptestlibimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,45 @@ EXPORT_CPPNODE( start_n2_throw );

}

namespace interrupt_stop_test
{

using namespace csp::python;

void setStatus( const DialectGenericType & obj_, int64_t idx )
{
PyObjectPtr obj = PyObjectPtr::own( toPython( obj_ ) );
PyObjectPtr list = PyObjectPtr::own( PyObject_GetAttrString( obj.get(), "stopped" ) );
PyList_SET_ITEM( list.get(), idx, Py_True );
}

DECLARE_CPPNODE( set_stop_index )
{
INIT_CPPNODE( set_stop_index ) {}

SCALAR_INPUT( DialectGenericType, obj_ );
SCALAR_INPUT( int64_t, idx );

START() {}
INVOKE() {}

STOP()
{
setStatus( obj_, idx );
}
};
EXPORT_CPPNODE( set_stop_index );

}

}

}

// Test nodes
REGISTER_CPPNODE( csp::cppnodes::testing::stop_start_test, start_n1_set_value );
REGISTER_CPPNODE( csp::cppnodes::testing::stop_start_test, start_n2_throw );
REGISTER_CPPNODE( csp::cppnodes::testing::interrupt_stop_test, set_stop_index );

static PyModuleDef _csptestlibimpl_module = {
PyModuleDef_HEAD_INIT,
Expand Down
54 changes: 54 additions & 0 deletions csp/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,60 @@ def g() -> ts[int]:
csp.run(g, starttime=datetime(2020, 1, 1), endtime=timedelta())
self.assertTrue(status["started"] and status["stopped"])

def test_interrupt_stops_all_nodes(self):
@csp.node
def n(l: list, idx: int):
with csp.stop():
l[idx] = True

@csp.node
def raise_interrupt():
with csp.alarms():
a = csp.alarm(bool)
with csp.start():
csp.schedule_alarm(a, timedelta(seconds=1), True)
if csp.ticked(a):
import signal
os.kill(os.getpid(), signal.SIGINT)

# Python nodes
@csp.graph
def g(l: list):
n(l, 0)
n(l, 1)
n(l, 2)
raise_interrupt()

stopped = [False, False, False]
with self.assertRaises(KeyboardInterrupt):
csp.run(g, stopped, starttime=datetime.utcnow(), endtime=timedelta(seconds=60), realtime=True)

for element in stopped:
self.assertTrue(element)

# C++ nodes
class RTI:
def __init__(self):
self.stopped = [False, False, False]

@csp.node(cppimpl=_csptestlibimpl.set_stop_index)
def n2(obj_: object, idx: int):
return

@csp.graph
def g2(rti: RTI):
n2(rti, 0)
n2(rti, 1)
n2(rti, 2)
raise_interrupt()

rti = RTI()
with self.assertRaises(KeyboardInterrupt):
csp.run(g2, rti, starttime=datetime.utcnow(), endtime=timedelta(seconds=60), realtime=True)

for element in rti.stopped:
self.assertTrue(element)


if __name__ == "__main__":
unittest.main()