Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore[python]: Enforce strict type equality in mypy check #4058

Merged
merged 1 commit into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ disallow_untyped_calls = true
warn_redundant_casts = true
# warn_return_any = true
no_implicit_reexport = true
# strict_equality = true
strict_equality = true
# TODO: When all flags are enabled, replace by strict = true
enable_error_code = [
"redundant-expr",
Expand Down
5 changes: 3 additions & 2 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# may contain many things that seemed to go wrong at scale

import time
from typing import cast

import numpy as np

Expand Down Expand Up @@ -50,8 +51,8 @@
computed = permuted.select(
[pl.col("id").min().alias("min"), pl.col("id").max().alias("max")]
)
assert computed[0, "min"] == minimum
assert computed[0, "max"] == maximum
assert cast(int, computed[0, "min"]) == minimum
assert cast(float, computed[0, "max"]) == maximum


def test_windows_not_cached() -> None:
Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import zlib
from datetime import date
from pathlib import Path
from typing import cast

import pytest

Expand Down Expand Up @@ -402,8 +403,8 @@ def test_csv_globbing(examples_dir: str) -> None:

df = pl.read_csv(path, columns=["category", "sugars_g"])
assert df.shape == (135, 2)
assert df.row(-1) == ("seafood", 1)
assert df.row(0) == ("vegetables", 2)
assert df.row(-1) == ("seafood", 1) # type: ignore[comparison-overlap]
assert df.row(0) == ("vegetables", 2) # type: ignore[comparison-overlap]

with pytest.raises(ValueError):
_ = pl.read_csv(path, dtypes=[pl.Utf8, pl.Int64, pl.Int64, pl.Int64])
Expand Down Expand Up @@ -509,7 +510,10 @@ def fallback_chrono_parser() -> None:
2021-10-10,2021-10-10
"""
)
assert pl.read_csv(data.encode(), parse_dates=True).null_count().row(0) == (0, 0)
assert cast(
tuple[int, int],
pl.read_csv(data.encode(), parse_dates=True).null_count().row(0),
) == (0, 0)


def test_csv_string_escaping() -> None:
Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_init_dict() -> None:

# List of empty list/tuple
df = pl.DataFrame({"a": [[]], "b": [()]})
assert df.schema == {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
expected = {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
assert df.schema == expected # type: ignore[comparison-overlap]
assert df.rows() == [([], [])]

# Mixed dtypes
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_dtype_init_equivalence() -> None:
if inspect.isclass(dtype) and issubclass(dtype, datatypes.DataType)
}
for dtype in all_datatypes:
assert dtype == dtype()
assert dtype == dtype() # type: ignore[comparison-overlap]


def test_dtype_temporal_units() -> None:
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import io
import typing
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, no_type_check

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -154,7 +153,7 @@ def test_datetime_consistency() -> None:
pl.lit(dt).cast(pl.Datetime("ns")).alias("dt_ns"),
]
)
assert ddf.schema == {
assert ddf.schema == { # type: ignore[comparison-overlap]
"date": pl.Datetime("us"),
"dt": pl.Datetime("us"),
"dt_ms": pl.Datetime("ms"),
Expand Down Expand Up @@ -886,7 +885,7 @@ def test_agg_logical() -> None:
assert s.min() == dates[0]


@typing.no_type_check
@no_type_check
def test_from_time_arrow() -> None:
times = pa.array([10, 20, 30], type=pa.time32("s"))
times_table = pa.table([times], names=["times"])
Expand Down Expand Up @@ -1027,7 +1026,8 @@ def test_datetime_instance_selection() -> None:
],
)
for tu in DTYPE_TEMPORAL_UNITS:
assert df.select(pl.col([pl.Datetime(tu)])).dtypes == [pl.Datetime(tu)]
res = df.select(pl.col([pl.Datetime(tu)])).dtypes
assert res == [pl.Datetime(tu)] # type: ignore[comparison-overlap]
assert len(df.filter(pl.col(tu) == test_data[tu][0])) == 1


Expand Down
26 changes: 13 additions & 13 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_selection() -> None:

# select columns by mask
assert df[:2, :1].shape == (2, 1)
assert df[:2, "a"].shape == (2, 1)
assert df[:2, "a"].shape == (2, 1) # type: ignore[comparison-overlap]

# column selection by string(s) in first dimension
assert df["a"].to_list() == [1, 2, 3]
Expand All @@ -117,11 +117,11 @@ def test_selection() -> None:
assert df[[1, 2], [1, 2]].frame_equal(
pl.DataFrame({"b": [2.0, 3.0], "c": ["b", "c"]})
)
assert df[1, 2] == "b"
assert df[1, 1] == 2.0
assert df[2, 0] == 3
assert typing.cast(str, df[1, 2]) == "b"
assert typing.cast(float, df[1, 1]) == 2.0
assert typing.cast(int, df[2, 0]) == 3

assert df[[0, 1], "b"].shape == (2, 1)
assert df[[0, 1], "b"].shape == (2, 1) # type: ignore[comparison-overlap]
assert df[[2], ["a", "b"]].shape == (1, 2)
assert df.to_series(0).name == "a"
assert (df["a"] == df["a"]).sum() == 3
Expand All @@ -132,10 +132,10 @@ def test_selection() -> None:
assert df[1, [2]].frame_equal(expect)
expect = pl.DataFrame({"b": [1.0, 3.0]})
assert df[[0, 2], [1]].frame_equal(expect)
assert df[0, "c"] == "a"
assert df[1, "c"] == "b"
assert df[2, "c"] == "c"
assert df[0, "a"] == 1
assert typing.cast(str, df[0, "c"]) == "a"
assert typing.cast(str, df[1, "c"]) == "b"
assert typing.cast(str, df[2, "c"]) == "c"
assert typing.cast(int, df[0, "a"]) == 1

# more slicing
expect = pl.DataFrame({"a": [3, 2, 1], "b": [3.0, 2.0, 1.0], "c": ["c", "b", "a"]})
Expand Down Expand Up @@ -766,9 +766,9 @@ def test_df_fold() -> None:

def test_row_tuple() -> None:
df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
assert df.row(0) == ("foo", 1, 1.0)
assert df.row(1) == ("bar", 2, 2.0)
assert df.row(-1) == ("2", 3, 3.0)
assert df.row(0) == ("foo", 1, 1.0) # type: ignore[comparison-overlap]
assert df.row(1) == ("bar", 2, 2.0) # type: ignore[comparison-overlap]
assert df.row(-1) == ("2", 3, 3.0) # type: ignore[comparison-overlap]


def test_df_apply() -> None:
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def dot_product() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})

assert df["a"].dot(df["b"]) == 20
assert df.select([pl.col("a").dot("b")])[0, "a"] == 20
assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20


def test_hash_rows() -> None:
Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import cast

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api

Expand Down Expand Up @@ -82,7 +84,7 @@ def test_count_expr() -> None:

out = df.select(pl.count())
assert out.shape == (1, 1)
assert out[0, 0] == 5
assert cast(int, out[0, 0]) == 5

out = df.groupby("b", maintain_order=True).agg(pl.count())
assert out["b"].to_list() == ["a", "b"]
Expand Down Expand Up @@ -274,9 +276,11 @@ def test_regex_in_filter() -> None:
}
)

assert df.filter(
res = df.filter(
pl.fold(acc=False, f=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3))
).row(0) == (1, "foo", 1.0)
).row(0)
expected = (1, "foo", 1.0)
assert res == expected # type: ignore[comparison-overlap]


def test_arr_contains() -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_from_pandas_ns_resolution() -> None:
[pd.Timestamp(year=2021, month=1, day=1, hour=1, second=1, nanosecond=1)],
columns=["date"],
)
assert pl.from_pandas(df)[0, 0] == datetime(2021, 1, 1, 1, 0, 1)
assert cast(datetime, pl.from_pandas(df)[0, 0]) == datetime(2021, 1, 1, 1, 0, 1)


@no_type_check
Expand Down
20 changes: 11 additions & 9 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_is_finite_is_infinite() -> None:

def test_len() -> None:
df = pl.DataFrame({"nrs": [1, 2, 3]})
assert df.select(col("nrs").len())[0, 0] == 3
assert cast(int, df.select(col("nrs").len())[0, 0]) == 3


def test_cum_agg() -> None:
Expand Down Expand Up @@ -505,7 +505,7 @@ def test_round() -> None:

def test_dot() -> None:
df = pl.DataFrame({"a": [1.8, 1.2, 3.0], "b": [3.2, 1, 2]})
assert df.select(pl.col("a").dot(pl.col("b")))[0, 0] == 12.96
assert cast(float, df.select(pl.col("a").dot(pl.col("b")))[0, 0]) == 12.96


def test_sort() -> None:
Expand Down Expand Up @@ -696,8 +696,8 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None:
]
)

assert out_single_val_variance[0, "std"] == 0.0
assert out_single_val_variance[0, "var"] == 0.0
assert cast(float, out_single_val_variance[0, "std"]) == 0.0
assert cast(float, out_single_val_variance[0, "var"]) == 0.0


def test_rolling_apply() -> None:
Expand Down Expand Up @@ -993,16 +993,16 @@ def test_join_suffix() -> None:
def test_str_concat() -> None:
df = pl.DataFrame({"foo": [1, None, 2]})
df = df.select(pl.col("foo").str.concat("-"))
assert df[0, 0] == "1-null-2"
assert cast(str, df[0, 0]) == "1-null-2"


@pytest.mark.parametrize("no_optimization", [False, True])
def test_collect_all(df: pl.DataFrame, no_optimization: bool) -> None:
lf1 = df.lazy().select(pl.col("int").sum())
lf2 = df.lazy().select((pl.col("floats") * 2).sum())
out = pl.collect_all([lf1, lf2], no_optimization=no_optimization)
assert out[0][0, 0] == 6
assert out[1][0, 0] == 12.0
assert cast(int, out[0][0, 0]) == 6
assert cast(float, out[1][0, 0]) == 12.0


def test_spearman_corr() -> None:
Expand Down Expand Up @@ -1058,8 +1058,10 @@ def test_pearson_corr() -> None:


def test_cov(fruits_cars: pl.DataFrame) -> None:
assert fruits_cars.select(pl.cov("A", "B"))[0, 0] == -2.5
assert fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0] == -2.5
assert cast(float, fruits_cars.select(pl.cov("A", "B"))[0, 0]) == -2.5
assert (
cast(float, fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0]) == -2.5
)


def test_std(fruits_cars: pl.DataFrame) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_dtype() -> None:
("dtm", pl.List(pl.Datetime)),
],
)
assert df.schema == {
assert df.schema == { # type: ignore[comparison-overlap]
"i": pl.List(pl.Int8),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_when_then_edge_cases_3994() -> None:
.groupby(["id"])
.agg(pl.col("type"))
.with_column(
pl.when(pl.col("type").arr.lengths == 0)
pl.when(pl.col("type").arr.lengths() == 0)
.then(pl.lit(None))
.otherwise(pl.col("type"))
.keep_name()
Expand Down