Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataTree repr to not repeat inherited coordinates #9532

Merged
merged 2 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary.append(f"{dims_start}({dims_values})")

if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))
node_coords = node.to_dataset(inherited=False).coords
summary.append(coords_repr(node_coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
Expand Down
107 changes: 106 additions & 1 deletion xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import sys
import typing
from copy import copy, deepcopy
from textwrap import dedent
Expand All @@ -15,6 +16,8 @@
from xarray.testing import assert_equal, assert_identical
from xarray.tests import assert_array_equal, create_test_data, source_ndarray

ON_WINDOWS = sys.platform == "win32"


class TestTreeCreation:
def test_empty(self):
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def test_repr_two_children(self):
{
"/": Dataset(coords={"x": [1.0]}),
"/first_child": None,
"/second_child": Dataset({"foo": ("x", [0.0])}),
"/second_child": Dataset({"foo": ("x", [0.0])}, coords={"z": 1.0}),
}
)

Expand All @@ -1067,6 +1070,8 @@ def test_repr_two_children(self):
├── Group: /first_child
└── Group: /second_child
Dimensions: (x: 1)
Coordinates:
z float64 8B 1.0
Data variables:
foo (x) float64 8B 0.0
"""
Expand All @@ -1091,6 +1096,8 @@ def test_repr_two_children(self):
<xarray.DataTree 'second_child'>
Group: /second_child
Dimensions: (x: 1)
Coordinates:
z float64 8B 1.0
Inherited coordinates:
* x (x) float64 8B 1.0
Data variables:
Expand Down Expand Up @@ -1138,6 +1145,104 @@ def test_repr_inherited_dims(self):
).strip()
assert result == expected

@pytest.mark.skipif(
ON_WINDOWS, reason="windows (pre NumPy2) uses int32 instead of int64"
)
def test_doc_example(self):
# regression test for https://github.com/pydata/xarray/issues/9499
time = xr.DataArray(data=["2022-01", "2023-01"], dims="time")
stations = xr.DataArray(data=list("abcdef"), dims="station")
lon = [-100, -80, -60]
lat = [10, 20, 30]
# Set up fake data
wind_speed = xr.DataArray(np.ones((2, 6)) * 2, dims=("time", "station"))
pressure = xr.DataArray(np.ones((2, 6)) * 3, dims=("time", "station"))
air_temperature = xr.DataArray(np.ones((2, 6)) * 4, dims=("time", "station"))
dewpoint = xr.DataArray(np.ones((2, 6)) * 5, dims=("time", "station"))
infrared = xr.DataArray(np.ones((2, 3, 3)) * 6, dims=("time", "lon", "lat"))
true_color = xr.DataArray(np.ones((2, 3, 3)) * 7, dims=("time", "lon", "lat"))
tree = xr.DataTree.from_dict(
{
"/": xr.Dataset(
coords={"time": time},
),
"/weather": xr.Dataset(
coords={"station": stations},
data_vars={
"wind_speed": wind_speed,
"pressure": pressure,
},
),
"/weather/temperature": xr.Dataset(
data_vars={
"air_temperature": air_temperature,
"dewpoint": dewpoint,
},
),
"/satellite": xr.Dataset(
coords={"lat": lat, "lon": lon},
data_vars={
"infrared": infrared,
"true_color": true_color,
},
),
},
)

result = repr(tree)
expected = dedent(
"""
<xarray.DataTree>
Group: /
│ Dimensions: (time: 2)
│ Coordinates:
│ * time (time) <U7 56B '2022-01' '2023-01'
├── Group: /weather
│ │ Dimensions: (station: 6, time: 2)
│ │ Coordinates:
│ │ * station (station) <U1 24B 'a' 'b' 'c' 'd' 'e' 'f'
│ │ Data variables:
│ │ wind_speed (time, station) float64 96B 2.0 2.0 2.0 2.0 ... 2.0 2.0 2.0 2.0
│ │ pressure (time, station) float64 96B 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0 3.0
│ └── Group: /weather/temperature
│ Dimensions: (time: 2, station: 6)
│ Data variables:
│ air_temperature (time, station) float64 96B 4.0 4.0 4.0 4.0 ... 4.0 4.0 4.0
│ dewpoint (time, station) float64 96B 5.0 5.0 5.0 5.0 ... 5.0 5.0 5.0
└── Group: /satellite
Dimensions: (lat: 3, lon: 3, time: 2)
Coordinates:
* lat (lat) int64 24B 10 20 30
* lon (lon) int64 24B -100 -80 -60
Data variables:
infrared (time, lon, lat) float64 144B 6.0 6.0 6.0 6.0 ... 6.0 6.0 6.0
true_color (time, lon, lat) float64 144B 7.0 7.0 7.0 7.0 ... 7.0 7.0 7.0
"""
).strip()
assert result == expected

result = repr(tree["weather"])
expected = dedent(
"""
<xarray.DataTree 'weather'>
Group: /weather
│ Dimensions: (time: 2, station: 6)
│ Coordinates:
│ * station (station) <U1 24B 'a' 'b' 'c' 'd' 'e' 'f'
│ Inherited coordinates:
│ * time (time) <U7 56B '2022-01' '2023-01'
│ Data variables:
│ wind_speed (time, station) float64 96B 2.0 2.0 2.0 2.0 ... 2.0 2.0 2.0 2.0
│ pressure (time, station) float64 96B 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0 3.0
└── Group: /weather/temperature
Dimensions: (time: 2, station: 6)
Data variables:
air_temperature (time, station) float64 96B 4.0 4.0 4.0 4.0 ... 4.0 4.0 4.0
dewpoint (time, station) float64 96B 5.0 5.0 5.0 5.0 ... 5.0 5.0 5.0
"""
).strip()
assert result == expected


def _exact_match(message: str) -> str:
return re.escape(dedent(message).strip())
Expand Down
Loading