diff --git a/requirements.txt b/requirements.txt index 002352d6..1de74620 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,4 @@ fastapi[all] == 0.103.1 urllib3 >= 1.26.6 # Testing Requirements -nest_asyncio == 1.5.7 +nest_asyncio == 1.5.8 diff --git a/runpod/serverless/utils/rp_upload.py b/runpod/serverless/utils/rp_upload.py index af6ce4cc..f56484ae 100644 --- a/runpod/serverless/utils/rp_upload.py +++ b/runpod/serverless/utils/rp_upload.py @@ -18,11 +18,24 @@ from boto3.s3.transfer import TransferConfig from botocore.config import Config from tqdm_loggable.auto import tqdm +from urllib.parse import urlparse logger = logging.getLogger("runpod upload utility") FMT = "%(filename)-20s:%(lineno)-4d %(asctime)s %(message)s" logging.basicConfig(level=logging.INFO, format=FMT, handlers=[logging.StreamHandler()]) +def extract_region_from_url(endpoint_url): + parsed_url = urlparse(endpoint_url) + # AWS/backblaze S3-like URL + if '.s3.' in endpoint_url: + return endpoint_url.split('.s3.')[1].split('.')[0] + # DigitalOcean Spaces-like URL + elif parsed_url.netloc.endswith('.digitaloceanspaces.com'): + return endpoint_url.split('.')[1].split('.digitaloceanspaces.com')[0] + else: + # Additional cases can be added here + return None + # --------------------------- S3 Bucket Connection --------------------------- # def get_boto_client( @@ -57,12 +70,16 @@ def get_boto_client( secret_access_key = os.environ.get('BUCKET_SECRET_ACCESS_KEY', None) if endpoint_url and access_key_id and secret_access_key: + # Extract region from the endpoint URL + region = extract_region_from_url(endpoint_url) + boto_client = bucket_session.client( 's3', endpoint_url=endpoint_url, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, - config=boto_config + config=boto_config, + region_name=region ) else: boto_client = None diff --git a/tests/test_serverless/test_utils/test_upload.py b/tests/test_serverless/test_utils/test_upload.py index e36f67e6..ef3b0e28 100644 --- a/tests/test_serverless/test_utils/test_upload.py +++ b/tests/test_serverless/test_utils/test_upload.py @@ -58,7 +58,8 @@ def test_get_boto_client(self): endpoint_url=bucket_creds['endpointUrl'], aws_access_key_id=bucket_creds['accessId'], aws_secret_access_key=bucket_creds['accessSecret'], - config=unittest.mock.ANY + config=unittest.mock.ANY, + region_name=None ) def test_get_boto_client_environ(self):