Skip to content

Commit bf234ed

Browse files
authored
Merge pull request #137 from e10v/dev
feat: improve type annotations
2 parents 78aa76d + 43d2e03 commit bf234ed

18 files changed

+309
-223
lines changed

src/_strip_doctest_artifacts/__init__.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,17 @@
33
from __future__ import annotations
44

55
import re
6-
from typing import TYPE_CHECKING
76

87
import markdown
98
import markdown.extensions
109
import markdown.preprocessors
1110

1211

13-
if TYPE_CHECKING:
14-
from typing import Any
15-
16-
1712
class StripDoctestArtifactsPreprocessor(markdown.preprocessors.Preprocessor):
1813
"""A preprocessor that removes doctest artifacts."""
1914

2015
def run(self, lines: list[str]) -> list[str]:
21-
"""Run th preprocessor."""
16+
"""Run the preprocessor."""
2217
return [_strip(line) for line in lines]
2318

2419

@@ -39,6 +34,6 @@ def extendMarkdown(self, md: markdown.Markdown) -> None:
3934
)
4035

4136

42-
def makeExtension(**kwargs: dict[str, Any]) -> StripDoctestArtifactsExtension:
37+
def makeExtension(**kwargs: dict[str, object]) -> StripDoctestArtifactsExtension:
4338
"""A factory function for the extension, required by Python-Markdown."""
4439
return StripDoctestArtifactsExtension(**kwargs)

src/tea_tasting/aggr.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
if TYPE_CHECKING:
1616
from collections.abc import Sequence
17-
from typing import Any
1817

1918
import narwhals.typing # noqa: TC004
2019

@@ -238,7 +237,7 @@ def read_aggregates(
238237
mean_cols: Sequence[str],
239238
var_cols: Sequence[str],
240239
cov_cols: Sequence[tuple[str, str]],
241-
) -> dict[Any, Aggregates]:
240+
) -> dict[object, Aggregates]:
242241
...
243242

