Skip to content

Commit 5ddc4a5

Browse files
committed
fix: parse Content-Disposition properly, closes #414
1 parent 1a0976f commit 5ddc4a5

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

runpod/serverless/utils/rp_download.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import uuid
1212
import zipfile
1313
from concurrent.futures import ThreadPoolExecutor
14-
from email import message_from_string
15-
from typing import List, Union
14+
from typing import List, Union, Dict
1615
from urllib.parse import urlparse
1716

1817
import backoff
@@ -34,6 +33,18 @@ def calculate_chunk_size(file_size: int) -> int:
3433

3534
return 1024 * 1024 * 10 # 10 MB
3635

36+
def extract_disposition_params(content_disposition: str)-> Dict[str, str]:
37+
parts = (p.strip() for p in content_disposition.split(";"))
38+
39+
params = {
40+
key.strip().lower(): value.strip().strip('"')
41+
for part in parts
42+
if "=" in part
43+
for key, value in [part.split("=", 1)]
44+
}
45+
46+
return params
47+
3748

3849
def download_files_from_urls(job_id: str, urls: Union[str, List[str]]) -> List[str]:
3950
"""
@@ -55,8 +66,7 @@ def download_file(url: str, path_to_save: str) -> str:
5566
content_disposition = response.headers.get("Content-Disposition")
5667
file_extension = ""
5768
if content_disposition:
58-
msg = message_from_string(f"Content-Disposition: {content_disposition}")
59-
params = dict(msg.items())
69+
params = extract_disposition_params(content_disposition)
6070
file_extension = os.path.splitext(params.get("filename", ""))[1]
6171

6272
# If no extension could be determined from 'Content-Disposition', get it from the URL
@@ -113,15 +123,15 @@ def file(file_url: str) -> dict:
113123

114124
download_response = SyncClientSession().get(file_url, headers=HEADERS, timeout=30)
115125

116-
original_file_name = []
117-
if "Content-Disposition" in download_response.headers.keys():
118-
original_file_name = re.findall(
119-
"filename=(.+)", download_response.headers["Content-Disposition"]
120-
)
126+
content_disposition = download_response.headers.get("Content-Disposition")
121127

122-
if len(original_file_name) > 0:
123-
original_file_name = original_file_name[0]
124-
else:
128+
original_file_name = ""
129+
if content_disposition:
130+
params = extract_disposition_params(content_disposition)
131+
132+
original_file_name = params.get("filename", "")
133+
134+
if not original_file_name:
125135
download_path = urlparse(file_url).path
126136
original_file_name = os.path.basename(download_path)
127137

tests/test_serverless/test_utils/test_download.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
URL_LIST = [
1818
"https://example.com/picture.jpg",
1919
"https://example.com/picture.jpg?X-Amz-Signature=123",
20+
"https://example.com/file_without_extension"
2021
]
2122

2223
JOB_ID = "job_123"
@@ -86,22 +87,27 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs)
8687
"""
8788
Tests download_files_from_urls
8889
"""
90+
urls = [
91+
"https://example.com/picture.jpg",
92+
"https://example.com/file_without_extension"
93+
]
8994
downloaded_files = download_files_from_urls(
9095
JOB_ID,
91-
[
92-
"https://example.com/picture.jpg",
93-
],
96+
urls,
9497
)
9598

96-
self.assertEqual(len(downloaded_files), 1)
99+
self.assertEqual(len(downloaded_files), len(urls))
97100

98-
# Check that the url was called with SyncClientSession.get
99-
self.assertIn("https://example.com/picture.jpg", mock_get.call_args_list[0][0])
101+
for index, url in enumerate(urls):
102+
# Check that the url was called with SyncClientSession.get
103+
self.assertIn(url, mock_get.call_args_list[index][0])
100104

101-
# Check that the file has the correct extension
102-
self.assertTrue(downloaded_files[0].endswith(".jpg"))
105+
# Check that the file has the correct extension
106+
self.assertTrue(downloaded_files[index].endswith(".jpg"))
107+
108+
mock_open_file.assert_any_call(downloaded_files[index], "wb")
109+
103110

104-
mock_open_file.assert_called_once_with(downloaded_files[0], "wb")
105111
mock_makedirs.assert_called_once_with(
106112
os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True
107113
)

0 commit comments

Comments
 (0)