Skip to content

Commit 5352e80

Browse files
committed
feat: show multiindex levels in repr & improve multiindexed coords repr
1 parent 78331b9 commit 5352e80

File tree

7 files changed

+204
-31
lines changed

7 files changed

+204
-31
lines changed

doc/release_notes.rst

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Upcoming Version
66

77
* When writing out an LP file, large variables and constraints are now chunked to avoid memory issues. This is especially useful for large models with constraints with many terms. The chunk size can be set with the `slice_size` argument in the `solve` function.
88
* Constraints which of the form `<= infinity` and `>= -infinity` are now automatically filtered out when solving. The `solve` function now has a new argument `sanitize_infinities` to control this feature. Default is set to `True`.
9+
* The representation of linopy objects with multiindexed coordinates was improved to be more readable.
910

1011
Version 0.3.15
1112
--------------

linopy/common.py

+63-4
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,33 @@ def align_lines_by_delimiter(lines: list[str], delimiter: str | list[str]):
613613
return formatted_lines
614614

615615

616+
def get_dims_with_index_levels(
617+
ds: Dataset, dims: list[Hashable] | tuple[Hashable, ...] | None = None
618+
) -> list[str]:
619+
"""
620+
Get the dimensions of a Dataset with their index levels.
621+
622+
Example usage with a dataset that has:
623+
- regular dimension 'time'
624+
- multi-indexed dimension 'station' with levels ['country', 'city']
625+
The output would be: ['time', 'station (country, city)']
626+
"""
627+
dims_with_levels = []
628+
if dims is None:
629+
dims = list(ds.dims)
630+
631+
for dim in dims:
632+
if isinstance(ds.indexes[dim], pd.MultiIndex):
633+
# For multi-indexed dimensions, format as "dim (level0, level1, ...)"
634+
names = ds.indexes[dim].names
635+
dims_with_levels.append(f"{dim} ({', '.join(names)})")
636+
else:
637+
# For regular dimensions, just add the dimension name
638+
dims_with_levels.append(str(dim))
639+
640+
return dims_with_levels
641+
642+
616643
def get_label_position(
617644
obj, values: int | np.ndarray
618645
) -> (
@@ -658,10 +685,42 @@ def find_single(value: int) -> tuple[str, dict] | tuple[None, None]:
658685
raise ValueError("Array's with more than two dimensions is not supported")
659686

660687

661-
def print_coord(coord):
662-
if isinstance(coord, dict):
663-
coord = coord.values()
664-
return "[" + ", ".join([str(c) for c in coord]) + "]" if len(coord) else ""
688+
def print_coord(coord: dict | Iterable) -> str:
689+
"""
690+
Format coordinates into a string representation.
691+
692+
Args:
693+
coord: Dictionary or iterable containing coordinate values.
694+
Values can be numbers, strings, or nested iterables.
695+
696+
Returns:
697+
Formatted string representation of coordinates in brackets,
698+
with nested coordinates grouped in parentheses.
699+
700+
Examples:
701+
>>> print_coord({"x": 1, "y": 2})
702+
'[1, 2]'
703+
>>> print_coord([1, 2, 3])
704+
'[1, 2, 3]'
705+
>>> print_coord([(1, 2), (3, 4)])
706+
'[(1, 2), (3, 4)]'
707+
"""
708+
# Handle empty input
709+
if not coord:
710+
return ""
711+
712+
# Extract values if input is dictionary
713+
values = coord.values() if isinstance(coord, dict) else coord
714+
715+
# Convert each coordinate component to string
716+
formatted = []
717+
for value in values:
718+
if isinstance(value, (list, tuple)):
719+
formatted.append(f"({', '.join(str(x) for x in value)})")
720+
else:
721+
formatted.append(str(value))
722+
723+
return f"[{', '.join(formatted)}]"
665724

666725

667726
def print_single_variable(model, label):

linopy/constraints.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import functools
88
import warnings
9-
from collections.abc import ItemsView, Iterator
9+
from collections.abc import Hashable, ItemsView, Iterator
1010
from dataclasses import dataclass
1111
from itertools import product
1212
from typing import (
@@ -35,6 +35,7 @@
3535
filter_nulls_polars,
3636
format_string_as_variable_name,
3737
generate_indices_for_printout,
38+
get_dims_with_index_levels,
3839
get_label_position,
3940
group_terms_polars,
4041
has_optimized_model,
@@ -255,13 +256,20 @@ def name(self):
255256
return self.attrs["name"]
256257

257258
@property
258-
def coord_dims(self):
259+
def coord_dims(self) -> tuple[Hashable, ...]:
259260
return tuple(k for k in self.dims if k not in HELPER_DIMS)
260261

261262
@property
262-
def coord_sizes(self):
263+
def coord_sizes(self) -> dict[Hashable, int]:
263264
return {k: v for k, v in self.sizes.items() if k not in HELPER_DIMS}
264265

266+
@property
267+
def coord_names(self) -> list[Hashable]:
268+
"""
269+
Get the names of the coordinates.
270+
"""
271+
return get_dims_with_index_levels(self.data, self.coord_dims)
272+
265273
@property
266274
def is_assigned(self):
267275
return self._assigned
@@ -273,6 +281,7 @@ def __repr__(self):
273281
max_lines = options["display_max_rows"]
274282
dims = list(self.coord_sizes.keys())
275283
ndim = len(dims)
284+
dim_names = self.coord_names
276285
dim_sizes = list(self.coord_sizes.values())
277286
size = np.prod(dim_sizes) # that the number of theoretical printouts
278287
masked_entries = (~self.mask).sum().values if self.mask is not None else 0
@@ -297,16 +306,16 @@ def __repr__(self):
297306
)
298307
sign = SIGNS_pretty[self.sign.values[indices]]
299308
rhs = self.rhs.values[indices]
300-
line = f"{print_coord(coord)}: {expr} {sign} {rhs}"
309+
line = print_coord(coord) + f": {expr} {sign} {rhs}"
301310
else:
302-
line = f"{print_coord(coord)}: None"
311+
line = print_coord(coord) + ": None"
303312
lines.append(line)
304313
lines = align_lines_by_delimiter(lines, list(SIGNS_pretty.values()))
305314

306-
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dims, dim_sizes))
315+
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dim_names, dim_sizes))
307316
mask_str = f" - {masked_entries} masked entries" if masked_entries else ""
308317
underscore = "-" * (len(shape_str) + len(mask_str) + len(header_string) + 4)
309-
lines.insert(0, f"{header_string} ({shape_str}){mask_str}:\n{underscore}")
318+
lines.insert(0, f"{header_string} [{shape_str}]{mask_str}:\n{underscore}")
310319
elif size == 1:
311320
expr = print_single_expression(self.coeffs, self.vars, 0, self.model)
312321
lines.append(

linopy/expressions.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@
4949
filter_nulls_polars,
5050
forward_as_properties,
5151
generate_indices_for_printout,
52+
get_dims_with_index_levels,
5253
get_index_map,
5354
group_terms_polars,
5455
has_optimized_model,
5556
iterate_slices,
57+
print_coord,
5658
print_single_expression,
5759
to_dataframe,
5860
to_polars,
@@ -404,7 +406,8 @@ def __repr__(self) -> str:
404406
Print the expression arrays.
405407
"""
406408
max_lines = options["display_max_rows"]
407-
dims = list(self.coord_sizes.keys())
409+
dims = list(self.coord_sizes)
410+
dim_names = self.coord_names
408411
ndim = len(dims)
409412
dim_sizes = list(self.coord_sizes.values())
410413
size = np.prod(dim_sizes) # that the number of theoretical printouts
@@ -418,26 +421,26 @@ def __repr__(self) -> str:
418421
if indices is None:
419422
lines.append("\t\t...")
420423
else:
421-
coord_values = ", ".join(
422-
str(self.data.indexes[dims[i]][ind])
423-
for i, ind in enumerate(indices)
424-
)
424+
coord = [
425+
self.data.indexes[dims[i]][ind] for i, ind in enumerate(indices)
426+
]
425427
if self.mask is None or self.mask.values[indices]:
426428
expr = print_single_expression(
427429
self.coeffs.values[indices],
428430
self.vars.values[indices],
429431
self.const.values[indices],
430432
self.model,
431433
)
432-
line = f"[{coord_values}]: {expr}"
434+
435+
line = print_coord(coord) + f": {expr}"
433436
else:
434-
line = f"[{coord_values}]: None"
437+
line = print_coord(coord) + ": None"
435438
lines.append(line)
436439

437-
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dims, dim_sizes))
440+
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dim_names, dim_sizes))
438441
mask_str = f" - {masked_entries} masked entries" if masked_entries else ""
439442
underscore = "-" * (len(shape_str) + len(mask_str) + len(header_string) + 4)
440-
lines.insert(0, f"{header_string} ({shape_str}){mask_str}:\n{underscore}")
443+
lines.insert(0, f"{header_string} [{shape_str}]{mask_str}:\n{underscore}")
441444
elif size == 1:
442445
expr = print_single_expression(
443446
self.coeffs, self.vars, self.const, self.model
@@ -735,9 +738,13 @@ def coord_dims(self) -> tuple[Hashable, ...]:
735738
return tuple(k for k in self.dims if k not in HELPER_DIMS)
736739

737740
@property
738-
def coord_sizes(self) -> dict[str, int]:
741+
def coord_sizes(self) -> dict[Hashable, int]:
739742
return {k: v for k, v in self.sizes.items() if k not in HELPER_DIMS}
740743

744+
@property
745+
def coord_names(self) -> list[Hashable]:
746+
return get_dims_with_index_levels(self.data, self.coord_dims)
747+
741748
@property
742749
def vars(self):
743750
return self.data.vars

linopy/variables.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
filter_nulls_polars,
4040
format_string_as_variable_name,
4141
generate_indices_for_printout,
42+
get_dims_with_index_levels,
4243
get_label_position,
4344
has_optimized_model,
4445
is_constant,
@@ -50,7 +51,7 @@
5051
to_polars,
5152
)
5253
from linopy.config import options
53-
from linopy.constants import TERM_DIM
54+
from linopy.constants import HELPER_DIMS, TERM_DIM
5455
from linopy.solvers import set_int_index
5556
from linopy.types import NotImplementedType
5657

@@ -321,7 +322,8 @@ def __repr__(self) -> str:
321322
Print the variable arrays.
322323
"""
323324
max_lines = options["display_max_rows"]
324-
dims = list(self.sizes.keys())
325+
dims = list(self.sizes)
326+
dim_names = self.coord_names
325327
dim_sizes = list(self.sizes.values())
326328
masked_entries = (~self.mask).sum().values
327329
lines = []
@@ -343,7 +345,7 @@ def __repr__(self) -> str:
343345
lines.append(line)
344346
# lines = align_lines_by_delimiter(lines, "∈")
345347

346-
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dims, dim_sizes))
348+
shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dim_names, dim_sizes))
347349
mask_str = f" - {masked_entries} masked entries" if masked_entries else ""
348350
lines.insert(
349351
0,
@@ -663,6 +665,21 @@ def type(self):
663665
else:
664666
return "Continuous Variable"
665667

668+
@property
669+
def coord_dims(self) -> tuple[Hashable, ...]:
670+
return tuple(k for k in self.dims if k not in HELPER_DIMS)
671+
672+
@property
673+
def coord_sizes(self) -> dict[Hashable, int]:
674+
return {k: v for k, v in self.sizes.items() if k not in HELPER_DIMS}
675+
676+
@property
677+
def coord_names(self) -> list[Hashable]:
678+
"""
679+
Get the names of the coordinates.
680+
"""
681+
return get_dims_with_index_levels(self.data, self.coord_dims)
682+
666683
@property
667684
def range(self) -> tuple[int, int]:
668685
"""
@@ -1089,8 +1106,6 @@ def __getitem__(self, keys: Any) -> ScalarVariable:
10891106
keys = keys if isinstance(keys, tuple) else (keys,)
10901107
object = self.object
10911108

1092-
if not all(map(pd.api.types.is_scalar, keys)):
1093-
raise ValueError("Only scalar keys are allowed.")
10941109
# return single scalar
10951110
if not object.labels.ndim:
10961111
return ScalarVariable(object.labels.item(), object.model)

test/test_common.py

+75
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
as_dataarray,
1616
assign_multiindex_safe,
1717
best_int,
18+
get_dims_with_index_levels,
1819
iterate_slices,
1920
)
2021

@@ -529,3 +530,77 @@ def test_iterate_slices_no_slice_dims():
529530
for s in slices:
530531
assert isinstance(s, xr.Dataset)
531532
assert set(s.dims) == set(ds.dims)
533+
534+
535+
def test_get_dims_with_index_levels():
536+
# Create test data
537+
538+
# Case 1: Simple dataset with regular dimensions
539+
ds1 = xr.Dataset(
540+
{"temp": (("time", "lat"), np.random.rand(3, 2))}, # noqa: NPY002
541+
coords={"time": pd.date_range("2024-01-01", periods=3), "lat": [0, 1]},
542+
)
543+
544+
# Case 2: Dataset with a multi-index dimension
545+
stations_index = pd.MultiIndex.from_product(
546+
[["USA", "Canada"], ["NYC", "Toronto"]], names=["country", "city"]
547+
)
548+
stations_coords = xr.Coordinates.from_pandas_multiindex(stations_index, "station")
549+
ds2 = xr.Dataset(
550+
{"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002
551+
coords={"time": pd.date_range("2024-01-01", periods=3), **stations_coords},
552+
)
553+
554+
# Case 3: Dataset with unnamed multi-index levels
555+
unnamed_stations_index = pd.MultiIndex.from_product(
556+
[["USA", "Canada"], ["NYC", "Toronto"]]
557+
)
558+
unnamed_stations_coords = xr.Coordinates.from_pandas_multiindex(
559+
unnamed_stations_index, "station"
560+
)
561+
ds3 = xr.Dataset(
562+
{"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002
563+
coords={
564+
"time": pd.date_range("2024-01-01", periods=3),
565+
**unnamed_stations_coords,
566+
},
567+
)
568+
569+
# Case 4: Dataset with multiple multi-indexed dimensions
570+
locations_index = pd.MultiIndex.from_product(
571+
[["North", "South"], ["A", "B"]], names=["region", "site"]
572+
)
573+
locations_coords = xr.Coordinates.from_pandas_multiindex(
574+
locations_index, "location"
575+
)
576+
577+
ds4 = xr.Dataset(
578+
{"temp": (("time", "station", "location"), np.random.rand(2, 4, 4))}, # noqa: NPY002
579+
coords={
580+
"time": pd.date_range("2024-01-01", periods=2),
581+
**stations_coords,
582+
**locations_coords,
583+
},
584+
)
585+
586+
# Run tests
587+
588+
# Test case 1: Regular dimensions
589+
assert get_dims_with_index_levels(ds1) == ["time", "lat"]
590+
591+
# Test case 2: Named multi-index
592+
assert get_dims_with_index_levels(ds2) == ["time", "station (country, city)"]
593+
594+
# Test case 3: Unnamed multi-index
595+
assert get_dims_with_index_levels(ds3) == [
596+
"time",
597+
"station (station_level_0, station_level_1)",
598+
]
599+
600+
# Test case 4: Multiple multi-indices
601+
expected = ["time", "station (country, city)", "location (region, site)"]
602+
assert get_dims_with_index_levels(ds4) == expected
603+
604+
# Test case 5: Empty dataset
605+
ds5 = xr.Dataset()
606+
assert get_dims_with_index_levels(ds5) == []

0 commit comments

Comments
 (0)