Skip to content

Commit cac44fc

Browse files
feat(python): Enable additional cases of loading numpy.float16 values (as Float32)
1 parent d05b942 commit cac44fc

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

py-polars/polars/datatypes/convert.py

+1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataTy
206206
# (np.dtype().kind, np.dtype().itemsize)
207207
("M", 8): Datetime,
208208
("b", 1): Boolean,
209+
("f", 2): Float32,
209210
("f", 4): Float32,
210211
("f", 8): Float64,
211212
("i", 1): Int8,

py-polars/tests/unit/interop/numpy/test_numpy.py

+23
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
("uint16", [1, 3, 2], pl.UInt16, np.uint16),
1818
("uint32", [1, 3, 2], pl.UInt32, np.uint32),
1919
("uint64", [1, 3, 2], pl.UInt64, np.uint64),
20+
("float16", [-123.0, 0.0, 456.0], pl.Float32, np.float16),
2021
("float32", [21.7, 21.8, 21], pl.Float32, np.float32),
2122
("float64", [21.7, 21.8, 21], pl.Float64, np.float64),
2223
("bool", [True, False, False], pl.Boolean, np.bool_),
@@ -77,3 +78,25 @@ def test_numpy_disambiguation() -> None:
7778

7879
def test_respect_dtype_with_series_from_numpy() -> None:
7980
assert pl.Series("foo", np.array([1, 2, 3]), dtype=pl.UInt32).dtype == pl.UInt32
81+
82+
83+
@pytest.mark.parametrize(
84+
("np_dtype_cls", "expected_pl_dtype"),
85+
[
86+
(np.int8, pl.Int8),
87+
(np.int16, pl.Int16),
88+
(np.int32, pl.Int32),
89+
(np.int64, pl.Int64),
90+
(np.uint8, pl.UInt8),
91+
(np.uint16, pl.UInt16),
92+
(np.uint32, pl.UInt32),
93+
(np.uint64, pl.UInt64),
94+
(np.float16, pl.Float32), # << note: we don't currently have a native f16
95+
(np.float32, pl.Float32),
96+
(np.float64, pl.Float64),
97+
],
98+
)
99+
def test_init_from_numpy_values(np_dtype_cls: Any, expected_pl_dtype: Any) -> None:
100+
# test init from raw numpy values (vs arrays)
101+
s = pl.Series("n", [np_dtype_cls(0), np_dtype_cls(4), np_dtype_cls(8)])
102+
assert s.dtype == expected_pl_dtype

0 commit comments

Comments
 (0)