Skip to content

Commit de316f0

Browse files
authored
Merge pull request #1621 from brian-team/interrupt_hook
Interrupt hook (graceful stop with Ctrl+C)
2 parents 9d675f4 + d08fadf commit de316f0

File tree

6 files changed

+84
-4
lines changed

6 files changed

+84
-4
lines changed

brian2/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import logging
6+
import signal
67

78

89
def _check_dependencies():
@@ -225,3 +226,37 @@ def _check_caches():
225226

226227

227228
_check_caches()
229+
230+
231+
class _InterruptHandler:
232+
"""
233+
Class to turn a Ctrl+C interruption (SIGINT signal) into a `stop` signal for
234+
a running simulation (i.e., finish simulating the current time step and then
235+
stop). This handler is activated by default, but can be switched off by
236+
setting the `core.stop_on_keyboard_interrupt` preference to ``False``.
237+
Note that this will only handle interruptions during a `Network.run`,
238+
interrupting at any other time will raise a `KeyboardInterrupt` in the
239+
usual way. In case that finishing the current time step takes a long time
240+
(or hangs for some reason), interrupting with Ctrl+C a second time will
241+
force the usual interrupt, regardless of the preference setting.
242+
"""
243+
244+
def __init__(self, previous_handler):
245+
self.previous_handler = previous_handler
246+
247+
def __call__(self, signalnum, stack_frame):
248+
if (
249+
not prefs.core.stop_on_keyboard_interrupt
250+
or not Network._globally_running
251+
or Network._globally_stopped
252+
):
253+
self.previous_handler(signalnum, stack_frame)
254+
else:
255+
logging.getLogger("brian2").warning(
256+
"Simulation stop requested. Press Ctrl+C again to interrupt."
257+
)
258+
Network._globally_stopped = True
259+
260+
261+
_int_handler = _InterruptHandler(signal.getsignal(signal.SIGINT))
262+
signal.signal(signal.SIGINT, _int_handler)

