Skip to content

Commit

Permalink
Don't register signal in thread (#10610)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Nov 19, 2021
1 parent 5788789 commit 7d3ad5b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


-
- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))


-
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import signal
import sys
import threading
from signal import Signals
from subprocess import call
from types import FrameType, FunctionType
Expand Down Expand Up @@ -46,10 +47,10 @@ def register_signal_handlers(self) -> None:
# signal.SIGUSR1 doesn't seem available on windows
if not self._is_on_windows():
if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1):
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))
self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))

if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))

def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
if self.trainer.is_global_zero:
Expand Down Expand Up @@ -96,3 +97,8 @@ def _has_already_handler(self, signum: Signals) -> bool:
return isinstance(signal.getsignal(signum), FunctionType)
except AttributeError:
return False

@staticmethod
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
if threading.current_thread() is threading.main_thread():
signal.signal(signum, handlers)
14 changes: 14 additions & 0 deletions tests/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import os
import signal
from time import sleep
Expand Down Expand Up @@ -87,3 +88,16 @@ def test_auto_requeue_flag(auto_requeue):
# TODO: should this be done in SignalConnector teardown?
signal.signal(signal.SIGTERM, sigterm_handler_default)
signal.signal(signal.SIGUSR1, sigusr1_handler_default)


def _registering_signals():
trainer = Trainer()
trainer.signal_connector.register_signal_handlers()


@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_signal_connector_in_thread():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]):
assert future.exception() is None

0 comments on commit 7d3ad5b

Please sign in to comment.