Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release/1.0.0b9 #512

Merged
merged 11 commits into from
Jul 18, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
exclude: ^.*fits
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ sampling_function = btk.sampling_functions.DefaultSampling(
# setup generator to create batches of blends
batch_size = 100
draw_generator = btk.draw_blends.CatsimGenerator(
catalog, sampling_function, survey, batch_size, stamp_size
catalog, sampling_function, survey, batch_size
)

# get batch of blends
Expand Down
20 changes: 15 additions & 5 deletions btk/blend_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _validate_catalog(self, catalog: Table):
)
return catalog

def _validate_segmentation(self, segmentation):
def _validate_segmentation(self, segmentation: Optional[np.ndarray]):
if segmentation is not None:
if self.image_size is None or self.n_bands is None:
raise ValueError("`image_size` must be specified if segmentation is provided")
Expand All @@ -233,14 +233,17 @@ def _validate_segmentation(self, segmentation):
"The predicted segmentation of at least one of your deblended images "
"has the wrong shape. It should be `(max_n_sources, image_size, image_size)`."
)
if segmentation.min() < 0 or segmentation.max() > 1:
cond1 = np.any(np.greater(segmentation, 0) & np.less(segmentation, 1))
cond2 = np.any(np.less(segmentation, 0) | np.greater(segmentation, 1))
if cond1 or cond2:
raise ValueError(
"The predicted segmentation of at least one of your deblended images "
"has values outside the range [0, 1]."
"has values different than 0 or 1."
)
return segmentation.astype(bool)
return segmentation

def _validate_deblended_images(self, deblended_images):
def _validate_deblended_images(self, deblended_images: Optional[np.ndarray]):
if deblended_images is not None:
if self.image_size is None or self.n_bands is None:
raise ValueError(
Expand Down Expand Up @@ -335,7 +338,14 @@ def _validate_segmentation(self, segmentation: Optional[np.ndarray] = None) -> n
self.image_size,
self.image_size,
)
assert segmentation.min() >= 0 and segmentation.max() <= 1
cond1 = np.any(np.greater(segmentation, 0) & np.less(segmentation, 1))
cond2 = np.any(np.less(segmentation, 0) | np.greater(segmentation, 1))
if cond1 or cond2:
raise ValueError(
"The predicted segmentation of at least one of your deblended images "
"has values different than 0 or 1."
)
return segmentation.astype(bool)
return segmentation

