Skip to content

Commit

Permalink
Merge pull request #25 from ahuang314/main
Browse files Browse the repository at this point in the history
Added a WeightedCatalog class with corresponding test functions
  • Loading branch information
swagnercarena authored May 29, 2024
2 parents fc64581 + 47fcf35 commit 7996c83
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 5 deletions.
13 changes: 10 additions & 3 deletions paltax/image_simulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,22 @@ def _prepare_all_source_models():
all_source_models = []
cosmos_path = str(pathlib.Path(__file__).parent)
cosmos_path += '/test_files/cosmos_galaxies_testing.npz'
catalog_weights = jnp.array([2.0, 5.0])
for model in source_models.__all__:
# CosmosCatalog model required initialization parameters.
if model != 'CosmosCatalog':
if model == 'CosmosCatalog':
all_source_models.append(
source_models.__getattribute__(model)()
source_models.__getattribute__(model)(cosmos_path)
)
elif model == 'WeightedCatalog':
all_source_models.append(
source_models.__getattribute__(model)(
cosmos_path, catalog_weights
)
)
else:
all_source_models.append(
source_models.__getattribute__(model)(cosmos_path)
source_models.__getattribute__(model)()
)
return tuple(all_source_models)

Expand Down
105 changes: 104 additions & 1 deletion paltax/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from paltax import cosmology_utils
from paltax import utils

__all__ = ['Interpol', 'SersicElliptic', 'CosmosCatalog']
__all__ = [
'Interpol', 'SersicElliptic', 'CosmosCatalog', 'WeightedCatalog'
]


class _SourceModelBase():
Expand Down Expand Up @@ -326,6 +328,31 @@ def convert_to_angular(
all_kwargs['galaxy_index'] * cosmology_params['cosmos_n_images']
).astype(int)

return CosmosCatalog._convert_to_angular(
all_kwargs, cosmology_params, galaxy_index
)

@staticmethod
def _convert_to_angular(
all_kwargs: Dict[str, jnp.ndarray],
cosmology_params: Dict[str, Union[float, int, jnp.ndarray]],
galaxy_index: int
) -> Dict[str, jnp.ndarray]:
"""Convert any parameters in physical units to angular units.
Args:
all_kwargs: All of the arguments, possibly including some in
physical units.
cosmology_params: Cosmological parameters that define the universe's
expansion.
Returns:
Arguments with any physical units parameters converted to angular
units.
Notes:
Galaxy index must have already been converted to an index.
"""
# Read the catalog values directly from the stored arrays.
z_catalog = cosmology_params['cosmos_redshifts'][galaxy_index]
pixel_scale_catalog = (
Expand Down Expand Up @@ -395,3 +422,79 @@ def k_correct_image(z_old: float, z_new: float) -> float:
mag_k_correction = utils.get_k_correction(z_new)
mag_k_correction -= utils.get_k_correction(z_old)
return 10 ** (-mag_k_correction / 2.5)


class WeightedCatalog(CosmosCatalog):
"""Light profiles from catalog with custom weights
"""

def __init__(self, cosmos_path: str, catalog_weights: jnp.ndarray):
"""Initialize the path to the catalog galaxies and catalog weights.
Args:
cosmos_path: Path to the npz file containing the cosmos images,
redshift array, and pixel sizes.
catalog_weights: Weights for the sources in the catalog. Do not
need to be normalized.
"""
# Save the cosmos image path.
super().__init__(cosmos_path=cosmos_path)

# Turns the catalog_weights pdf into a normalized cdf
catalog_weights_cdf = (
jnp.cumsum(catalog_weights) / jnp.sum(catalog_weights)
)
self.catalog_weights_cdf = catalog_weights_cdf

def modify_cosmology_params(
self,
cosmology_params: Dict[str, Union[float, int, jnp.ndarray]]
) -> Dict[str, Union[float, int, jnp.ndarray]]:
"""Modify cosmology params to include information required by model.
Args:
cosmology_params: Cosmological parameters that define the universe's
expansion. Must be mutable.
Returns:
Modified cosmology parameters.
"""
cosmology_params = super().modify_cosmology_params(
cosmology_params=cosmology_params
)
cosmology_params['catalog_weights_cdf'] = self.catalog_weights_cdf

n_weights = len(cosmology_params['catalog_weights_cdf'])
if cosmology_params['cosmos_n_images'] != n_weights:
raise ValueError(
f'Number of weights {n_weights} should be equal to the ' +
f'number of sources {cosmology_params["cosmos_n_images"]}'
)

return cosmology_params

@staticmethod
def convert_to_angular(
all_kwargs: Dict[str, jnp.ndarray],
cosmology_params: Dict[str, Union[float, int, jnp.ndarray]]
) -> Dict[str, jnp.ndarray]:
"""Convert any parameters in physical units to angular units.
Args:
all_kwargs: All of the arguments, possibly including some in
physical units.
cosmology_params: Cosmological parameters that define the universe's
expansion.
Returns:
Arguments with any physical units parameters converted to angular
units.
"""
# Select the galaxy index using the weighted distribution
galaxy_index = jnp.searchsorted(
cosmology_params['catalog_weights_cdf'], all_kwargs['galaxy_index']
)

return CosmosCatalog._convert_to_angular(
all_kwargs, cosmology_params, galaxy_index
)
93 changes: 92 additions & 1 deletion paltax/source_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
})


