Skip to content

Commit 01e19bd

Browse files
committed
Drafting infer_fieldnames_from_function_return_type
1 parent 21fcd41 commit 01e19bd

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

minet/scrape/classes/function.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
1-
from typing import Optional, Callable, Any, cast, Dict
1+
from typing import Union, Optional, Callable, Any, cast, Dict, List
22

33
import inspect
4+
45
from casanova import RowWrapper
56
from bs4 import SoupStrainer
67

8+
from minet.types import get_type_hints, get_origin, get_args
79
from minet.scrape.classes.base import ScraperBase
810
from minet.scrape.soup import WonderfulSoup
911
from minet.scrape.straining import strainer_from_css
1012
from minet.scrape.utils import ensure_soup
1113
from minet.scrape.types import AnyScrapableTarget
1214

1315

16+
def infer_fieldnames_from_function_return_type(fn: Callable) -> Optional[List[str]]:
17+
if not callable(fn):
18+
raise TypeError
19+
20+
return_type = get_type_hints(fn)["return"]
21+
22+
origin = get_origin(return_type)
23+
24+
if origin is Union:
25+
args = get_args(return_type)
26+
27+
# Optionals
28+
if len(args) == 2:
29+
if args[1] is type(None):
30+
return_type = args[0]
31+
32+
if return_type in (str, int, float, bool, type(None)):
33+
return ["value"]
34+
35+
return None
36+
37+
1438
class FunctionScraper(ScraperBase):
1539
fn: Callable[[RowWrapper, WonderfulSoup], Any]
1640
fieldnames = None

test/scraper_test.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# =============================================================================
22
# Minet Scrape Unit Tests
33
# =============================================================================
4+
from typing import Optional
5+
46
import pytest
57
from bs4 import BeautifulSoup, Tag, SoupStrainer
68
from textwrap import dedent
@@ -31,6 +33,7 @@
3133
ScraperValidationMixedConcernError,
3234
ScraperValidationUnknownKeyError,
3335
)
36+
from minet.scrape.classes.function import infer_fieldnames_from_function_return_type
3437

3538
BASIC_HTML = """
3639
<ul>
@@ -160,7 +163,7 @@
160163
"""
161164

162165

163-
class TestDefinitionScraper(object):
166+
class TestDefinitionScraper:
164167
def test_basics(self):
165168
result = scrape({"iterator": "li"}, BASIC_HTML)
166169

@@ -1055,3 +1058,31 @@ def clean(t):
10551058
text = get_display_text(elements)
10561059

10571060
assert text == "L'internationale."
1061+
1062+
1063+
class TestFunctionScraper:
1064+
def test_infer_fieldnames_from_function_return_type(self):
1065+
def basic_string() -> str:
1066+
return "ok"
1067+
1068+
def basic_int() -> int:
1069+
return 4
1070+
1071+
def basic_float() -> float:
1072+
return 4.0
1073+
1074+
def basic_bool() -> bool:
1075+
return True
1076+
1077+
def basic_void() -> None:
1078+
return
1079+
1080+
def basic_optional_scalar() -> Optional[str]:
1081+
return
1082+
1083+
assert infer_fieldnames_from_function_return_type(basic_string) == ["value"]
1084+
assert infer_fieldnames_from_function_return_type(basic_int) == ["value"]
1085+
assert infer_fieldnames_from_function_return_type(basic_float) == ["value"]
1086+
assert infer_fieldnames_from_function_return_type(basic_bool) == ["value"]
1087+
assert infer_fieldnames_from_function_return_type(basic_void) == ["value"]
1088+
assert infer_fieldnames_from_function_return_type(basic_optional_scalar) == ["value"]

0 commit comments

Comments
 (0)