Skip to content

Commit 5bd5a23

Browse files
committed
Fix tests for add_data
1 parent 510287a commit 5bd5a23

File tree

4 files changed

+39
-68
lines changed

4 files changed

+39
-68
lines changed

ixmp4/data/db/base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,7 @@ def collect_indexsets_to_check(self) -> dict[str, list[Any]]:
488488
IndexSet.elements."""
489489
return {column.name: column.indexset.elements for column in self.columns}
490490

491-
@validates("data")
492-
def validate_data(self, key, data: dict[str, Any]):
493-
data_frame: pd.DataFrame = pd.DataFrame.from_dict(data)
491+
def _validate_data(self, data_frame: pd.DataFrame, data: dict[str, Any]) -> None:
494492
# TODO for all of the following, we might want to create unique exceptions
495493
# Could we make both more specific by specifiying missing/extra columns?
496494
if len(data_frame.columns) < len(self.columns):
@@ -524,4 +522,8 @@ def validate_data(self, key, data: dict[str, Any]):
524522
"and Columns it is constrained to!"
525523
)
526524

525+
@validates("data")
526+
def validate_data(self, key, data: dict[str, Any]):
527+
data_frame: pd.DataFrame = pd.DataFrame.from_dict(data)
528+
self._validate_data(data_frame=data_frame, data=data)
527529
return data_frame.to_dict(orient="list")
+11-25
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,14 @@
1-
from typing import ClassVar
1+
from typing import Any, ClassVar
22

3-
from sqlalchemy import Column as sqlaColumn
4-
from sqlalchemy import Table
3+
import pandas as pd
4+
from sqlalchemy.orm import validates
55

66
from ixmp4 import db
77
from ixmp4.data import types
88
from ixmp4.data.abstract import optimization as abstract
9-
from ixmp4.data.db.unit import Unit
109

1110
from .. import Column, base
1211

13-
# Many Parameters can refer to many Units
14-
# note for a Core table, we use the sqlalchemy.Column construct,
15-
# not sqlalchemy.orm.mapped_column
16-
17-
# TODO Is this enough/correct? This follows many-to-many currently with units:
18-
# But does that work?
19-
20-
parameter_unit_association_table = Table(
21-
"optimization_parameter_unit_association_table",
22-
base.BaseModel.metadata,
23-
sqlaColumn("parameter__id", db.ForeignKey("optimization_parameter.id")),
24-
sqlaColumn("unit__id", db.ForeignKey("unit.id")),
25-
)
26-
2712

2813
class Parameter(base.BaseModel, base.OptimizationDataMixin, base.UniqueNameRunIDMixin):
2914
# NOTE: These might be mixin-able, but would require some abstraction
@@ -33,12 +18,13 @@ class Parameter(base.BaseModel, base.OptimizationDataMixin, base.UniqueNameRunID
3318

3419
# constrained_to_indexsets: ClassVar[list[str] | None] = None
3520

36-
values: types.JsonList = db.Column(db.JsonType, nullable=False, default=[])
37-
units: types.Mapped[list["Unit"]] = db.relationship(
38-
secondary=parameter_unit_association_table
39-
)
40-
# TODO: need some kind of primaryjoin adaption and unit_ids so that each unit_id is
41-
# foreignkeyed to Unit.id correctly
42-
4321
# TODO Same as in table/model.py
4422
columns: types.Mapped[list["Column"]] = db.relationship() # type: ignore
23+
24+
@validates("data")
25+
def validate_data(self, key, data: dict[str, Any]):
26+
data_frame: pd.DataFrame = pd.DataFrame.from_dict(data)
27+
data_frame_to_validate = data_frame.drop(columns=["values", "units"])
28+
29+
self._validate_data(data_frame=data_frame_to_validate, data=data)
30+
return data_frame.to_dict(orient="list")

ixmp4/data/db/optimization/parameter/repository.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ixmp4 import db
66
from ixmp4.data.abstract import optimization as abstract
77
from ixmp4.data.auth.decorators import guard
8+
from ixmp4.data.db.unit import Unit
89

910
from .. import ColumnRepository, base
1011
from .docs import ParameterDocsRepository
@@ -150,26 +151,24 @@ def add_data(self, parameter_id: int, data: dict[str, Any] | pd.DataFrame) -> No
150151
data = pd.DataFrame.from_dict(data=data)
151152
parameter = self.get_by_id(id=parameter_id)
152153

153-
try:
154-
values = data.pop(item="values").to_list()
155-
except KeyError as e:
156-
raise KeyError("Parameter.data must include a 'values' column!") from e
154+
missing_columns = set(["values", "units"]) - set(data.columns)
155+
assert (
156+
not missing_columns
157+
), f"Parameter.data must include the column(s): {' ,'.join(missing_columns)}!"
157158

158-
try:
159-
units = [
159+
# Can use a set for now, need full column if we care about order
160+
for unit_name in set(data["units"]):
161+
try:
160162
self.backend.units.get(name=unit_name)
161-
for unit_name in data.pop(item="units")
162-
]
163-
except KeyError as e:
164-
raise KeyError("Parameter.data must include a 'units' column!") from e
163+
except Unit.NotFound as e:
164+
# TODO Add a helpful hint on how to check defined Units
165+
raise Unit.NotFound(
166+
message=f"'{unit_name}' is not defined for this Platform!"
167+
) from e
165168

166169
parameter.data = pd.concat(
167170
[pd.DataFrame.from_dict(parameter.data), data]
168171
).to_dict(orient="list")
169-
parameter.values = parameter.values + values
170-
171-
# TODO does this actually work? Do we set the relationships correctly here?
172-
parameter.units = parameter.units + units
173172

174173
self.session.add(parameter)
175174
self.session.commit()

tests/data/test_optimization_parameter.py

+10-26
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ def test_create_parameter(self, test_mp, request):
4949
assert parameter.run__id == run.id
5050
assert parameter.name == "Parameter"
5151
assert parameter.data == {} # JsonDict type currently requires a dict, not None
52-
assert parameter.values == []
53-
assert parameter.units == []
5452
assert parameter.columns[0].name == "Indexset"
5553
assert parameter.columns[0].constrained_to_indexset == indexset_1.id
5654

@@ -161,8 +159,6 @@ def test_parameter_add_data(self, test_mp, request):
161159
parameter = test_mp.backend.optimization.parameters.get(
162160
run_id=run.id, name="Parameter"
163161
)
164-
assert parameter.values == test_data_1.pop("values")
165-
assert [unit.name for unit in parameter.units] == test_data_1.pop("units")
166162
assert parameter.data == test_data_1
167163

168164
parameter_2 = test_mp.backend.optimization.parameters.create(
@@ -171,7 +167,9 @@ def test_parameter_add_data(self, test_mp, request):
171167
constrained_to_indexsets=[indexset_1.name, indexset_2.name],
172168
)
173169

174-
with pytest.raises(KeyError, match="must include a 'values' column!"):
170+
with pytest.raises(
171+
AssertionError, match=r"must include the column\(s\): values!"
172+
):
175173
test_mp.backend.optimization.parameters.add_data(
176174
parameter_id=parameter_2.id,
177175
data=pd.DataFrame(
@@ -183,7 +181,9 @@ def test_parameter_add_data(self, test_mp, request):
183181
),
184182
)
185183

186-
with pytest.raises(KeyError, match="must include a 'units' column!"):
184+
with pytest.raises(
185+
AssertionError, match=r"must include the column\(s\): units!"
186+
):
187187
test_mp.backend.optimization.parameters.add_data(
188188
parameter_id=parameter_2.id,
189189
data=pd.DataFrame(
@@ -232,8 +232,6 @@ def test_parameter_add_data(self, test_mp, request):
232232
parameter_2 = test_mp.backend.optimization.parameters.get(
233233
run_id=run.id, name="Parameter 2"
234234
)
235-
assert parameter_2.values == test_data_2.pop("values")
236-
assert [unit.name for unit in parameter_2.units] == test_data_2.pop("units")
237235
assert parameter_2.data == test_data_2
238236

239237
# Test order is conserved with varying types and upon later addition of data
@@ -258,17 +256,8 @@ def test_parameter_add_data(self, test_mp, request):
258256
parameter_3 = test_mp.backend.optimization.parameters.get(
259257
run_id=run.id, name="Parameter 3"
260258
)
261-
assert parameter_3.values == test_data_3.pop("values")
262-
assert [unit.name for unit in parameter_3.units] == test_data_3.pop("units")
263259
assert parameter_3.data == test_data_3
264260

265-
# Repopulate test_data after pop()
266-
test_data_3 = {
267-
"Column 1": ["bar", "foo", ""],
268-
"Column 2": [2, 3, 1],
269-
"values": ["3", 2.0, 1],
270-
"units": [unit_3.name, unit_2.name, unit.name],
271-
}
272261
test_data_4 = {
273262
"Column 1": ["foo", "", "bar"],
274263
"Column 2": [2, 3, 1],
@@ -281,15 +270,10 @@ def test_parameter_add_data(self, test_mp, request):
281270
parameter_3 = test_mp.backend.optimization.parameters.get(
282271
run_id=run.id, name="Parameter 3"
283272
)
284-
assert parameter_3.values == test_data_3.pop("values") + test_data_4.pop(
285-
"values"
286-
)
287-
assert [unit.name for unit in parameter_3.units] == test_data_3.pop(
288-
"units"
289-
) + test_data_4.pop("units")
290-
assert parameter_3.data == pd.DataFrame([test_data_3, test_data_4]).to_dict(
291-
orient="list"
292-
)
273+
test_data_5 = test_data_3.copy()
274+
for key, value in test_data_4.items():
275+
test_data_5[key].extend(value)
276+
assert parameter_3.data == test_data_5
293277

294278
# def test_list_parameter(self, test_mp, request):
295279
# test_mp = request.getfixturevalue(test_mp)

0 commit comments

Comments
 (0)