Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Commit 67268fb

Browse files
authored
Allow pass column to plcontainer_apply (#229)
This PR allows to pass specific columns when using plcontainer_apply.
1 parent b5cda8e commit 67268fb

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

greenplumpython/func.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,18 @@ def apply(
117117
):
118118
return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType
119119
_serialize_to_type_name(return_annotation, db=db, for_return=True)
120+
input_args = self._args
121+
if len(input_args) == 0:
122+
raise Exception("No input data specified, please specify a DataFrame or Columns")
123+
input_clause = (
124+
"*"
125+
if (len(input_args) == 1 and isinstance(input_args[0], DataFrame))
126+
else ",".join([arg._serialize(db=db) for arg in input_args])
127+
)
120128
return DataFrame(
121129
f"""
122130
SELECT * FROM plcontainer_apply(TABLE(
123-
SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS
131+
SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS
124132
{_defined_types[return_annotation.__args__[0]]._serialize(db=db)}
125133
""",
126134
db=db,
@@ -370,6 +378,7 @@ def _serialize(self, db: Database) -> str:
370378
f" import sys as {sys_lib_name}\n"
371379
f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n"
372380
f" raise ModuleNotFoundError\n"
381+
f" {sys_lib_name}.modules['plpy']=plpy\n"
373382
f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n"
374383
f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n"
375384
f" except ModuleNotFoundError:\n"

tests/test_plcontainer.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,56 @@
11
from dataclasses import dataclass
22

3+
import pytest
4+
35
import greenplumpython as gp
46
from tests import db
57

68

7-
def test_simple_func(db: gp.Database):
8-
@dataclass
9-
class Int:
10-
i: int
9+
@dataclass
10+
class Int:
11+
i: int
12+
13+
14+
@dataclass
15+
class Pair:
16+
i: int
17+
j: int
18+
19+
20+
@pytest.fixture
21+
def t(db: gp.Database):
22+
rows = [(i, i) for i in range(10)]
23+
return db.create_dataframe(rows=rows, column_names=["a", "b"])
24+
25+
26+
@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
27+
def add_one(x: list[Int]) -> list[Int]:
28+
return [{"i": arg["i"] + 1} for arg in x]
1129

12-
@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
13-
def add_one(x: list[Int]) -> list[Int]:
14-
return [{"i": arg["i"] + 1} for arg in x]
1530

31+
def test_simple_func(db: gp.Database):
1632
assert (
1733
len(
1834
list(
1935
db.create_dataframe(columns={"i": range(10)}).apply(
20-
lambda _: add_one(), expand=True
36+
lambda t: add_one(t), expand=True
2137
)
2238
)
2339
)
2440
== 10
2541
)
42+
43+
44+
def test_func_no_input(db: gp.Database):
45+
46+
with pytest.raises(Exception) as exc_info: # no input data for func raises Exception
47+
db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True)
48+
assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value)
49+
50+
51+
def test_func_column(db: gp.Database, t: gp.DataFrame):
52+
@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
53+
def add(x: list[Pair]) -> list[Int]:
54+
return [{"i": arg["i"] + arg["j"]} for arg in x]
55+
56+
assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10

0 commit comments

Comments
 (0)