def _validate_deblended_images(
Expand Down
2 changes: 1 addition & 1 deletion btk/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from astropy import units
from astropy.coordinates import SkyCoord
from astropy.table import Table
from galcheat.utilities import mean_sky_level
from numpy.linalg import LinAlgError
from skimage.feature import peak_local_max
from surveycodex.utilities import mean_sky_level

from btk.blend_batch import (
BlendBatch,
Expand Down
9 changes: 2 additions & 7 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from astropy.table import Column, Table
from astropy.wcs import WCS
from galcheat.utilities import mag2counts, mean_sky_level
from surveycodex.utilities import mag2counts, mean_sky_level
from tqdm.auto import tqdm

from btk.blend_batch import BlendBatch, MultiResolutionBlendBatch
Expand Down Expand Up @@ -148,7 +148,6 @@ def __init__(
sampling_function: SamplingFunction,
surveys: Union[List[Survey], Survey],
batch_size: int = 8,
stamp_size: float = 24.0,
njobs: int = 1,
verbose: bool = False,
use_bar: bool = False,
Expand All @@ -165,7 +164,6 @@ def __init__(
surveys: List of BTK Survey objects or
single BTK Survey object.
batch_size: Number of blends generated per batch
stamp_size: Size of the stamps, in arcseconds
njobs: Number of njobs to use; defines the number of minibatches
verbose: Indicates whether additionnal information should be printed
use_bar: Whether to use progress bar (default: False)
Expand All @@ -187,7 +185,7 @@ def __init__(
self.max_number = self.blend_generator.max_number
self.apply_shear = apply_shear
self.augment_data = augment_data
self.stamp_size = stamp_size
self.stamp_size = sampling_function.stamp_size
self.use_bar = use_bar
self._set_surveys(surveys)

Expand Down Expand Up @@ -523,7 +521,6 @@ def __init__(
sampling_function: SamplingFunction,
surveys: List[Survey],
batch_size: int = 8,
stamp_size: float = 24.0,
njobs: int = 1,
verbose: bool = False,
add_noise: str = "all",
Expand All @@ -541,7 +538,6 @@ def __init__(
sampling_function: See parent class.
surveys: See parent class.
batch_size: See parent class.
stamp_size: See parent class.
njobs: See parent class.
verbose: See parent class.
add_noise: See parent class.
Expand All @@ -563,7 +559,6 @@ def __init__(
sampling_function,
surveys,
batch_size,
stamp_size,
njobs,
verbose,
use_bar,
Expand Down
52 changes: 39 additions & 13 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@


def _get_single_ksb_ellipticity(
image: np.ndarray, psf: GSObject, pixel_scale: float, verbose=False
image: np.ndarray, centroid: np.ndarray, psf: GSObject, pixel_scale: float, verbose=False
) -> Tuple[float, float]:
"""Utility function to measure ellipticity using the KSB method.

Args:
image: Image of a single, isolated galaxy with shape (H, W).
image: Image of a single, isolated galaxy with shape (h, w).
centroid: The centroid of the galaxy in the image, with shape (2,). Following the GalSim
convention for offsets.
psf: A galsim object containing the PSF of the single, isolated galaxy.
pixel_scale: The pixel scale of the galaxy image.
verbose: Whether to print errors if they happen when estimating ellipticity.
Expand All @@ -23,33 +25,57 @@ def _get_single_ksb_ellipticity(
Tuple of (g1, g2) containing measured shapes.
"""
psf_image = galsim.Image(image.shape[0], image.shape[1], scale=pixel_scale)
psf_image = psf.drawImage(psf_image)
gal_image = galsim.Image(image, scale=pixel_scale)
psf_image = psf.drawImage(psf_image)
pos = galsim.PositionD(centroid)

res = galsim.hsm.EstimateShear(gal_image, psf_image, shear_est="KSB", strict=False)
res = galsim.hsm.EstimateShear(
gal_image, psf_image, shear_est="KSB", strict=False, guess_centroid=pos
)
output = (res.corrected_g1, res.corrected_g2)
if res.error_message != "" and verbose:
print(
f"Shear measurement error: '{res.error_message }'. \
This error may happen for faint galaxies or inaccurate detections."
)
if res.error_message != "": # absorbs all (10, -10) and makes them np.nan
output = (np.nan, np.nan)
if verbose:
print(
f"Shear measurement error: '{res.error_message }'. \
This error may happen for faint galaxies or inaccurate detections."
)
return output


def get_ksb_ellipticity(
images: np.ndarray, psf: GSObject, pixel_scale: float, verbose=False
images: np.ndarray, centroids: np.ndarray, psf: GSObject, pixel_scale: float, verbose=False
) -> np.ndarray:
"""Return ellipticities of both true and detected galaxies, assuming they are matched."""
"""Calculate the KSB ellipticities of a batched array of isolated galaxy images.

The galaxy images are assumed to all correspond to single band, and the input PSF is assumed
to be the same for all images.

If the shear measurement fails or the image is empty (no flux), then `np.nan` is returned for
the corresponding ellipticity.

Args:
images: Array of batch isolated images with shape (batch_size, max_n_sources, h, w)
centroids: An array of centers for each galaxy using the GalSim convention where the
center of the lower-left pixel is (image.xmin, image.ymin). The shape of this array is
(batch_size, max_n_sources, 2).
psf: a GalSim GSObject containing the PSF common to all galaxies.
pixel_scale: The pixel scale of the galaxy images.
verbose: Whether an error message should be printed if the ellipticity measurement fails
for any one of the galaxies.

Returns:
An array containing the measured ellipticities of shape (batch_size, max_n_sources, 2)
"""
# psf is assumed to be the same for the entire batch and correspond to selected band.
assert images.ndim == 4 # (batch_size, max_n_sources, H, W)
assert images.ndim == 4
batch_size, max_n_sources, _, _ = images.shape
ellipticities = np.zeros((batch_size, max_n_sources, 2))
for ii in range(batch_size):
for jj in range(max_n_sources):
if np.sum(images[ii, jj]) > 0:
ellipticities[ii, jj] = _get_single_ksb_ellipticity(
images[ii, jj], psf, pixel_scale, verbose=verbose
images[ii, jj], centroids[ii, jj], psf, pixel_scale, verbose=verbose
)
else:
ellipticities[ii, jj] = (np.nan, np.nan)
Expand Down
6 changes: 5 additions & 1 deletion btk/metrics/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def _get_data(self, iso_images1: np.ndarray, iso_images2: np.ndarray) -> Dict[st


class MSE(ReconstructionMetric):
"""MSE class metric."""
"""MSE class metric.

Note that this metric can become diluted as the postage stamp size grows, as it does not
exclude pixels with a common value of zero in the images it compares.
"""

def _get_data(self, iso_images1: np.ndarray, iso_images2: np.ndarray) -> Dict[str, np.ndarray]:
return {"mse": self._get_recon_metric(iso_images1, iso_images2, mse)}
Expand Down
10 changes: 9 additions & 1 deletion btk/metrics/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def _get_data(


class IoU(SegmentationMetric):
"""Intersection-over-Union class metric."""
"""Intersection-over-Union class metric.

Note that this metric assumes that the input segmentation arrays are booleans. An error is
raised if that condition is not met.
"""

def _get_data(
self,
Expand All @@ -31,6 +35,10 @@ def _get_data(
) -> Dict[str, np.ndarray]:
assert seg1.shape == seg2.shape
assert seg1.ndim == 4 # batch, max_n_sources, x, y

if not (seg1.dtype == "bool" and seg2.dtype == "bool"):
raise TypeError("Segmentation arrays for the IoU metric should be of boolean type.")

ious = np.full((self.batch_size, seg1.shape[1]), fill_value=np.nan)
for ii in range(self.batch_size):
n_sources1 = np.sum(np.sum(seg1[ii], axis=(-1, -2)) > 0)
Expand Down
3 changes: 2 additions & 1 deletion btk/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def mse(images1: np.ndarray, images2: np.ndarray) -> np.ndarray:
def iou(seg1: np.ndarray, seg2: np.ndarray) -> np.ndarray:
"""Calculates intersection-over-union (IoU) given two semgentation arrays.

The segmentation arrays should each have values of 1 or 0s only.
This metric assumes that the input arrays are boolean. Otherwise the arrays are
casted to boolean arrays before the computation.

Args:
seg1: Array of shape `NHW` containing `N` segmentation maps each of
Expand Down
Loading
Loading