Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions ccflow/exttypes/polars.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import math
from io import StringIO
from typing import Any
from typing import Annotated, Any

import numpy as np
import orjson
import polars as pl
from packaging import version
from pydantic import TypeAdapter
from typing_extensions import Self

__all__ = ("PolarsExpression",)


class PolarsExpression(pl.Expr):
class _PolarsExprPydanticAnnotation:
"""Provides a polars expressions from a string"""

@classmethod
Expand Down Expand Up @@ -68,10 +67,6 @@ def _validate(cls, value: Any) -> Self:

raise ValueError(f"Supplied value '{value}' cannot be converted to a Polars expression")

@classmethod
def validate(cls, value: Any) -> Self:
"""Try to convert/validate an arbitrary value to a PolarsExpression."""
return _TYPE_ADAPTER.validate_python(value)


_TYPE_ADAPTER = TypeAdapter(PolarsExpression)
# Public annotated type for Polars expressions
PolarsExpression = Annotated[pl.Expr, _PolarsExprPydanticAnnotation]
102 changes: 65 additions & 37 deletions ccflow/tests/exttypes/test_polars.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,84 @@
import math
from unittest import TestCase

import numpy as np
import polars as pl
import pytest
import scipy
from packaging import version
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError

from ccflow import BaseModel
from ccflow.exttypes.polars import PolarsExpression


class TestPolarsExpression(TestCase):
def test_expression(self):
expression = pl.col("Col1") + pl.col("Col2")
self.assertEqual(PolarsExpression.validate(expression).meta.serialize(), expression.meta.serialize())
def test_expression_passthrough():
adapter = TypeAdapter(PolarsExpression)
expression = pl.col("Col1") + pl.col("Col2")
result = adapter.validate_python(expression)
assert result.meta.serialize() == expression.meta.serialize()

def test_expression_deserialization(self):
expression = PolarsExpression.validate("pl.col('Col1') + pl.col('Col2')")
expected_result = pl.col("Col1") + pl.col("Col2")

self.assertEqual(expression.meta.serialize(), expected_result.meta.serialize())
def test_expression_from_string():
adapter = TypeAdapter(PolarsExpression)
expected_result = pl.col("Col1") + pl.col("Col2")
expression = adapter.validate_python("pl.col('Col1') + pl.col('Col2')")
assert expression.meta.serialize() == expected_result.meta.serialize()

def test_expression_complex(self):
expression = PolarsExpression.validate(
"col('Col1') + (sp.linalg.det(numpy.eye(2, dtype=int)) - 1 ) * math.pi * c('Col2') + polars.col('Col2')"
)
expected_result = pl.col("Col1") + (scipy.linalg.det(np.eye(2, dtype=int)) - 1) * math.pi * pl.col("Col2") + pl.col("Col2")

self.assertEqual(
PolarsExpression.validate(expression).meta.serialize(),
expected_result.meta.serialize(),
)
def test_expression_complex():
adapter = TypeAdapter(PolarsExpression)
expected_result = pl.col("Col1") + (scipy.linalg.det(np.eye(2, dtype=int)) - 1) * math.pi * pl.col("Col2") + pl.col("Col2")
expression = adapter.validate_python("col('Col1') + (sp.linalg.det(numpy.eye(2, dtype=int)) - 1 ) * math.pi * c('Col2') + polars.col('Col2')")
assert expression.meta.serialize() == expected_result.meta.serialize()

def test_validation_failure(self):
with self.assertRaises(ValueError):
PolarsExpression.validate(None)

with self.assertRaises(ValueError):
PolarsExpression.validate("pl.DataFrame()")
def test_validation_failure():
adapter = TypeAdapter(PolarsExpression)
with pytest.raises(ValidationError):
adapter.validate_python(None)
with pytest.raises(ValidationError):
adapter.validate_python("pl.DataFrame()")

def test_validation_eval_failure(self):
with self.assertRaises(ValueError):
PolarsExpression.validate("invalid_statement")

def test_json_serialization(self):
expression = pl.col("Col1") + pl.col("Col2")
json_result = TypeAdapter(PolarsExpression).dump_json(expression)
if version.parse(pl.__version__) < version.parse("1.0.0"):
self.assertEqual(json_result.decode("utf-8"), expression.meta.serialize())
else:
# polars serializes into a binary format by default.
self.assertEqual(json_result.decode("utf-8"), expression.meta.serialize(format="json"))
def test_validation_eval_failure():
adapter = TypeAdapter(PolarsExpression)
with pytest.raises(ValidationError):
adapter.validate_python("invalid_statement")

expected_result = TypeAdapter(PolarsExpression).validate_json(json_result)
self.assertEqual(expected_result.meta.serialize(), expression.meta.serialize())

def test_json_serialization_roundtrip():
adapter = TypeAdapter(PolarsExpression)
expression = pl.col("Col1") + pl.col("Col2")
json_result = adapter.dump_json(expression)
if version.parse(pl.__version__) < version.parse("1.0.0"):
assert json_result.decode("utf-8") == expression.meta.serialize()
else:
assert json_result.decode("utf-8") == expression.meta.serialize(format="json")

expected_result = adapter.validate_json(json_result)
assert expected_result.meta.serialize() == expression.meta.serialize()


def test_model_field_and_dataframe_filter():
class DummyExprModel(BaseModel):
expr: PolarsExpression

m = DummyExprModel(expr="pl.col('x') > 10")
assert isinstance(m.expr, pl.Expr)

df = pl.DataFrame({"x": [5, 10, 11, 20], "y": [1, 2, 3, 4]})
filtered = df.filter(m.expr)
assert filtered.select("x").to_series().to_list() == [11, 20]


def test_model_field_and_dataframe_with_columns():
class DummyExprModel(BaseModel):
expr: PolarsExpression

raw_expr = 'pl.col("x").rolling_max(window_size=2)'
m = DummyExprModel(expr=raw_expr)
assert isinstance(m.expr, pl.Expr)

df = pl.DataFrame({"x": [5, 19, 17, 13, 8, 20], "y": [1, 2, 3, 4, 5, 6]})
transformed = df.select(m.expr)
assert transformed.to_series().to_list() == [None, 19, 19, 17, 13, 20]