diff --git a/.vscode/settings.json b/.vscode/settings.json index 5cd5a878d..7fa0c1ff8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -48,6 +48,7 @@ "Bigl", "Bigr", "bijective", + "Bivariate", "bmatrix", "boldsymbol", "boxplot", diff --git a/CHANGELOG.md b/CHANGELOG.md index 90e4911f1..2a0c9b976 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). Attention: The newest changes should be on top --> ### Added - +- ENH: allow users to provide custom samplers [#803](https://github.com/RocketPy-Team/RocketPy/pull/803) - ENH: Implement Multivariate Rejection Sampling (MRS) [#738] (https://github.com/RocketPy-Team/RocketPy/pull/738) - ENH: Create a rocketpy file to store flight simulations [#800](https://github.com/RocketPy-Team/RocketPy/pull/800) - ENH: Support for the RSE file format has been added to the library [#798](https://github.com/RocketPy-Team/RocketPy/pull/798) diff --git a/docs/reference/classes/monte_carlo/stochastic_models/custom_sampler.rst b/docs/reference/classes/monte_carlo/stochastic_models/custom_sampler.rst new file mode 100644 index 000000000..54e612c3d --- /dev/null +++ b/docs/reference/classes/monte_carlo/stochastic_models/custom_sampler.rst @@ -0,0 +1,5 @@ +Custom Sampler +--------------------------- + +.. autoclass:: rocketpy.stochastic.CustomSampler + :members: \ No newline at end of file diff --git a/docs/reference/classes/monte_carlo/stochastic_models/index.rst b/docs/reference/classes/monte_carlo/stochastic_models/index.rst index 7bf705347..ca8b2b1e2 100644 --- a/docs/reference/classes/monte_carlo/stochastic_models/index.rst +++ b/docs/reference/classes/monte_carlo/stochastic_models/index.rst @@ -24,3 +24,4 @@ input parameters, enabling robust Monte Carlo simulations. stochastic_rocket stochastic_parachute stochastic_flight + custom_sampler diff --git a/docs/user/custom_sampler.rst b/docs/user/custom_sampler.rst new file mode 100644 index 000000000..8994c9573 --- /dev/null +++ b/docs/user/custom_sampler.rst @@ -0,0 +1,354 @@ +.. _custom_sampler: + +Implementing custom sampler for Stochastic objects +================================================== + +The :ref:`stochastic_usage` documentation teaches how to work with stochastic +objects, discussing the standard initializations, how to create objects and interpret +outputs. Our goal here is to show how to build a custom sampler that gives complete +control of the distributions used. + +Custom Sampler +-------------- + +Rocketpy provides a ``CustomSampler`` abstract class which works as the backbone of +a custom sampler. We begin by first importing it and some other useful modules + +.. jupyter-execute:: + + from rocketpy import CustomSampler + from matplotlib import pyplot as plt + import numpy as np + +In order to use it, we must create a new class that inherits from +it and it **must** implement two methods: *sample* and *reset_seed*. The *sample* +method has one argument, *n_samples*, and must return a list with *n_samples* +entries, each of which is a sample of the desired variable. The *reset_seed* method +has one argument, *seed*, which is used to reset the pseudorandom generators in order +to avoid unwanted dependency across samples. This is especially important when the +``MonteCarlo`` class is used in parallel mode. + +Below, we give an example of how to implement a mixture of two Gaussian +distributions. + +.. jupyter-execute:: + + class TwoGaussianMixture(CustomSampler): + """Class to sample from a mixture of two Gaussian distributions + """ + + def __init__(self, means_tuple, sd_tuple, prob_tuple, seed = None): + """ Creates a sampler for a mixture of two Gaussian distributions + + Parameters + ---------- + means_tuple : 2-tuple + 2-Tuple that contains the mean of each normal distribution of the + mixture + sd_tuple : 2-tuple + 2-Tuple that contains the sd of each normal distribution of the + mixture + prob_tuple : 2-tuple + 2-Tuple that contains the probability of each normal distribution + of the mixture. Its entries should be non-negative and sum up to 1. + """ + np.random.default_rng(seed) + self.means_tuple = means_tuple + self.sd_tuple = sd_tuple + self.prob_tuple = prob_tuple + + def sample(self, n_samples = 1): + """Samples from a mixture of two Gaussian + + Parameters + ---------- + n_samples : int, optional + Number of samples, by default 1 + + Returns + ------- + samples_list + List containing n_samples samples + """ + samples_list = [0] * n_samples + mixture_id_list = np.random.binomial(1, self.prob_tuple[0], n_samples) + for i, mixture_id in enumerate(mixture_id_list): + if mixture_id: + samples_list[i] = np.random.normal(self.means_tuple[0], self.sd_tuple[0]) + else: + samples_list[i] = np.random.normal(self.means_tuple[1], self.sd_tuple[1]) + + return samples_list + + def reset_seed(self, seed=None): + """Resets all associated random number generators + + Parameters + ---------- + seed : int, optional + Seed for the random number generator. + """ + np.random.default_rng(seed) + +This is an example of a distribution that is not implemented in numpy. Note that it is +a general distribution, so we can use it for many different variables. + +.. note:: + You can check more information about the mixture of Gaussian distributions + `here `. + Intuitively, if you think of a single Gaussian as a bell curve distribution, + the mixture distribution resembles two bell curves superimposed, each with mode at their + respective mean. + +Example: Gaussian Mixture for Total Impulse +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the new created sampler in a stochastic object, we first need +to build an object. In this example, we will consider a case where the distribution of +the total impulse is a mixture of two gaussian with mean parameters +:math:`(6000, 7000)`, standard deviations :math:`(300, 100)` and mixture probabilities +:math:`(0.7, 0.3)`. + +.. jupyter-execute:: + + means_tuple = (6000, 7000) + sd_tuple = (300, 100) + prob_tuple = (0.7, 0.3) + total_impulse_sampler = TwoGaussianMixture(means_tuple, sd_tuple, prob_tuple) + +Finally, we can create ``StochasticSolidMotor`` object as we did in the example of +:ref:`stochastic_usage`, but we pass the sampler object instead for the *total_impulse* +argument + +.. jupyter-execute:: + + from rocketpy import SolidMotor, StochasticSolidMotor + + motor = SolidMotor( + thrust_source="../data/motors/cesaroni/Cesaroni_M1670.eng", + dry_mass=1.815, + dry_inertia=(0.125, 0.125, 0.002), + nozzle_radius=33 / 1000, + grain_number=5, + grain_density=1815, + grain_outer_radius=33 / 1000, + grain_initial_inner_radius=15 / 1000, + grain_initial_height=120 / 1000, + grain_separation=5 / 1000, + grains_center_of_mass_position=0.397, + center_of_dry_mass_position=0.317, + nozzle_position=0, + burn_time=3.9, + throat_radius=11 / 1000, + coordinate_system_orientation="nozzle_to_combustion_chamber", + ) + + stochastic_motor = StochasticSolidMotor( + solid_motor=motor, + burn_start_time=(0, 0.1, "binomial"), + grains_center_of_mass_position=0.001, + grain_density=10, + grain_separation=1 / 1000, + grain_initial_height=1 / 1000, + grain_initial_inner_radius=0.375 / 1000, + grain_outer_radius=0.375 / 1000, + throat_radius=0.5 / 1000, + nozzle_radius=0.5 / 1000, + nozzle_position=0.001, + total_impulse=total_impulse_sampler, # total impulse using custom sampler! + ) + + stochastic_motor.visualize_attributes() + +Let's generate some random motors and check the distribution of the total impulse + +.. jupyter-execute:: + + total_impulse_samples = [ + stochastic_motor.create_object().total_impulse for _ in range(200) + ] + plt.hist(total_impulse_samples, density = True, bins = 30) + +Introducing dependency between parameters +----------------------------------------- + +Although probabilistic **independency between samples**, i.e. different flight runs, +is desired for Monte Carlo simulations, it is often important to be able to introduce +**dependency between parameters**. A clear example of this is wind speed: if we know +the wind speed in the x-axis, then our forecast model might tells us that the wind +speed y-axis is more likely to be positive than negative, or vice-versa. These +parameters are then correlated, and using samplers that model these correlations make +the Monte Carlo analysis more robust. + +When we use the default numpy samplers, the Monte Carlo analysis samples the +parameters independently from each other. However, using custom samplers, we can +introduce dependency and correlation! It might be a bit tricky, but we will show how +it can be done. First, let us import the modules required + +.. jupyter-execute:: + + from rocketpy import Environment, StochasticEnvironment + from datetime import datetime, timedelta + +Assume we want to model the x and y axis wind speed as a Bivariate Gaussian with +parameters :math:`\mu = (1, 1)` and variance-covariance matrix +:math:`\Sigma = \begin{bmatrix} 0.2 & 0.17 \\ 0.17 & 0.3 \end{bmatrix}`. This will +make the correlation between the speeds be of :math:`0.7`. + +Now, in order to correlate the parameters using different custom samplers, +**the key trick is to create a common generator that will be used by both.** The code +below implements an example of such a generator + +.. jupyter-execute:: + + class BivariateGaussianGenerator: + """Bivariate Gaussian generator used across custom samplers + """ + def __init__(self, mean, cov, seed = None): + """Initializes the generator + + Parameters + ---------- + mean : tuple, list + Tuple or list with mean of bivariate Gaussian + cov : np.array + Variance-Covariance matrix + seed : int, optional + Number to seed random generator, by default None + """ + np.random.default_rng(seed) + self.samples_list = [] + self.samples_generated = 0 + self.used_samples_x = 0 + self.used_samples_y = 0 + self.mean = mean + self.cov = cov + self.generate_samples(1000) + + def generate_samples(self, n_samples = 1): + """Generate samples from bivariate Gaussian and append to sample list + + Parameters + ---------- + n_samples : int, optional + Number of samples to be generated, by default 1 + """ + samples_generated = np.random.multivariate_normal(self.mean, self.cov, n_samples) + self.samples_generated += n_samples + self.samples_list += list(samples_generated) + + def reset_seed(self, seed=None): + np.random.default_rng(seed) + + def get_samples(self, n_samples, axis): + if axis == "x": + if self.samples_generated < self.used_samples_x: + self.generate_samples(n_samples) + samples_list = [ + sample[0] for sample in self.samples_list[self.used_samples_x:(self.used_samples_x + n_samples)] + ] + if axis == "y": + if self.samples_generated < self.used_samples_y: + self.generate_samples(n_samples) + samples_list = [ + sample[1] for sample in self.samples_list[self.used_samples_y:(self.used_samples_y + n_samples)] + ] + self.update_info(n_samples, axis) + return samples_list + + def update_info(self, n_samples, axis): + """Updates the information of the used samples + + Parameters + ---------- + n_samples : int + Number of samples used + axis : str + Which axis was sampled + """ + if axis == "x": + self.used_samples_x += n_samples + if axis == "y": + self.used_samples_y += n_samples + +This generator samples from the bivariate Gaussian and stores then in a *samples_list* +attribute. Then, the idea is to create two samplers for the wind speeds that will share +an object of this class and their sampling methods only get samples from the stored +sample list. + +.. jupyter-execute:: + + class WindXSampler(CustomSampler): + """Samples from X""" + + def __init__(self, bivariate_gaussian_generator): + self.generator = bivariate_gaussian_generator + + def sample(self, n_samples=1): + samples_list = self.generator.get_samples(n_samples, "x") + return samples_list + + def reset_seed(self, seed=None): + self.generator.reset_seed(seed) + + class WindYSampler(CustomSampler): + """Samples from Y""" + + def __init__(self, bivariate_gaussian_generator): + self.generator = bivariate_gaussian_generator + + def sample(self, n_samples=1): + samples_list = self.generator.get_samples(n_samples, "y") + return samples_list + + def reset_seed(self, seed=None): + self.generator.reset_seed(seed) + +Then, we create the objects + +.. jupyter-execute:: + + mean = [1, 2] + cov_mat = [[0.2, 0.171], [0.171, 0.3]] + bivariate_gaussian_generator = BivariateGaussianGenerator(mean, cov_mat) + wind_x_sampler = WindXSampler(bivariate_gaussian_generator) + wind_y_sampler = WindYSampler(bivariate_gaussian_generator) + +With the sample objects created, we only need to create the stochastic objects and +pass them as argument + +.. jupyter-execute:: + + spaceport_env = Environment( + latitude=32.990254, + longitude=-106.974998, + elevation=1400, + datum="WGS84", + ) + spaceport_env.set_atmospheric_model("custom_atmosphere", wind_u = 1, wind_v = 1) + spaceport_env.set_date(datetime.now() + timedelta(days=1)) + + stochastic_environment = StochasticEnvironment( + environment=spaceport_env, + elevation=(1400, 10, "normal"), + gravity=None, + latitude=None, + longitude=None, + ensemble_member=None, + wind_velocity_x_factor=wind_x_sampler, + wind_velocity_y_factor=wind_y_sampler + ) + +Finally, let us check that if there is a correlation between the wind speed in the +two axis + +.. jupyter-execute:: + + wind_velocity_x_samples = [0] * 200 + wind_velocity_y_samples = [0] * 200 + for i in range(200): + stochastic_environment.create_object() + wind_velocity_x_samples[i] = stochastic_environment.obj.wind_velocity_x(0) + wind_velocity_y_samples[i] = stochastic_environment.obj.wind_velocity_y(0) + + np.corrcoef(wind_velocity_x_samples, wind_velocity_y_samples) diff --git a/docs/user/index.rst b/docs/user/index.rst index f23eae25a..a09872ad5 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -32,6 +32,7 @@ RocketPy's User Guide :caption: Monte Carlo Simulations Stochastic Classes + Custom Sampler ../notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb ../notebooks/monte_carlo_analysis/monte_carlo_analysis.ipynb ../notebooks/monte_carlo_analysis/parachute_drop_from_helicopter.ipynb diff --git a/docs/user/stochastic.rst b/docs/user/stochastic.rst index 0985ca1da..062b034e2 100644 --- a/docs/user/stochastic.rst +++ b/docs/user/stochastic.rst @@ -88,6 +88,11 @@ passed in a few different ways: used as the parameter value during the simulation. You cannot assign standard \ deviations when using lists, nor can you assign different distribution types. +5. **A CustomSampler object**: \ + An object from a class that inherits from ``CustomSampler``. This object \ + gives you the full control of how the samples are generated. See + :ref:`custom_sampler` for more details. + .. note:: In statistics, the terms "Normal" and "Gaussian" refer to the same type of \ distribution. This distribution is commonly used and is the default for the \ diff --git a/rocketpy/__init__.py b/rocketpy/__init__.py index 1e7731073..f99a70f28 100644 --- a/rocketpy/__init__.py +++ b/rocketpy/__init__.py @@ -44,6 +44,7 @@ from .sensors import Accelerometer, Barometer, GnssReceiver, Gyroscope from .simulation import Flight, MonteCarlo, MultivariateRejectionSampler from .stochastic import ( + CustomSampler, StochasticAirBrakes, StochasticEllipticalFins, StochasticEnvironment, diff --git a/rocketpy/stochastic/__init__.py b/rocketpy/stochastic/__init__.py index b1e146246..ffadfaaaf 100644 --- a/rocketpy/stochastic/__init__.py +++ b/rocketpy/stochastic/__init__.py @@ -5,6 +5,7 @@ associated with each input parameter. """ +from .custom_sampler import CustomSampler from .stochastic_aero_surfaces import ( StochasticAirBrakes, StochasticEllipticalFins, diff --git a/rocketpy/stochastic/custom_sampler.py b/rocketpy/stochastic/custom_sampler.py new file mode 100644 index 000000000..82a06dd9f --- /dev/null +++ b/rocketpy/stochastic/custom_sampler.py @@ -0,0 +1,38 @@ +""" +Provides an abstract class so that users can build custom samplers upon +""" + +from abc import ABC, abstractmethod + + +class CustomSampler(ABC): + """Abstract subclass for user defined samplers""" + + @abstractmethod + def sample(self, n_samples=1): + """Generates samples from the custom distribution + + Parameters + ---------- + n_samples : int, optional + Numbers of samples to be generated + + Returns + ------- + samples_list : list + A list with n_samples elements, each of which is a valid sample + """ + + @abstractmethod + def reset_seed(self, seed=None): + """Resets the seeds of all associated stochastic generators + + Parameters + ---------- + seed : int, optional + Seed for the random number generator. The default is None + + Returns + ------- + None + """ diff --git a/rocketpy/stochastic/stochastic_model.py b/rocketpy/stochastic/stochastic_model.py index fe47a252a..879b61e70 100644 --- a/rocketpy/stochastic/stochastic_model.py +++ b/rocketpy/stochastic/stochastic_model.py @@ -8,6 +8,7 @@ import numpy as np from rocketpy.mathutils.function import Function +from rocketpy.stochastic.custom_sampler import CustomSampler from ..tools import get_distribution @@ -88,7 +89,9 @@ def _set_stochastic(self, seed=None): attr_value = None if input_value is not None: if "factor" in input_name: - attr_value = self._validate_factors(input_name, input_value) + attr_value = self._validate_factors( + input_name, input_value, seed + ) elif input_name not in self.exception_list: if isinstance(input_value, tuple): attr_value = self._validate_tuple(input_name, input_value) @@ -96,9 +99,14 @@ def _set_stochastic(self, seed=None): attr_value = self._validate_list(input_name, input_value) elif isinstance(input_value, (int, float)): attr_value = self._validate_scalar(input_name, input_value) + elif isinstance(input_value, CustomSampler): + attr_value = self._validate_custom_sampler( + input_name, input_value, seed + ) else: raise AssertionError( f"'{input_name}' must be a tuple, list, int, or float" + "or a custom sampler" ) else: attr_value = [getattr(self.obj, input_name)] @@ -280,7 +288,7 @@ def _validate_scalar(self, input_name, input_value, getattr=getattr): # pylint: get_distribution("normal", self.__random_number_generator), ) - def _validate_factors(self, input_name, input_value): + def _validate_factors(self, input_name, input_value, seed): """ Validate factor arguments. @@ -308,8 +316,12 @@ def _validate_factors(self, input_name, input_value): return self._validate_tuple_factor(input_name, input_value) elif isinstance(input_value, list): return self._validate_list_factor(input_name, input_value) + elif isinstance(input_value, CustomSampler): + return self._validate_custom_sampler(input_name, input_value, seed) else: - raise AssertionError(f"`{input_name}`: must be either a tuple or list") + raise AssertionError( + f"`{input_name}`: must be either a tuple or listor a custom sampler" + ) def _validate_tuple_factor(self, input_name, factor_tuple): """ @@ -436,6 +448,33 @@ def _validate_positive_int_list(self, input_name, input_value): isinstance(member, int) and member >= 0 for member in input_value ), f"`{input_name}` must be a list of positive integers" + def _validate_custom_sampler(self, input_name, sampler, seed=None): + """ + Validate a custom sampler. + + Parameters + ---------- + input_name : str + Name of the input argument. + sampler : CustomSampler object + Custom sampler provided by the user + seed : int, optional + Seed for the random number generator. The default is None + + Raises + ------ + AssertionError + If the input is not in a valid format. + """ + try: + sampler.reset_seed(seed) + except RuntimeError as e: + raise RuntimeError( + f"An error occurred in the 'reset_seed' method of {input_name} CustomSampler" + ) from e + + return sampler + def _validate_airfoil(self, airfoil): """ Validate airfoil input. @@ -490,9 +529,17 @@ def dict_generator(self): generated_dict = {} for arg, value in self.__dict__.items(): if isinstance(value, tuple): - generated_dict[arg] = value[-1](value[0], value[1]) + dist_sampler = value[-1] + generated_dict[arg] = dist_sampler(value[0], value[1]) elif isinstance(value, list): generated_dict[arg] = choice(value) if value else value + elif isinstance(value, CustomSampler): + try: + generated_dict[arg] = value.sample(n_samples=1)[0] + except RuntimeError as e: + raise RuntimeError( + f"An error occurred in the 'sample' method of {arg} CustomSampler" + ) from e self.last_rnd_dict = generated_dict yield generated_dict @@ -527,6 +574,12 @@ def format_attribute(attr, value): f"{nominal_value:.5f} ± " f"{std_dev:.5f} ({dist_func.__name__})" ) + elif isinstance(value, CustomSampler): + sampler_name = type(value).__name__ + return ( + f"\t{attr.ljust(max_str_length)} " + f"\t{sampler_name.ljust(max_str_length)} " + ) return None attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} @@ -550,6 +603,9 @@ def format_attribute(attr, value): list_attributes = [ attr for attr, val in items if isinstance(val, list) and len(val) > 1 ] + custom_attributes = [ + attr for attr, val in items if isinstance(val, CustomSampler) + ] if constant_attributes: report.append("\nConstant Attributes:") @@ -568,5 +624,10 @@ def format_attribute(attr, value): report.extend( format_attribute(attr, attributes[attr]) for attr in list_attributes ) + if custom_attributes: + report.append("\nStochastic Attributes with Custom user samplers:") + report.extend( + format_attribute(attr, attributes[attr]) for attr in custom_attributes + ) print("\n".join(filter(None, report))) diff --git a/tests/conftest.py b/tests/conftest.py index 370721cf8..980e0b6ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ "tests.fixtures.units.numerical_fixtures", "tests.fixtures.monte_carlo.monte_carlo_fixtures", "tests.fixtures.monte_carlo.stochastic_fixtures", + "tests.fixtures.monte_carlo.custom_sampler_fixtures", "tests.fixtures.monte_carlo.stochastic_motors_fixtures", "tests.fixtures.sensors.sensors_fixtures", "tests.fixtures.generic_surfaces.generic_surfaces_fixtures", diff --git a/tests/fixtures/monte_carlo/custom_sampler_fixtures.py b/tests/fixtures/monte_carlo/custom_sampler_fixtures.py new file mode 100644 index 000000000..8a4ff497d --- /dev/null +++ b/tests/fixtures/monte_carlo/custom_sampler_fixtures.py @@ -0,0 +1,76 @@ +"""This file contains fixtures of CustomSampler used in stochastic classes.""" + +import numpy as np +import pytest + +from rocketpy import CustomSampler + + +@pytest.fixture +def elevation_sampler(): + """Fixture to create mixture of two gaussian sampler""" + means_tuple = [1400, 1500] + sd_tuple = [40, 50] + prob_tuple = [0.4, 0.6] + return TwoGaussianMixture(means_tuple, sd_tuple, prob_tuple) + + +class TwoGaussianMixture(CustomSampler): + """Class to sample from a mixture of two Gaussian distributions""" + + def __init__(self, means_tuple, sd_tuple, prob_tuple, seed=None): + """Creates a sampler for a mixture of two Gaussian distributions + + Parameters + ---------- + means_tuple : 2-tuple + 2-Tuple that contains the mean of each normal distribution of the + mixture + sd_tuple : 2-tuple + 2-Tuple that contains the sd of each normal distribution of the + mixture + prob_tuple : 2-tuple + 2-Tuple that contains the probability of each normal distribution of the + mixture. Its entries should be non-negative and sum up to 1. + """ + np.random.default_rng(seed) + self.means_tuple = means_tuple + self.sd_tuple = sd_tuple + self.prob_tuple = prob_tuple + + def sample(self, n_samples=1): + """Samples from a mixture of two Gaussian + + Parameters + ---------- + n_samples : int, optional + Number of samples, by default 1 + + Returns + ------- + samples_list + List containing n_samples samples + """ + samples_list = [0] * n_samples + mixture_id_list = np.random.binomial(1, self.prob_tuple[0], n_samples) + for i, mixture_id in enumerate(mixture_id_list): + if mixture_id: + samples_list[i] = np.random.normal( + self.means_tuple[0], self.sd_tuple[0] + ) + else: + samples_list[i] = np.random.normal( + self.means_tuple[1], self.sd_tuple[1] + ) + + return samples_list + + def reset_seed(self, seed=None): + """Resets all associated random number generators + + Parameters + ---------- + seed : int, optional + Seed for the random number generator. + """ + np.random.default_rng(seed) diff --git a/tests/fixtures/monte_carlo/stochastic_fixtures.py b/tests/fixtures/monte_carlo/stochastic_fixtures.py index bf576e5ed..6610666cf 100644 --- a/tests/fixtures/monte_carlo/stochastic_fixtures.py +++ b/tests/fixtures/monte_carlo/stochastic_fixtures.py @@ -43,6 +43,36 @@ def stochastic_environment(example_spaceport_env): ) +@pytest.fixture +def stochastic_environment_custom_sampler(example_spaceport_env, elevation_sampler): + """This fixture is used to create a stochastic environment object for the + Calisto flight using a custom sampler for the elevation. + + Parameters + ---------- + example_spaceport_env : Environment + This is another fixture. + + elevation_sampler: CustomSampler + This is another fixture. + + Returns + ------- + StochasticEnvironment + The stochastic environment object + """ + return StochasticEnvironment( + environment=example_spaceport_env, + elevation=elevation_sampler, + gravity=None, + latitude=None, + longitude=None, + ensemble_member=None, + wind_velocity_x_factor=(1.0, 0.033, "normal"), + wind_velocity_y_factor=(1.0, 0.033, "normal"), + ) + + @pytest.fixture def stochastic_nose_cone(calisto_nose_cone): """This fixture is used to create a StochasticNoseCone object for the diff --git a/tests/unit/stochastic/test_custom_sampler.py b/tests/unit/stochastic/test_custom_sampler.py new file mode 100644 index 000000000..90774ac50 --- /dev/null +++ b/tests/unit/stochastic/test_custom_sampler.py @@ -0,0 +1,21 @@ +from rocketpy.environment.environment import Environment + + +def test_create_object(stochastic_environment_custom_sampler): + """Test create object method of StochasticEnvironment class. + + This test checks if the create_object method of the StochasticEnvironment + class creates a StochasticEnvironment object from the randomly generated + input arguments. + + Parameters + ---------- + stochastic_environment : StochasticEnvironment + StochasticEnvironment object to be tested. + + Returns + ------- + None + """ + obj = stochastic_environment_custom_sampler.create_object() + assert isinstance(obj, Environment) diff --git a/tests/unit/test_stochastic_model.py b/tests/unit/test_stochastic_model.py index 77c94fb40..0d0a13311 100644 --- a/tests/unit/test_stochastic_model.py +++ b/tests/unit/test_stochastic_model.py @@ -7,6 +7,7 @@ "stochastic_rail_buttons", "stochastic_main_parachute", "stochastic_environment", + "stochastic_environment_custom_sampler", "stochastic_tail", "stochastic_calisto", ],