Skip to content

Commit 5b9bd55

Browse files
Merge pull request #135 from runpod/main
updating branch
2 parents 2d09dca + a9e3ed8 commit 5b9bd55

File tree

4 files changed

+63
-8
lines changed

4 files changed

+63
-8
lines changed

runpod/serverless/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
import json
1010
import time
11+
import signal
1112
import argparse
1213
from typing import Dict, Any
1314

@@ -82,6 +83,13 @@ def _get_realtime_concurrency() -> int:
8283
"""
8384
return int(os.environ.get("RUNPOD_REALTIME_CONCURRENCY", "1"))
8485

86+
def _signal_handler(sig, frame):
87+
"""
88+
Handles the SIGINT signal.
89+
"""
90+
del sig, frame
91+
log.info("SIGINT received. Shutting down.")
92+
sys.exit(0)
8593

8694
# ---------------------------------------------------------------------------- #
8795
# Start Serverless Worker #
@@ -100,6 +108,8 @@ def start(config: Dict[str, Any]):
100108
from runpod import __version__ as runpod_version # pylint: disable=import-outside-toplevel,cyclic-import
101109
print(f"--- Starting Serverless Worker | Version {runpod_version} ---")
102110

111+
signal.signal(signal.SIGINT, _signal_handler)
112+
103113
config["reference_counter_start"] = time.perf_counter()
104114
config = _set_config_args(config)
105115

runpod/serverless/utils/rp_upload.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import threading
1111
import multiprocessing
1212
from io import BytesIO
13+
from urllib.parse import urlparse
1314
from typing import Optional, Tuple
1415

1516
import boto3
@@ -18,24 +19,27 @@
1819
from boto3.s3.transfer import TransferConfig
1920
from botocore.config import Config
2021
from tqdm_loggable.auto import tqdm
21-
from urllib.parse import urlparse
22+
2223

2324
logger = logging.getLogger("runpod upload utility")
2425
FMT = "%(filename)-20s:%(lineno)-4d %(asctime)s %(message)s"
2526
logging.basicConfig(level=logging.INFO, format=FMT, handlers=[logging.StreamHandler()])
2627

2728
def extract_region_from_url(endpoint_url):
29+
"""
30+
Extracts the region from the endpoint URL.
31+
"""
2832
parsed_url = urlparse(endpoint_url)
2933
# AWS/backblaze S3-like URL
3034
if '.s3.' in endpoint_url:
3135
return endpoint_url.split('.s3.')[1].split('.')[0]
36+
3237
# DigitalOcean Spaces-like URL
33-
elif parsed_url.netloc.endswith('.digitaloceanspaces.com'):
38+
if parsed_url.netloc.endswith('.digitaloceanspaces.com'):
3439
return endpoint_url.split('.')[1].split('.digitaloceanspaces.com')[0]
35-
else:
36-
# Additional cases can be added here
37-
return None
38-
40+
41+
return None
42+
3943

4044
# --------------------------- S3 Bucket Connection --------------------------- #
4145
def get_boto_client(
@@ -72,7 +76,7 @@ def get_boto_client(
7276
if endpoint_url and access_key_id and secret_access_key:
7377
# Extract region from the endpoint URL
7478
region = extract_region_from_url(endpoint_url)
75-
79+
7680
boto_client = bucket_session.client(
7781
's3',
7882
endpoint_url=endpoint_url,

tests/test_serverless/test_utils/test_upload.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,36 @@ def test_get_boto_client(self):
6262
region_name=None
6363
)
6464

65+
creds_s3 = bucket_creds.copy()
66+
creds_s3['endpointUrl'] = "https://bucket-name.s3.region-code.amazonaws.com/key-name"
67+
68+
boto_client, transfer_config = get_boto_client(creds_s3)
69+
70+
mock_session.return_value.client.assert_called_with(
71+
's3',
72+
endpoint_url=creds_s3['endpointUrl'],
73+
aws_access_key_id=bucket_creds['accessId'],
74+
aws_secret_access_key=bucket_creds['accessSecret'],
75+
config=unittest.mock.ANY,
76+
region_name="region-code"
77+
)
78+
79+
creds_do = bucket_creds.copy()
80+
creds_do['endpointUrl'] = "https://name.region-code.digitaloceanspaces.com/key-name"
81+
82+
boto_client, transfer_config = get_boto_client(creds_do)
83+
84+
mock_session.return_value.client.assert_called_with(
85+
's3',
86+
endpoint_url=creds_do['endpointUrl'],
87+
aws_access_key_id=bucket_creds['accessId'],
88+
aws_secret_access_key=bucket_creds['accessSecret'],
89+
config=unittest.mock.ANY,
90+
region_name="region-code"
91+
)
92+
93+
94+
6595
def test_get_boto_client_environ(self):
6696
'''
6797
Tests get_boto_client with environment variables

tests/test_serverless/test_worker.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import runpod
1111
from runpod.serverless.modules.rp_logger import RunPodLogger
12-
12+
from runpod.serverless import _signal_handler
1313

1414
nest_asyncio.apply()
1515

@@ -93,6 +93,17 @@ def test_local_api(self):
9393

9494
assert mock_fastapi.WorkerAPI.called
9595

96+
@patch('runpod.serverless.log')
97+
@patch('runpod.serverless.sys.exit')
98+
def test_signal_handler(self, mock_exit, mock_logger):
99+
'''
100+
Test signal handler.
101+
'''
102+
103+
_signal_handler(None, None)
104+
105+
assert mock_exit.called
106+
assert mock_logger.info.called
96107

97108
class TestWorkerTestInput(IsolatedAsyncioTestCase):
98109
""" Tests for runpod | serverless| worker """

0 commit comments

Comments
 (0)