Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions ccflow/exttypes/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import polars as pl
from packaging import version
from pydantic import TypeAdapter
from typing_extensions import Self
from typing_extensions import Annotated, Self

__all__ = ("PolarsExpression",)
__all__ = ("PolarsExpression", "PolarsExpr")


class PolarsExpression(pl.Expr):
class _PolarsExprPydanticAnnotation:
"""Provides a polars expressions from a string"""

@classmethod
Expand Down Expand Up @@ -68,10 +68,18 @@ def _validate(cls, value: Any) -> Self:

raise ValueError(f"Supplied value '{value}' cannot be converted to a Polars expression")


class PolarsExpression(_PolarsExprPydanticAnnotation, pl.Expr):
"""Provides a polars expressions from a string"""

@classmethod
def validate(cls, value: Any) -> Self:
"""Try to convert/validate an arbitrary value to a PolarsExpression."""
return _TYPE_ADAPTER.validate_python(value)


_TYPE_ADAPTER = TypeAdapter(PolarsExpression)


# Public alias for use with Pydantic v2 Annotated type
PolarsExpr = Annotated[pl.Expr, _PolarsExprPydanticAnnotation]
119 changes: 81 additions & 38 deletions ccflow/tests/exttypes/test_polars.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,99 @@
import math
from unittest import TestCase

import numpy as np
import polars as pl
import pytest
import scipy
from packaging import version
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError

from ccflow.exttypes.polars import PolarsExpression
from ccflow import BaseModel
from ccflow.exttypes.polars import PolarsExpr, PolarsExpression


class TestPolarsExpression(TestCase):
def test_expression(self):
expression = pl.col("Col1") + pl.col("Col2")
self.assertEqual(PolarsExpression.validate(expression).meta.serialize(), expression.meta.serialize())
@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_expression_passthrough(typ):
adapter = TypeAdapter(typ)
expression = pl.col("Col1") + pl.col("Col2")
result = adapter.validate_python(expression)
assert result.meta.serialize() == expression.meta.serialize()

def test_expression_deserialization(self):
expression = PolarsExpression.validate("pl.col('Col1') + pl.col('Col2')")
expected_result = pl.col("Col1") + pl.col("Col2")

self.assertEqual(expression.meta.serialize(), expected_result.meta.serialize())
@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_expression_from_string(typ):
adapter = TypeAdapter(typ)
expected_result = pl.col("Col1") + pl.col("Col2")
expression = adapter.validate_python("pl.col('Col1') + pl.col('Col2')")
assert expression.meta.serialize() == expected_result.meta.serialize()

def test_expression_complex(self):
expression = PolarsExpression.validate(
"col('Col1') + (sp.linalg.det(numpy.eye(2, dtype=int)) - 1 ) * math.pi * c('Col2') + polars.col('Col2')"
)
expected_result = pl.col("Col1") + (scipy.linalg.det(np.eye(2, dtype=int)) - 1) * math.pi * pl.col("Col2") + pl.col("Col2")

self.assertEqual(
PolarsExpression.validate(expression).meta.serialize(),
expected_result.meta.serialize(),
)
@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_expression_complex(typ):
adapter = TypeAdapter(typ)
expected_result = pl.col("Col1") + (scipy.linalg.det(np.eye(2, dtype=int)) - 1) * math.pi * pl.col("Col2") + pl.col("Col2")
expression = adapter.validate_python("col('Col1') + (sp.linalg.det(numpy.eye(2, dtype=int)) - 1 ) * math.pi * c('Col2') + polars.col('Col2')")
assert expression.meta.serialize() == expected_result.meta.serialize()

def test_validation_failure(self):
with self.assertRaises(ValueError):
PolarsExpression.validate(None)

with self.assertRaises(ValueError):
PolarsExpression.validate("pl.DataFrame()")
@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_validation_failure(typ):
adapter = TypeAdapter(typ)
with pytest.raises(ValidationError):
adapter.validate_python(None)
with pytest.raises(ValidationError):
adapter.validate_python("pl.DataFrame()")

def test_validation_eval_failure(self):
with self.assertRaises(ValueError):
PolarsExpression.validate("invalid_statement")

def test_json_serialization(self):
expression = pl.col("Col1") + pl.col("Col2")
json_result = TypeAdapter(PolarsExpression).dump_json(expression)
if version.parse(pl.__version__) < version.parse("1.0.0"):
self.assertEqual(json_result.decode("utf-8"), expression.meta.serialize())
else:
# polars serializes into a binary format by default.
self.assertEqual(json_result.decode("utf-8"), expression.meta.serialize(format="json"))
@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_validation_eval_failure(typ):
adapter = TypeAdapter(typ)
with pytest.raises(ValidationError):
adapter.validate_python("invalid_statement")

expected_result = TypeAdapter(PolarsExpression).validate_json(json_result)
self.assertEqual(expected_result.meta.serialize(), expression.meta.serialize())

@pytest.mark.parametrize("typ", [PolarsExpression, PolarsExpr])
def test_json_serialization_roundtrip(typ):
adapter = TypeAdapter(typ)
expression = pl.col("Col1") + pl.col("Col2")
json_result = adapter.dump_json(expression)
if version.parse(pl.__version__) < version.parse("1.0.0"):
assert json_result.decode("utf-8") == expression.meta.serialize()
else:
assert json_result.decode("utf-8") == expression.meta.serialize(format="json")

expected_result = adapter.validate_json(json_result)
assert expected_result.meta.serialize() == expression.meta.serialize()


def test_model_field_and_dataframe_filter():
class DummyExprModel(BaseModel):
expr: PolarsExpr

m = DummyExprModel(expr="pl.col('x') > 10")
assert isinstance(m.expr, pl.Expr)

df = pl.DataFrame({"x": [5, 10, 11, 20], "y": [1, 2, 3, 4]})
filtered = df.filter(m.expr)
assert filtered.select("x").to_series().to_list() == [11, 20]


# Explicitly test the legacy classmethod validator for backwards compatibility
def test_polars_expression_validate_passthrough():
expression = pl.col("Col1") + pl.col("Col2")
result = PolarsExpression.validate(expression)
assert result.meta.serialize() == expression.meta.serialize()


def test_polars_expression_validate_from_string():
result = PolarsExpression.validate("pl.col('Col1') + pl.col('Col2')")
expected_result = pl.col("Col1") + pl.col("Col2")
assert result.meta.serialize() == expected_result.meta.serialize()


def test_polars_expression_validate_errors():
with pytest.raises(ValueError):
PolarsExpression.validate(None)
with pytest.raises(ValueError):
PolarsExpression.validate("pl.DataFrame()")
with pytest.raises(ValueError):
PolarsExpression.validate("invalid_statement")