brian2/core/core_preferences.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def default_float_dtype_validator(dtype):
4343
a warning (``False``).
4444
""",
4545
),
46+
stop_on_keyboard_interrupt=BrianPreference(
47+
default=True,
48+
docs="""
49+
Whether to "gracefully" stop a simulation after pressing Ctrl+C (defaults to
50+
``True``). Note that pressing Ctrl+C a second time will force the usual
51+
interruption mechanism.
52+
""",
53+
),
4654
)
4755

4856
prefs.register_preferences(

brian2/core/network.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def profiling_info(self):
484484
return self.get_profiling_info()
485485

486486
_globally_stopped = False
487+
_globally_running = False
487488

488489
def __getitem__(self, item):
489490
if not isinstance(item, str):
@@ -1197,6 +1198,7 @@ def run(
11971198

11981199
active_objects = [obj for obj in all_objects if obj.active]
11991200

1201+
Network._globally_running = True
12001202
while running and not self._stopped and not Network._globally_stopped:
12011203
if not single_clock:
12021204
timestep, t, dt = self._clock_variables[clock]
@@ -1258,6 +1260,7 @@ def run(
12581260
running = timestep[0] < clock._i_end
12591261

12601262
end_time = time.time()
1263+
Network._globally_running = False
12611264
if self._stopped or Network._globally_stopped:
12621265
self.t_ = clock.t_
12631266
else:
@@ -1275,7 +1278,12 @@ def run(
12751278
obj._check_for_invalid_states()
12761279

12771280
if report is not None:
1278-
report_callback((end_time - start_time) * second, 1.0, t_start, duration)
1281+
report_callback(
1282+
(end_time - start_time) * second,
1283+
device._last_run_completed_fraction,
1284+
t_start,
1285+
duration,
1286+
)
12791287
self.after_run()
12801288

12811289
logger.debug(

brian2/devices/cpp_standalone/device.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from brian2.codegen.generators.cpp_generator import c_data_type
2525
from brian2.core.functions import Function
2626
from brian2.core.namespace import get_local_namespace
27+
from brian2.core.network import Network
2728
from brian2.core.preferences import BrianPreference, prefs
2829
from brian2.core.variables import (
2930
ArrayVariable,
@@ -1287,15 +1288,19 @@ def run(
12871288
stdout = None
12881289
if os.name == "nt":
12891290
start_time = time.time()
1291+
Network._globally_running = True
12901292
x = subprocess.call(["main"] + run_args, stdout=stdout)
12911293
self.timers["run_binary"] = time.time() - start_time
1294+
Network._globally_running = False
12921295
else:
12931296
run_cmd = prefs.devices.cpp_standalone.run_cmd_unix
12941297
if isinstance(run_cmd, str):
12951298
run_cmd = [run_cmd]
12961299
start_time = time.time()
1300+
Network._globally_running = True
12971301
x = subprocess.call(run_cmd + run_args, stdout=stdout)
12981302
self.timers["run_binary"] = time.time() - start_time
1303+
Network._globally_running = False
12991304
if stdout is not None:
13001305
stdout.close()
13011306
if x:

brian2/devices/cpp_standalone/templates/main.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <stdlib.h>
22
#include "objects.h"
3+
#include <csignal>
34
#include <ctime>
45
#include <time.h>
56
{{ openmp_pragma('include') }}
@@ -33,8 +34,21 @@ void set_from_command_line(const std::vector<std::string> args)
3334
brian::set_variable_by_name(name, value);
3435
}
3536
}
37+
38+
void _int_handler(int signal_num) {
39+
if (Network::_globally_running and !Network::_globally_stopped) {
40+
Network::_globally_stopped = true;
41+
} else {
42+
std::signal(signal_num, SIG_DFL);
43+
std::raise(signal_num);
44+
}
45+
}
46+
3647
int main(int argc, char **argv)
3748
{
49+
{% if prefs.core.stop_on_keyboard_interrupt %}
50+
std::signal(SIGINT, _int_handler);
51+
{% endif %}
3852
std::random_device _rd;
3953
std::vector<std::string> args(argv + 1, argv + argc);
4054
if (args.size() >=2 && args[0] == "--results_dir")

brian2/devices/cpp_standalone/templates/network.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
double Network::_last_run_time = 0.0;
1414
double Network::_last_run_completed_fraction = 0.0;
15+
bool Network::_globally_stopped = false;
16+
bool Network::_globally_running = false;
1517

1618
Network::Network()
1719
{
@@ -63,7 +65,9 @@ void Network::run(const double duration, void (*report_func)(const double, const
6365
double elapsed_realtime;
6466
bool did_break_early = false;
6567

66-
while(clock && clock->running())
68+
Network::_globally_running = true;
69+
Network::_globally_stopped = false;
70+
while(clock && clock->running() && !Network::_globally_stopped)
6771
{
6872
t = clock->t[0];
6973

@@ -112,8 +116,12 @@ void Network::run(const double duration, void (*report_func)(const double, const
112116
{% endif %}
113117

114118
}
119+
Network::_globally_running = false;
115120

116-
if(!did_break_early) t = t_end;
121+
if(!did_break_early && !Network::_globally_stopped)
122+
t = t_end;
123+
else
124+
t = clock->t[0];
117125

118126
_last_run_time = elapsed_realtime;
119127
if(duration>0)
@@ -124,7 +132,7 @@ void Network::run(const double duration, void (*report_func)(const double, const
124132
}
125133
if (report_func)
126134
{
127-
report_func(elapsed_realtime, 1.0, t_start, duration);
135+
report_func(elapsed_realtime, _last_run_completed_fraction, t_start, duration);
128136
}
129137
}
130138

@@ -189,6 +197,8 @@ class Network
189197
double t;
190198
static double _last_run_time;
191199
static double _last_run_completed_fraction;
200+
static bool _globally_stopped;
201+
static bool _globally_running;
192202

193203
Network();
194204
void clear();

0 commit comments

Comments
 (0)