def _prepare_catalog_weights():
# Generate a fixed set of catalog weights.
catalog_weights = jnp.array([0.2, 10.0])
catalog_weights_cdf = jnp.cumsum(catalog_weights) / jnp.sum(catalog_weights)
return catalog_weights, catalog_weights_cdf


def _prepare_cosmology_params(
cosmology_params_init, z_lookup_max, dz, r_min=1e-4, r_max=1e3,
n_r_bins=2):
Expand Down Expand Up @@ -288,7 +295,7 @@ def test__b_n(self, n_sersic, expected):


class CosmosCatalogTest(chex.TestCase, parameterized.TestCase):
"""Runs tests of elliptical Sersic brightness functions."""
"""Runs tests of CosmosCatalog functions."""

def test__init__(self):
# Test that the intialization saves the path.
Expand Down Expand Up @@ -448,5 +455,89 @@ def test_k_correct_image(self, z_old, z_new):
expected, k_correct_image(z_old, z_new), places=6
)


class WeightedCatalogTest(chex.TestCase):
"""Runs tests of WeightedCatalog functions."""

def test__init__(self):
# Test that the intialization saves the path and weights.
catalog_weights, catalog_weights_cdf = _prepare_catalog_weights()
weighted_catalog = source_models.WeightedCatalog(
COSMOS_TEST_PATH, catalog_weights
)
self.assertEqual(weighted_catalog.cosmos_path, COSMOS_TEST_PATH)
np.testing.assert_array_almost_equal(
catalog_weights_cdf, weighted_catalog.catalog_weights_cdf
)

def test_modify_cosmology_params(self):
# Test that the weights are saved to the cosmology params.
catalog_weights, catalog_weights_cdf = _prepare_catalog_weights()
weighted_catalog = source_models.WeightedCatalog(
COSMOS_TEST_PATH, catalog_weights
)
cosmology_params = {}
cosmology_params = weighted_catalog.modify_cosmology_params(
cosmology_params
)

self.assertEqual(cosmology_params['cosmos_n_images'], 2)
np.testing.assert_array_almost_equal(
cosmology_params['catalog_weights_cdf'], catalog_weights_cdf
)

# Makes sure that when the number of weights don't match the
# number of sources, an error is raised
weighted_catalog = source_models.WeightedCatalog(
COSMOS_TEST_PATH, catalog_weights[1:]
)
with self.assertRaises(ValueError):
cosmology_params = weighted_catalog.modify_cosmology_params(
cosmology_params
)


@chex.all_variants
def test_convert_to_angular(self):
# Test that we sample accoding to the weights
catalog_weights, _ = _prepare_catalog_weights()
weighted_catalog = source_models.WeightedCatalog(
COSMOS_TEST_PATH, catalog_weights
)
cosmology_params = _prepare_cosmology_params(
COSMOLOGY_PARAMS_INIT, 1.0, 0.01
)
cosmology_params = weighted_catalog.modify_cosmology_params(
cosmology_params
)
all_kwargs = {
'galaxy_index': 0.01,
'amp': 1.0,
'z_source': cosmology_params['cosmos_redshifts'][0],
'output_ab_zeropoint': 23.5,
'catalog_ab_zeropoint': 23.5
}
convert_to_angular = self.variant(weighted_catalog.convert_to_angular)

# Makes sure that the first image is returned when the galaxy index
# is below 0.02
angular_kwargs = convert_to_angular(all_kwargs, cosmology_params)
np.testing.assert_array_almost_equal(
angular_kwargs['image'],
(cosmology_params['cosmos_images'][0] /
cosmology_params['cosmos_pixel_sizes'][0] ** 2),
decimal=4)

# Makes sure that the second image is returned when the galaxy index
# is above 0.02
all_kwargs['galaxy_index'] = 0.1
angular_kwargs = convert_to_angular(all_kwargs, cosmology_params)
np.testing.assert_array_almost_equal(
angular_kwargs['image'],
(cosmology_params['cosmos_images'][1] /
cosmology_params['cosmos_pixel_sizes'][1] ** 2),
decimal=4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 7996c83

Please sign in to comment.