diff --git a/tpch/__init__.py b/tpch/__init__.py index e69de29bb2..abc43f5e05 100644 --- a/tpch/__init__.py +++ b/tpch/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from tpch import _typing + +__all__ = ["_typing"] diff --git a/tpch/_typing.py b/tpch/_typing.py new file mode 100644 index 0000000000..4d7c4ac3ad --- /dev/null +++ b/tpch/_typing.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Any +from typing import TypeVar + +from narwhals import DataFrame +from narwhals import LazyFrame + +FrameT = TypeVar("FrameT", DataFrame[Any], LazyFrame[Any]) diff --git a/tpch/execute.py b/tpch/execute.py index 25e368d305..0d1cdf4358 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -16,7 +16,7 @@ import narwhals as nw pd.options.mode.copy_on_write = True -pd.options.future.infer_string = True +pd.options.future.infer_string = True # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] pl.Config.set_fmt_float("full") DATA_DIR = Path("data") diff --git a/tpch/queries/q1.py b/tpch/queries/q1.py index d0acd5a85f..7ca7e305cb 100644 --- a/tpch/queries/q1.py +++ b/tpch/queries/q1.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(lineitem: FrameT) -> FrameT: diff --git a/tpch/queries/q10.py b/tpch/queries/q10.py index 5eb2a1eac9..c2b220ed82 100644 --- a/tpch/queries/q10.py +++ b/tpch/queries/q10.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( diff --git a/tpch/queries/q11.py b/tpch/queries/q11.py index 9d28bc0904..f985a7f99f 100644 --- a/tpch/queries/q11.py +++ b/tpch/queries/q11.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -30,7 +30,7 @@ def query( q1.with_columns((nw.col("ps_supplycost") * nw.col("ps_availqty")).alias("value")) .group_by("ps_partkey") .agg(nw.sum("value")) - .join(q2, how="cross") + .join(q2, how="cross") # pyright: ignore[reportArgumentType] .filter(nw.col("value") > nw.col("tmp")) .select("ps_partkey", "value") .sort("value", descending=True) diff --git a/tpch/queries/q12.py b/tpch/queries/q12.py index bb9b95604e..72632715e7 100644 --- a/tpch/queries/q12.py +++ b/tpch/queries/q12.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(line_item_ds: FrameT, orders_ds: FrameT) -> FrameT: diff --git a/tpch/queries/q13.py b/tpch/queries/q13.py index fbcb8e44fe..14d65d2a0e 100644 --- a/tpch/queries/q13.py +++ b/tpch/queries/q13.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: @@ -14,7 +14,7 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: orders = orders_ds.filter(~nw.col("o_comment").str.contains(f"{var1}.*{var2}")) return ( - customer_ds.join(orders, left_on="c_custkey", right_on="o_custkey", how="left") + customer_ds.join(orders, left_on="c_custkey", right_on="o_custkey", how="left") # pyright: ignore[reportArgumentType] .group_by("c_custkey") .agg(nw.col("o_orderkey").count().alias("c_count")) .group_by("c_count") diff --git a/tpch/queries/q14.py b/tpch/queries/q14.py index 24bf1f6f89..59dd222d60 100644 --- a/tpch/queries/q14.py +++ b/tpch/queries/q14.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(line_item_ds: FrameT, part_ds: FrameT) -> FrameT: diff --git a/tpch/queries/q15.py b/tpch/queries/q15.py index 7bf923e639..9fed9e7eb6 100644 --- a/tpch/queries/q15.py +++ b/tpch/queries/q15.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -29,7 +29,7 @@ def query( ) return ( - supplier_ds.join(revenue, left_on="s_suppkey", right_on="supplier_no") + supplier_ds.join(revenue, left_on="s_suppkey", right_on="supplier_no") # pyright: ignore[reportArgumentType] .filter(nw.col("total_revenue") == nw.col("total_revenue").max()) .with_columns(nw.col("total_revenue").round(2)) .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") diff --git a/tpch/queries/q16.py b/tpch/queries/q16.py index d574a8ecc0..debb42351d 100644 --- a/tpch/queries/q16.py +++ b/tpch/queries/q16.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(part_ds: FrameT, partsupp_ds: FrameT, supplier_ds: FrameT) -> FrameT: @@ -20,7 +20,7 @@ def query(part_ds: FrameT, partsupp_ds: FrameT, supplier_ds: FrameT) -> FrameT: .filter(nw.col("p_brand") != var1) .filter(~nw.col("p_type").str.contains("MEDIUM POLISHED*")) .filter(nw.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9])) - .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left") + .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left") # pyright: ignore[reportArgumentType] .filter(nw.col("ps_suppkey_right").is_null()) .group_by("p_brand", "p_type", "p_size") .agg(nw.col("ps_suppkey").n_unique().alias("supplier_cnt")) diff --git a/tpch/queries/q17.py b/tpch/queries/q17.py index 80796a3475..f373d8a96b 100644 --- a/tpch/queries/q17.py +++ b/tpch/queries/q17.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: @@ -23,7 +23,7 @@ def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: .group_by("p_partkey") .agg(nw.col("l_quantity_times_point_2").mean().alias("avg_quantity")) .select(nw.col("p_partkey").alias("key"), nw.col("avg_quantity")) - .join(query1, left_on="key", right_on="p_partkey") + .join(query1, left_on="key", right_on="p_partkey") # pyright: ignore[reportArgumentType] .filter(nw.col("l_quantity") < nw.col("avg_quantity")) .select((nw.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")) ) diff --git a/tpch/queries/q18.py b/tpch/queries/q18.py index cc3aa6d6f0..42bf12adca 100644 --- a/tpch/queries/q18.py +++ b/tpch/queries/q18.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(customer_ds: FrameT, lineitem_ds: FrameT, orders_ds: FrameT) -> FrameT: @@ -18,7 +18,7 @@ def query(customer_ds: FrameT, lineitem_ds: FrameT, orders_ds: FrameT) -> FrameT ) return ( - orders_ds.join(query1, left_on="o_orderkey", right_on="l_orderkey", how="semi") + orders_ds.join(query1, left_on="o_orderkey", right_on="l_orderkey", how="semi") # pyright: ignore[reportArgumentType] .join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey") .join(customer_ds, left_on="o_custkey", right_on="c_custkey") .group_by("c_name", "o_custkey", "o_orderkey", "o_orderdate", "o_totalprice") diff --git a/tpch/queries/q19.py b/tpch/queries/q19.py index aeca219197..09aaeac553 100644 --- a/tpch/queries/q19.py +++ b/tpch/queries/q19.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: diff --git a/tpch/queries/q2.py b/tpch/queries/q2.py index 0c5601afe0..fcf26f05ab 100644 --- a/tpch/queries/q2.py +++ b/tpch/queries/q2.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -46,7 +46,7 @@ def query( result_q2.group_by("p_partkey") .agg(nw.col("ps_supplycost").min().alias("ps_supplycost")) .join( - result_q2, + result_q2, # pyright: ignore[reportArgumentType] left_on=["p_partkey", "ps_supplycost"], right_on=["p_partkey", "ps_supplycost"], ) diff --git a/tpch/queries/q20.py b/tpch/queries/q20.py index b2319e436d..e61cb333c1 100644 --- a/tpch/queries/q20.py +++ b/tpch/queries/q20.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -28,7 +28,7 @@ def query( .with_columns(sum_quantity=nw.col("sum_quantity") * 0.5) ) query2 = nation_ds.filter(nw.col("n_name") == var3) - query3 = supplier_ds.join(query2, left_on="s_nationkey", right_on="n_nationkey") + query3 = supplier_ds.join(query2, left_on="s_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] return ( part_ds.filter(nw.col("p_name").str.starts_with(var4)) @@ -36,14 +36,14 @@ def query( .unique("p_partkey") .join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") .join( - query1, + query1, # pyright: ignore[reportArgumentType] left_on=["ps_suppkey", "p_partkey"], right_on=["l_suppkey", "l_partkey"], ) .filter(nw.col("ps_availqty") > nw.col("sum_quantity")) .select("ps_suppkey") .unique("ps_suppkey") - .join(query3, left_on="ps_suppkey", right_on="s_suppkey") + .join(query3, left_on="ps_suppkey", right_on="s_suppkey") # pyright: ignore[reportArgumentType] .select("s_name", "s_address") .sort("s_name") ) diff --git a/tpch/queries/q21.py b/tpch/queries/q21.py index 0babac08f3..99ea981dde 100644 --- a/tpch/queries/q21.py +++ b/tpch/queries/q21.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -21,7 +21,7 @@ def query( .agg(nw.len().alias("n_supp_by_order")) .filter(nw.col("n_supp_by_order") > 1) .join( - lineitem.filter(nw.col("l_receiptdate") > nw.col("l_commitdate")), + lineitem.filter(nw.col("l_receiptdate") > nw.col("l_commitdate")), # pyright: ignore[reportArgumentType] left_on="l_orderkey", right_on="l_orderkey", ) @@ -31,7 +31,7 @@ def query( q1.group_by("l_orderkey") .agg(nw.len().alias("n_supp_by_order")) .join( - q1, + q1, # pyright: ignore[reportArgumentType] left_on="l_orderkey", right_on="l_orderkey", ) diff --git a/tpch/queries/q22.py b/tpch/queries/q22.py index d1c670c788..f0b8ebe698 100644 --- a/tpch/queries/q22.py +++ b/tpch/queries/q22.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: @@ -26,9 +26,9 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: ) return ( - q1.join(q3, left_on="c_custkey", right_on="c_custkey", how="left") + q1.join(q3, left_on="c_custkey", right_on="c_custkey", how="left") # pyright: ignore[reportArgumentType] .filter(nw.col("o_custkey").is_null()) - .join(q2, how="cross") + .join(q2, how="cross") # pyright: ignore[reportArgumentType] .filter(nw.col("c_acctbal") > nw.col("avg_acctbal")) .group_by("cntrycode") .agg( diff --git a/tpch/queries/q3.py b/tpch/queries/q3.py index 6fb32b9c8f..3145a0b79c 100644 --- a/tpch/queries/q3.py +++ b/tpch/queries/q3.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( diff --git a/tpch/queries/q4.py b/tpch/queries/q4.py index 981f39690c..918d991740 100644 --- a/tpch/queries/q4.py +++ b/tpch/queries/q4.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(line_item_ds: FrameT, orders_ds: FrameT) -> FrameT: diff --git a/tpch/queries/q5.py b/tpch/queries/q5.py index 933b972f36..57a1cee58d 100644 --- a/tpch/queries/q5.py +++ b/tpch/queries/q5.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( diff --git a/tpch/queries/q6.py b/tpch/queries/q6.py index c5502219ff..50d20706bf 100644 --- a/tpch/queries/q6.py +++ b/tpch/queries/q6.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query(line_item_ds: FrameT) -> FrameT: diff --git a/tpch/queries/q7.py b/tpch/queries/q7.py index 27ebf475be..abf3f19b0a 100644 --- a/tpch/queries/q7.py +++ b/tpch/queries/q7.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -23,22 +23,22 @@ def query( var_2 = datetime(1996, 12, 31) df1 = ( - customer_ds.join(n1, left_on="c_nationkey", right_on="n_nationkey") + customer_ds.join(n1, left_on="c_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .join(orders_ds, left_on="c_custkey", right_on="o_custkey") .rename({"n_name": "cust_nation"}) .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") - .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .join(n2, left_on="s_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .rename({"n_name": "supp_nation"}) ) df2 = ( - customer_ds.join(n2, left_on="c_nationkey", right_on="n_nationkey") + customer_ds.join(n2, left_on="c_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .join(orders_ds, left_on="c_custkey", right_on="o_custkey") .rename({"n_name": "cust_nation"}) .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") - .join(n1, left_on="s_nationkey", right_on="n_nationkey") + .join(n1, left_on="s_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .rename({"n_name": "supp_nation"}) ) diff --git a/tpch/queries/q8.py b/tpch/queries/q8.py index 8817cc9792..de082b3323 100644 --- a/tpch/queries/q8.py +++ b/tpch/queries/q8.py @@ -6,7 +6,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( @@ -32,10 +32,10 @@ def query( .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") .join(customer_ds, left_on="o_custkey", right_on="c_custkey") - .join(n1, left_on="c_nationkey", right_on="n_nationkey") + .join(n1, left_on="c_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") .filter(nw.col("r_name") == region) - .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .join(n2, left_on="s_nationkey", right_on="n_nationkey") # pyright: ignore[reportArgumentType] .filter(nw.col("o_orderdate").is_between(date1, date2)) .filter(nw.col("p_type") == type) .select( diff --git a/tpch/queries/q9.py b/tpch/queries/q9.py index 73c1305187..d4b692776e 100644 --- a/tpch/queries/q9.py +++ b/tpch/queries/q9.py @@ -5,7 +5,7 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import FrameT + from tpch._typing import FrameT def query( diff --git a/utils/bump_version.py b/utils/bump_version.py index ac8c6973eb..fe4079d607 100644 --- a/utils/bump_version.py +++ b/utils/bump_version.py @@ -44,7 +44,7 @@ with open("pyproject.toml", encoding="utf-8") as f: content = f.read() -old_version = re.search(r'version = "(.*)"', content).group(1) +old_version = re.search(r'version = "(.*)"', content).group(1) # pyright: ignore[reportOptionalMemberAccess] version = old_version.split(".") if how == "patch": version = ".".join(version[:-1] + [str(int(version[-1]) + 1)]) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 7a699a2a87..5954abfd6b 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -1,14 +1,39 @@ from __future__ import annotations +import inspect +import string import sys +from inspect import isfunction from pathlib import Path +from typing import Any +from typing import Iterator import polars as pl import narwhals as nw -from narwhals._expression_parsing import ExprMetadata from narwhals.utils import remove_prefix +LOWERCASE = tuple(string.ascii_lowercase) + +if sys.version_info >= (3, 13): + + def _is_public_method_or_property(obj: Any) -> bool: + return (isfunction(obj) or isinstance(obj, property)) and obj.__name__.startswith( + LOWERCASE + ) +else: + + def _is_public_method_or_property(obj: Any) -> bool: + return (isfunction(obj) and obj.__name__.startswith(LOWERCASE)) or ( + isinstance(obj, property) and obj.fget.__name__.startswith(LOWERCASE) + ) + + +def iter_api_reference_names(tp: type[Any]) -> Iterator[str]: + for name, _ in inspect.getmembers(tp, _is_public_method_or_property): + yield name + + ret = 0 NAMESPACES = {"dt", "str", "cat", "name", "list", "struct"} @@ -51,7 +76,7 @@ # Top level functions top_level_functions = [ - i for i in nw.__dir__() if not i[0].isupper() and i[0] != "_" and i not in files + i for i in dir(nw) if not i[0].isupper() and i[0] != "_" and i not in files ] with open("docs/api-reference/narwhals.md") as fd: content = fd.read() @@ -70,11 +95,7 @@ ret = 1 # DataFrame methods -dataframe_methods = [ - i - for i in nw.from_native(pl.DataFrame()).__dir__() - if not i[0].isupper() and i[0] != "_" -] +dataframe_methods = list(iter_api_reference_names(nw.DataFrame)) with open("docs/api-reference/dataframe.md") as fd: content = fd.read() documented = [ @@ -92,11 +113,7 @@ ret = 1 # LazyFrame methods -lazyframe_methods = [ - i - for i in nw.from_native(pl.LazyFrame()).__dir__() - if not i[0].isupper() and i[0] != "_" -] +lazyframe_methods = list(iter_api_reference_names(nw.LazyFrame)) with open("docs/api-reference/lazyframe.md") as fd: content = fd.read() documented = [ @@ -114,11 +131,7 @@ ret = 1 # Series methods -series_methods = [ - i - for i in nw.from_native(pl.Series(), series_only=True).__dir__() - if not i[0].isupper() and i[0] != "_" -] +series_methods = list(iter_api_reference_names(nw.Series)) with open("docs/api-reference/series.md") as fd: content = fd.read() documented = [ @@ -137,11 +150,9 @@ # Series.{cat, dt, list, str} methods for namespace in NAMESPACES.difference({"name"}): - series_methods = [ + series_ns_methods = [ i - for i in getattr( - nw.from_native(pl.Series(), series_only=True), namespace - ).__dir__() + for i in dir(getattr(nw.from_native(pl.Series(), series_only=True), namespace)) if not i[0].isupper() and i[0] != "_" ] with open(f"docs/api-reference/series_{namespace}.md") as fd: @@ -151,21 +162,17 @@ for i in content.splitlines() if i.startswith(" - ") and not i.startswith(" - _") ] - if missing := set(series_methods).difference(documented): + if missing := set(series_ns_methods).difference(documented): print(f"Series.{namespace}: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 - if extra := set(documented).difference(series_methods): + if extra := set(documented).difference(series_ns_methods): print(f"Series.{namespace}: outdated") # noqa: T201 print(extra) # noqa: T201 ret = 1 # Expr methods -expr_methods = [ - i - for i in nw.Expr(lambda: 0, ExprMetadata.selector_single()).__dir__() - if not i[0].isupper() and i[0] != "_" -] +expr_methods = list(iter_api_reference_names(nw.Expr)) with open("docs/api-reference/expr.md") as fd: content = fd.read() documented = [ @@ -184,12 +191,9 @@ # Expr.{cat, dt, list, name, str} methods for namespace in NAMESPACES: - expr_methods = [ + expr_ns_methods = [ i - for i in getattr( - nw.Expr(lambda: 0, ExprMetadata.selector_single()), - namespace, - ).__dir__() + for i in dir(getattr(nw.col("a"), namespace)) if not i[0].isupper() and i[0] != "_" ] with open(f"docs/api-reference/expr_{namespace}.md") as fd: @@ -199,19 +203,17 @@ for i in content.splitlines() if i.startswith(" - ") ] - if missing := set(expr_methods).difference(documented): + if missing := set(expr_ns_methods).difference(documented): print(f"Expr.{namespace}: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 - if extra := set(documented).difference(expr_methods): + if extra := set(documented).difference(expr_ns_methods): print(f"Expr.{namespace}: outdated") # noqa: T201 print(extra) # noqa: T201 ret = 1 # DTypes -dtypes = [ - i for i in nw.dtypes.__dir__() if i[0].isupper() and not i.isupper() and i[0] != "_" -] +dtypes = [i for i in dir(nw.dtypes) if i[0].isupper() and not i.isupper() and i[0] != "_"] with open("docs/api-reference/dtypes.md") as fd: content = fd.read() documented = [ @@ -229,21 +231,11 @@ ret = 1 # Check Expr vs Series -expr = [ - i - for i in nw.Expr(lambda: 0, ExprMetadata.selector_single()).__dir__() - if not i[0].isupper() and i[0] != "_" -] -series = [ - i - for i in nw.from_native(pl.Series(), series_only=True).__dir__() - if not i[0].isupper() and i[0] != "_" -] -if missing := set(expr).difference(series).difference(EXPR_ONLY_METHODS): +if missing := set(expr_methods).difference(series_methods).difference(EXPR_ONLY_METHODS): print("In Expr but not in Series") # noqa: T201 print(missing) # noqa: T201 ret = 1 -if extra := set(series).difference(expr).difference(SERIES_ONLY_METHODS): +if extra := set(series_methods).difference(expr_methods).difference(SERIES_ONLY_METHODS): print("In Series but not in Expr") # noqa: T201 print(extra) # noqa: T201 ret = 1 @@ -252,17 +244,12 @@ for namespace in NAMESPACES.difference({"name"}): expr_internal = [ i - for i in getattr( - nw.Expr(lambda: 0, ExprMetadata.selector_single()), - namespace, - ).__dir__() + for i in dir(getattr(nw.col("a"), namespace)) if not i[0].isupper() and i[0] != "_" ] series_internal = [ i - for i in getattr( - nw.from_native(pl.Series(), series_only=True), namespace - ).__dir__() + for i in dir(getattr(nw.from_native(pl.Series(), series_only=True), namespace)) if not i[0].isupper() and i[0] != "_" ] if missing := set(expr_internal).difference(series_internal): diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index bd89294be9..80a1b86543 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -10,13 +10,13 @@ from pathlib import Path -def extract_docstring_examples(files: list[Path]) -> list[tuple[Path, str, str]]: +def extract_docstring_examples(files: list[str]) -> list[tuple[Path, str, str]]: """Extract examples from docstrings in Python files.""" examples: list[tuple[Path, str, str]] = [] for file in files: - with open(file, encoding="utf-8") as f: - tree = ast.parse(f.read()) + fp = Path(file) + tree = ast.parse(fp.read_text("utf-8")) for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.ClassDef)): @@ -27,7 +27,7 @@ def extract_docstring_examples(files: list[Path]) -> list[tuple[Path, str, str]] example.source for example in parsed_examples ) if example_code.strip(): - examples.append((file, node.name, example_code)) + examples.append((fp, node.name, example_code)) return examples @@ -41,14 +41,14 @@ def create_temp_files(examples: list[tuple[Path, str, str]]) -> list[tuple[Path, temp_file.write(example) temp_file_path = temp_file.name temp_file.close() - temp_files.append((Path(temp_file_path), f"{file}:{name}")) + temp_files.append((Path(temp_file_path), f"{file!s}:{name}")) return temp_files def run_ruff_on_temp_files(temp_files: list[tuple[Path, str]]) -> list[str]: """Run ruff on all temporary files and collect error messages.""" - temp_file_paths = [str(temp_file[0]) for temp_file in temp_files] + temp_file_paths = [temp_file[0] for temp_file in temp_files] result = subprocess.run( # noqa: S603 [ # noqa: S607