Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ click >= 8.1.7
colorama >= 0.4.6
fastapi[all] >= 0.94.0
paramiko >= 3.3.1
pillow >= 9.5.0
prettytable >= 3.8.0
prettytable >= 3.9.0
py-cpuinfo >= 9.0.0
python-dotenv >= 1.0.0
inquirerpy == 0.3.4
requests >= 2.31.0
tomli >= 2.0.1
tomlkit >= 0.12.1
tomlkit >= 0.12.2
tqdm-loggable >= 0.1.4
urllib3 >= 1.26.6
setuptools_scm >= 8.0.4
Expand Down
43 changes: 14 additions & 29 deletions runpod/serverless/utils/rp_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
import logging
import threading
import multiprocessing
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Tuple

import boto3
from PIL import Image, UnidentifiedImageError
from boto3 import session
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
Expand All @@ -25,6 +23,7 @@
FMT = "%(filename)-20s:%(lineno)-4d %(asctime)s %(message)s"
logging.basicConfig(level=logging.INFO, format=FMT, handlers=[logging.StreamHandler()])


def extract_region_from_url(endpoint_url):
"""
Extracts the region from the endpoint URL.
Expand All @@ -43,7 +42,7 @@ def extract_region_from_url(endpoint_url):

# --------------------------- S3 Bucket Connection --------------------------- #
def get_boto_client(
bucket_creds: Optional[dict] = None) -> Tuple[boto3.client, TransferConfig]: # pragma: no cover
bucket_creds: Optional[dict] = None) -> Tuple[boto3.client, TransferConfig]: # pragma: no cover # pylint: disable=line-too-long
'''
Returns a boto3 client and transfer config for the bucket.
'''
Expand Down Expand Up @@ -94,49 +93,35 @@ def get_boto_client(
# ---------------------------------------------------------------------------- #
# Upload Image #
# ---------------------------------------------------------------------------- #
def upload_image(job_id, image_location, result_index=0, results_list=None): # pragma: no cover
def upload_image(job_id, image_location, result_index=0, results_list=None): # pragma: no cover
'''
Upload a single file to bucket storage.
'''
image_name = str(uuid.uuid4())[:8]
boto_client, _ = get_boto_client()
file_extension = os.path.splitext(image_location)[1]
content_type = "image/" + file_extension.lstrip(".")

with open(image_location, "rb") as input_file:
output = input_file.read()

if boto_client is None:
# Save the output to a file
print("No bucket endpoint set, saving to disk folder 'simulated_uploaded'")
print("If this is a live endpoint, please reference the following:")
print("https://github.com/runpod/runpod-python/blob/main/docs/serverless/utils/rp_upload.md") # pylint: disable=line-too-long
print("https://github.com/runpod/runpod-python/blob/main/docs/serverless/utils/rp_upload.md") # pylint: disable=line-too-long

os.makedirs("simulated_uploaded", exist_ok=True)
sim_upload_location = f"simulated_uploaded/{image_name}{file_extension}"
try:
with Image.open(image_location) as img, open(sim_upload_location, "wb") as file_output:
img.save(file_output, format=img.format)

except UnidentifiedImageError:
# If the file is not an image, save it directly
shutil.copy(image_location, sim_upload_location)
with open(sim_upload_location, "wb") as file_output:
file_output.write(output)

if results_list is not None:
results_list[result_index] = sim_upload_location

return sim_upload_location

try:
with Image.open(image_location) as img:
output = BytesIO()
img.save(output, format=img.format)
output.seek(0)
content_type = "image/" + file_extension.lstrip(".")

except UnidentifiedImageError:
# If the file is not an image, read it directly
with open(image_location, "rb") as f:
output = f.read()
content_type = "application/octet-stream"


bucket = time.strftime('%m-%y')
boto_client.put_object(
Bucket=f'{bucket}',
Expand All @@ -161,7 +146,7 @@ def upload_image(job_id, image_location, result_index=0, results_list=None): # p
# ---------------------------------------------------------------------------- #
# Files To Upload #
# ---------------------------------------------------------------------------- #
def files(job_id, file_list): # pragma: no cover
def files(job_id, file_list): # pragma: no cover
'''
Uploads a list of files in parallel.
Once all files are uploaded, the function returns the presigned URLs list.
Expand All @@ -186,7 +171,7 @@ def files(job_id, file_list): # pragma: no cover


# --------------------------- Custom Bucket Upload --------------------------- #
def bucket_upload(job_id, file_list, bucket_creds): # pragma: no cover
def bucket_upload(job_id, file_list, bucket_creds): # pragma: no cover
'''
Uploads files to bucket storage.
'''
Expand Down Expand Up @@ -231,7 +216,7 @@ def upload_file_to_bucket(
bucket_name: Optional[str] = None,
prefix: Optional[str] = None,
extra_args: Optional[dict] = None
) -> str: # pragma: no cover
) -> str: # pragma: no cover
'''
Uploads a single file to bucket storage and returns a presigned URL.
'''
Expand Down Expand Up @@ -283,7 +268,7 @@ def upload_in_memory_object(
file_name: str, file_data: bytes,
bucket_creds: Optional[dict] = None,
bucket_name: Optional[str] = None,
prefix: Optional[str] = None) -> str: # pragma: no cover
prefix: Optional[str] = None) -> str: # pragma: no cover
'''
Uploads an in-memory object (bytes) to bucket storage and returns a presigned URL.
'''
Expand Down
39 changes: 13 additions & 26 deletions tests/test_serverless/test_utils/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_boto_client(self):

# Mock boto3.session.Session
with patch('boto3.session.Session') as mock_session, \
patch('runpod.serverless.utils.rp_upload.TransferConfig') as mock_transfer_config:
patch('runpod.serverless.utils.rp_upload.TransferConfig') as mock_transfer_config:
mock_session.return_value.client.return_value = self.mock_boto_client
mock_transfer_config.return_value = self.mock_transfer_config

Expand Down Expand Up @@ -90,8 +90,6 @@ def test_get_boto_client(self):
region_name="region-code"
)



def test_get_boto_client_environ(self):
'''
Tests get_boto_client with environment variables
Expand All @@ -105,7 +103,7 @@ def test_get_boto_client_environ(self):
importlib.reload(rp_upload)

with patch('boto3.session.Session') as mock_session, \
patch('runpod.serverless.utils.rp_upload.TransferConfig') as mock_transfer_config:
patch('runpod.serverless.utils.rp_upload.TransferConfig') as mock_transfer_config:
mock_session.return_value.client.return_value = self.mock_boto_client
mock_transfer_config.return_value = self.mock_transfer_config

Expand All @@ -117,39 +115,35 @@ def test_get_boto_client_environ(self):
# ---------------------------------------------------------------------------- #
# Upload Image #
# ---------------------------------------------------------------------------- #


class TestUploadImage(unittest.TestCase):
''' Tests for upload_image '''

@patch("runpod.serverless.utils.rp_upload.get_boto_client")
@patch("runpod.serverless.utils.rp_upload.Image.open")
@patch("builtins.open")
@patch("runpod.serverless.utils.rp_upload.os.makedirs")
def test_upload_image_local(self, mock_makedirs, mock_img_open, mock_get_boto_client):
def test_upload_image_local(self, mock_makedirs, mock_open, mock_get_boto_client):
'''
Test upload_image function when there is no boto client
'''
# Mocking get_boto_client to return None
mock_get_boto_client.return_value = (None, None)

# Mocking the context manager of Image.open
mock_image = Mock()
mock_image.format = "PNG"
mock_img_open.return_value.__enter__.return_value = mock_image
mock_file = mock_open.return_value.__enter__.return_value
mock_file.read.return_value = b"simulated_uploaded"
mock_file.__exit__.return_value = False

with patch("builtins.open") as mock_open:
mock_open.return_value = io.BytesIO(b"simulated_uploaded")
result = rp_upload.upload_image("job_id", "image_location")
result = rp_upload.upload_image("job_id", "image_location")

# Assert that image is saved locally
assert "simulated_uploaded" in result
mock_makedirs.assert_called_once()
mock_img_open.assert_called_once()
mock_open.assert_called_once()
mock_image.save.assert_called_once()

@patch("runpod.serverless.utils.rp_upload.get_boto_client")
@patch("runpod.serverless.utils.rp_upload.Image.open")
@patch("runpod.serverless.utils.rp_upload.BytesIO")
def test_upload_image_s3(self, mock_bytes_io, mock_open, mock_get_boto_client):
@patch("builtins.open")
def test_upload_image_s3(self, mock_open, mock_get_boto_client):
'''
Test upload_image function when there is a boto client
'''
Expand All @@ -165,22 +159,15 @@ def test_upload_image_s3(self, mock_bytes_io, mock_open, mock_get_boto_client):
mock_image.format = "PNG"
mock_open.return_value.__enter__.return_value = mock_image

# Mocking BytesIO
mock_bytes_io_instance = Mock()
mock_bytes_io_instance.getvalue = Mock(return_value="image_bytes")
mock_bytes_io.return_value = mock_bytes_io_instance

result = rp_upload.upload_image("job_id", "image_location")

# Assert the image is uploaded to S3
assert result == "presigned_url"
mock_open.assert_called_once_with("image_location")
mock_open.assert_called_once_with("image_location", "rb")
mock_boto_client.put_object.assert_called_once()
mock_boto_client.generate_presigned_url.assert_called_once()




class TestUploadUtility(unittest.TestCase):
''' Tests for upload utility '''

Expand Down