244243
@overload
@@ -261,7 +260,7 @@ def read_aggregates(
261260
mean_cols: Sequence[str],
262261
var_cols: Sequence[str],
263262
cov_cols: Sequence[tuple[str, str]],
264-
) -> dict[Any, Aggregates] | Aggregates:
263+
) -> dict[object, Aggregates] | Aggregates:
265264
"""Extract aggregated statistics.
266265
267266
Args:
@@ -346,7 +345,7 @@ def _read_aggr_ibis(
346345
mean_cols: Sequence[str],
347346
var_cols: Sequence[str],
348347
cov_cols: Sequence[tuple[str, str]],
349-
) -> list[dict[str, Any]]:
348+
) -> list[dict[str, int | float]]:
350349
covar_cols = tuple({*var_cols, *itertools.chain(*cov_cols)})
351350
backend = ibis.get_backend(data)
352351
var_op = ibis.expr.operations.Variance
@@ -402,7 +401,7 @@ def _read_aggr_narwhals(
402401
mean_cols: Sequence[str],
403402
var_cols: Sequence[str],
404403
cov_cols: Sequence[tuple[str, str]],
405-
) -> list[dict[str, Any]]:
404+
) -> list[dict[str, int | float]]:
406405
data = nw.from_native(data)
407406
if not isinstance(data, nw.LazyFrame):
408407
data = data.lazy()
@@ -464,15 +463,15 @@ def _demean_nw_col(col: str, group_col: str | None) -> nw.Expr:
464463

465464

466465
def _get_aggregates(
467-
data: dict[str, Any],
466+
data: dict[str, float | int],
468467
*,
469468
has_count: bool,
470469
mean_cols: Sequence[str],
471470
var_cols: Sequence[str],
472471
cov_cols: Sequence[tuple[str, str]],
473472
) -> Aggregates:
474473
return Aggregates(
475-
count_=data[_COUNT] if has_count else None,
474+
count_=data[_COUNT] if has_count else None, # type: ignore
476475
mean_={col: data[_MEAN.format(col)] for col in mean_cols},
477476
var_={col: data[_VAR.format(col)] for col in var_cols},
478477
cov_={cols: data[_COV.format(*cols)] for cols in cov_cols},

src/tea_tasting/config.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from __future__ import annotations
55

66
import contextlib
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, overload
88

99
import tea_tasting.utils
1010

1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Iterator, Sequence
14-
from typing import Any, Literal
14+
from typing import Literal
1515

1616

17-
_global_config = {
17+
_global_config: dict[str, object] = {
1818
"alpha": 0.05,
1919
"alternative": "two-sided",
2020
"confidence_level": 0.95,
@@ -27,7 +27,51 @@
2727
}
2828

2929

30-
def get_config(option: str | None = None) -> Any:
30+
@overload
31+
def get_config(option: Literal["alpha"]) -> float:
32+
...
33+
34+
@overload
35+
def get_config(option: Literal["alternative"]) -> str:
36+
...
37+
38+
@overload
39+
def get_config(option: Literal["confidence_level"]) -> float:
40+
...
41+
42+
@overload
43+
def get_config(option: Literal["equal_var"]) -> bool:
44+
...
45+
46+
@overload
47+
def get_config(option: Literal["n_obs"]) -> int | Sequence[int] | None:
48+
...
49+
50+
@overload
51+
def get_config(option: Literal["n_resamples"]) -> str:
52+
...
53+
54+
@overload
55+
def get_config(option: Literal["power"]) -> float:
56+
...
57+
58+
@overload
59+
def get_config(option: Literal["ratio"]) -> float | int:
60+
...
61+
62+
@overload
63+
def get_config(option: Literal["use_t"]) -> bool:
64+
...
65+
66+
@overload
67+
def get_config(option: str) -> object:
68+
...
69+
70+
@overload
71+
def get_config(option: None = None) -> dict[str, object]:
72+
...
73+
74+
def get_config(option: str | None = None) -> object:
3175
"""Retrieve the current settings of the global configuration.
3276
3377
Args:
@@ -62,7 +106,7 @@ def set_config(
62106
power: float | None = None,
63107
ratio: float | int | None = None,
64108
use_t: bool | None = None,
65-
**kwargs: Any,
109+
**kwargs: object,
66110
) -> None:
67111
"""Update the global configuration with specified settings.
68112
@@ -129,8 +173,8 @@ def config_context(
129173
power: float | None = None,
130174
ratio: float | int | None = None,
131175
use_t: bool | None = None,
132-
**kwargs: Any,
133-
) -> Iterator[Any]:
176+
**kwargs: object,
177+
) -> Iterator[object]:
134178
"""A context manager that temporarily modifies the global configuration.
135179
136180
Args:

src/tea_tasting/datasets.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@
1212

1313

1414
if TYPE_CHECKING:
15-
from typing import Any, Literal
15+
from typing import Literal, TypeAlias
1616

1717
import numpy.typing as npt
1818

1919
try:
20-
from pandas import DataFrame as PandasDataFrame
20+
from pandas import DataFrame as _PandasDataFrame
2121
except ImportError:
22-
from typing import Any as PandasDataFrame
22+
_PandasDataFrame = object
2323

2424
try:
25-
from polars import DataFrame as PolarsDataFrame
25+
from polars import DataFrame as _PolarsDataFrame
2626
except ImportError:
27-
from typing import Any as PolarsDataFrame
27+
_PolarsDataFrame = object
28+
29+
PandasDataFrame: TypeAlias = _PandasDataFrame # type: ignore
30+
PolarsDataFrame: TypeAlias = _PolarsDataFrame # type: ignore
2831

2932

