diff --git a/yt/data_objects/tests/test_data_containers.py b/yt/data_objects/tests/test_data_containers.py index d0f99c6a799..f958de06628 100644 --- a/yt/data_objects/tests/test_data_containers.py +++ b/yt/data_objects/tests/test_data_containers.py @@ -126,8 +126,8 @@ def test_to_frb(self): data_source=dd, ) frb = proj.to_frb((1.0, "unitary"), 64) - assert_equal(frb.radius, (1.0, "unitary")) - assert_equal(frb.buff_size, 64) + assert_equal(frb.radius, ds.quan(1.0, "unitary")) + assert_equal(frb.buff_size, (64, 64)) def test_extract_isocontours(self): # Test isocontour properties for AMRGridData diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 7e1123cdf6f..8e1316f19fc 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -1,9 +1,11 @@ import sys import weakref from functools import partial +from numbers import Number from typing import TYPE_CHECKING, Optional import numpy as np +from unyt import unyt_quantity from yt._maintenance.deprecation import issue_deprecation_warning from yt._typing import FieldKey, MaskT @@ -584,16 +586,42 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer): def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None): self.data_source = data_source self.ds = data_source.ds - self.radius = radius - self.buff_size = buff_size + self._set_radius(radius) + if np.isscalar(buff_size): + self.buff_size = (buff_size, buff_size) + else: + self.buff_size = buff_size + self.bounds = self._get_bounds() self.antialias = antialias - self.data = {} + self.data: dict[str, ImageArray] = {} + self.mask: dict[str, MaskT] = {} self._filters = filters if filters is not None else [] ds = getattr(data_source, "ds", None) if ds is not None: ds.plots.append(weakref.proxy(self)) + def _set_radius(self, r) -> None: + if self.ds is None: + # not attempting to solve this case right now + raise RuntimeError("boom 1") # tmp + else: + if ( + isinstance(r, tuple) + and len(r) == 2 + and (isinstance(r[0], Number) and isinstance(r[1], str)) + ): + r = self.ds.quan(*r).to("code_length") + elif isinstance(r, unyt_quantity): + r = self.ds.quan(r).to("code_length") + else: + raise RuntimeError("boom 2") # tmp + self.radius = r + + def _get_bounds(self) -> tuple[float, float, float, float]: + dx = dy = self.radius.item() + return (0.0, dx, 0.0, dy) + @override def _generate_image_and_mask(self, item) -> None: buff = np.zeros(self.buff_size, dtype="f8") @@ -604,7 +632,7 @@ def _generate_image_and_mask(self, item) -> None: self.data_source["theta"], self.data_source["dtheta"], self.data_source[item].astype("float64"), - self.radius, + extents=self.bounds, return_mask=True, ) self.data[item] = ImageArray(