Skip to content

Commit 8b70975

Browse files
committed
refactor: update url join method implementation and function signature
1 parent e8bd322 commit 8b70975

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

python/pydantic_core/_pydantic_core.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,13 +606,13 @@ class Url(SupportsAllComparisons):
606606
An instance of URL
607607
"""
608608

609-
def join(self, path: str, trailing_slash: bool = True) -> Self:
609+
def join(self, path: str, append_trailing_slash: bool = False) -> Self:
610610
"""
611611
Parse a string `path` as an URL, using this URL as the base.
612612
613613
Args:
614614
path: The string (typically a relative URL) to parse and join with the base URL.
615-
trailing_slash: Whether to append a trailing slash at the end of the URL.
615+
append_trailing_slash: Whether to append a trailing slash at the end of the URL.
616616
617617
Returns:
618618
A new `Url` instance

src/url.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,24 @@ impl PyUrl {
199199
cls.call1((url,))
200200
}
201201

202-
#[pyo3(signature=(path, trailing_slash=true))]
203-
pub fn join(&self, path: &str, trailing_slash: bool) -> PyResult<Self> {
202+
#[pyo3(signature=(path, append_trailing_slash=false))]
203+
pub fn join(&self, path: &str, append_trailing_slash: bool) -> PyResult<Self> {
204204
let mut new_url = self
205205
.lib_url
206206
.join(path)
207207
.map_err(|err| PyValueError::new_err(err.to_string()))?;
208208

209-
if !trailing_slash || new_url.query().is_some() || new_url.fragment().is_some() || new_url.cannot_be_a_base() {
210-
return Ok(PyUrl::new(new_url));
211-
}
212-
213-
new_url
214-
.path_segments_mut()
215-
.map_err(|()| PyValueError::new_err("Url cannot be a base"))?
216-
.pop_if_empty()
217-
.push("");
209+
if append_trailing_slash && !(new_url.query().is_some() || new_url.fragment().is_some()) {
210+
let path_segments_result = new_url.path_segments_mut().map(|mut segments| {
211+
segments.pop_if_empty().push("");
212+
});
218213

214+
if path_segments_result.is_err() {
215+
let mut new_path = new_url.path().to_string();
216+
new_path.push('/');
217+
new_url.set_path(&new_path);
218+
}
219+
}
219220
Ok(PyUrl::new(new_url))
220221
}
221222
}

tests/validators/test_url.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,7 @@ def test_url_build() -> None:
13241324
('http://a/b/c/d/e', '/../../f/g/', 'http://a/f/g/', 'http://a/f/g/'),
13251325
('http://a/b/c/d/e/', '../../f/g', 'http://a/b/c/f/g/', 'http://a/b/c/f/g'),
13261326
('http://a/b/', '../../f/g/', 'http://a/f/g/', 'http://a/f/g/'),
1327-
(SIMPLE_BASE, 'g:h', 'g:h', 'g:h'),
1327+
(SIMPLE_BASE, 'g:h', 'g:h/', 'g:h'),
13281328
(SIMPLE_BASE, 'g', 'http://a/b/c/g/', 'http://a/b/c/g'),
13291329
(SIMPLE_BASE, './g', 'http://a/b/c/g/', 'http://a/b/c/g'),
13301330
(SIMPLE_BASE, 'g/', 'http://a/b/c/g/', 'http://a/b/c/g/'),
@@ -1350,7 +1350,7 @@ def test_url_build() -> None:
13501350
(SIMPLE_BASE + '/', 'foo', SIMPLE_BASE + '/foo/', SIMPLE_BASE + '/foo'),
13511351
(QUERY_BASE, '?y', 'http://a/b/c/d;p?y', 'http://a/b/c/d;p?y'),
13521352
(QUERY_BASE, ';x', 'http://a/b/c/;x/', 'http://a/b/c/;x'),
1353-
(QUERY_BASE, 'g:h', 'g:h', 'g:h'),
1353+
(QUERY_BASE, 'g:h', 'g:h/', 'g:h'),
13541354
(QUERY_BASE, 'g', 'http://a/b/c/g/', 'http://a/b/c/g'),
13551355
(QUERY_BASE, './g', 'http://a/b/c/g/', 'http://a/b/c/g'),
13561356
(QUERY_BASE, 'g/', 'http://a/b/c/g/', 'http://a/b/c/g/'),
@@ -1436,8 +1436,8 @@ def test_url_join(base_url, join_path, expected_with_slash, expected_without_sla
14361436
and the URL specification from https://url.spec.whatwg.org/
14371437
"""
14381438
url = Url(base_url)
1439-
assert str(url.join(join_path, trailing_slash=True)) == expected_with_slash
1440-
assert str(url.join(join_path, trailing_slash=False)) == expected_without_slash
1439+
assert str(url.join(join_path, append_trailing_slash=True)) == expected_with_slash
1440+
assert str(url.join(join_path, append_trailing_slash=False)) == expected_without_slash
14411441

14421442

14431443
def test_url_join_operators() -> None:

0 commit comments

Comments
 (0)