3033
@overload
@@ -610,9 +613,9 @@ def _check_params(
610613

611614

612615
def _avg_by_groups(
613-
values: npt.NDArray[np.number[Any]],
614-
groups: npt.NDArray[np.number[Any]],
615-
) -> npt.NDArray[np.number[Any]]:
616+
values: npt.NDArray[np.number],
617+
groups: npt.NDArray[np.number],
618+
) -> npt.NDArray[np.number]:
616619
return np.concatenate([
617620
np.full(v.shape, v.mean())
618621
for v in np.split(values, np.unique(groups, return_index=True)[1])

src/tea_tasting/experiment.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ExperimentResult(
3434
"pvalue",
3535
)
3636

37-
def to_dicts(self) -> tuple[dict[str, Any], ...]:
37+
def to_dicts(self) -> tuple[dict[str, object], ...]:
3838
"""Convert the result to a sequence of dictionaries.
3939
4040
Examples:
@@ -77,11 +77,11 @@ def to_dicts(self) -> tuple[dict[str, Any], ...]:
7777
return tuple(
7878
{"metric": k} | (v if isinstance(v, dict) else v._asdict())
7979
for k, v in self.items()
80-
)
80+
) # type: ignore
8181

8282

8383
class ExperimentResults(
84-
UserDict[tuple[Any, Any], ExperimentResult],
84+
UserDict[tuple[object, object], ExperimentResult],
8585
tea_tasting.utils.DictsReprMixin,
8686
):
8787
"""Experiment results for multiple pairs of variants."""
@@ -95,7 +95,7 @@ class ExperimentResults(
9595
"pvalue",
9696
)
9797

98-
def to_dicts(self) -> tuple[dict[str, Any], ...]:
98+
def to_dicts(self) -> tuple[dict[str, object], ...]:
9999
"""Convert the result to a sequence of dictionaries."""
100100
return tuple(
101101
{"variants": str(variants)} | metric_result
@@ -111,7 +111,7 @@ class ExperimentPowerResult(
111111
"""Result of the analysis of power in a experiment."""
112112
default_keys = ("metric", "power", "effect_size", "rel_effect_size", "n_obs")
113113

114-
def to_dicts(self) -> tuple[dict[str, Any], ...]:
114+
def to_dicts(self) -> tuple[dict[str, object], ...]:
115115
"""Convert the result to a sequence of dictionaries."""
116116
dicts = ()
117117
for metric, results in self.items():
@@ -230,7 +230,7 @@ def __init__(
230230
def analyze(
231231
self,
232232
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
233-
control: Any = None,
233+
control: object = None,
234234
*,
235235
all_variants: Literal[False] = False,
236236
) -> ExperimentResult:
@@ -240,7 +240,7 @@ def analyze(
240240
def analyze(
241241
self,
242242
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
243-
control: Any = None,
243+
control: object = None,
244244
*,
245245
all_variants: Literal[True] = True,
246246
) -> ExperimentResults:
@@ -249,7 +249,7 @@ def analyze(
249249
def analyze(
250250
self,
251251
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
252-
control: Any = None,
252+
control: object = None,
253253
*,
254254
all_variants: bool = False,
255255
) -> ExperimentResult | ExperimentResults:
@@ -273,7 +273,7 @@ def analyze(
273273
variants = granular_data.keys()
274274
else:
275275
variants = self._read_variants(data)
276-
variants = sorted(variants)
276+
variants = sorted(variants) # type: ignore
277277

278278
if control is not None:
279279
variant_pairs = tuple(
@@ -318,10 +318,10 @@ def _analyze_metric(
318318
self,
319319
metric: tea_tasting.metrics.MetricBase[Any],
320320
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
321-
aggregated_data: dict[Any, tea_tasting.aggr.Aggregates] | None,
322-
granular_data: dict[Any, pa.Table] | None,
323-
control: Any,
324-
treatment: Any,
321+
aggregated_data: dict[object, tea_tasting.aggr.Aggregates] | None,
322+
granular_data: dict[object, pa.Table] | None,
323+
control: object,
324+
treatment: object,
325325
) -> tea_tasting.metrics.MetricResult:
326326
if (
327327
isinstance(metric, tea_tasting.metrics.MetricBaseAggregated)
@@ -342,8 +342,8 @@ def _read_data(
342342
self,
343343
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
344344
) -> tuple[
345-
dict[Any, tea_tasting.aggr.Aggregates] | None,
346-
dict[Any, pa.Table] | None,
345+
dict[object, tea_tasting.aggr.Aggregates] | None,
346+
dict[object, pa.Table] | None,
347347
]:
348348
aggr_cols = tea_tasting.metrics.AggrCols()
349349
gran_cols = set()
@@ -371,7 +371,7 @@ def _read_data(
371371
def _read_variants(
372372
self,
373373
data: narwhals.typing.IntoFrame | ibis.expr.types.Table,
374-
) -> list[Any]:
374+
) -> list[object]:
375375
if isinstance(data, ibis.expr.types.Table):
376376
return (
377377
data.select(self.variant)

0 commit comments

Comments
 (0)