Skip to content

Commit

Permalink
MNT: Allow for encoding customization of MonteCarlo.
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Dec 7, 2024
1 parent 715338e commit 7834947
Show file tree
Hide file tree
Showing 11 changed files with 2,206 additions and 1,797 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

167 changes: 90 additions & 77 deletions docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb

Large diffs are not rendered by default.

48 changes: 10 additions & 38 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import base64
import json
from datetime import datetime
from importlib import import_module

import dill
import numpy as np

from rocketpy.mathutils.function import Function


class RocketPyEncoder(json.JSONEncoder):
"""Custom JSON encoder for RocketPy objects. It defines how to encode
different types of objects to a JSON supported format."""

def __init__(self, *args, **kwargs):
self.include_outputs = kwargs.pop("include_outputs", False)
self.include_function_data = kwargs.pop("include_function_data", True)
super().__init__(*args, **kwargs)

def default(self, o):
Expand Down Expand Up @@ -43,6 +44,13 @@ def default(self, o):
return [o.year, o.month, o.day, o.hour]

Check warning on line 44 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L44

Added line #L44 was not covered by tests
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif isinstance(o, Function):
if not self.include_function_data:
return str(o)

Check warning on line 49 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L49

Added line #L49 was not covered by tests
else:
encoding = o.to_dict(self.include_outputs)
encoding["signature"] = get_class_signature(o)
return encoding
elif hasattr(o, "to_dict"):
encoding = o.to_dict(self.include_outputs)
encoding = remove_circular_references(encoding)
Expand Down Expand Up @@ -155,39 +163,3 @@ def remove_circular_references(obj_dict):
obj_dict.pop("plots", None)

return obj_dict


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.
Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.
Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.
Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.
Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
4 changes: 2 additions & 2 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
RBFInterpolator,
)

from ..plots.plot_helpers import show_or_save_plot
from rocketpy.tools import from_hex_decode, to_hex_encode

from rocketpy._encoders import from_hex_decode, to_hex_encode
from ..plots.plot_helpers import show_or_save_plot

# Numpy 1.x compatibility,
# TODO: remove these lines when all dependencies support numpy>=2.0.0
Expand Down
2 changes: 1 addition & 1 deletion rocketpy/rocket/parachute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from rocketpy._encoders import from_hex_decode, to_hex_encode
from rocketpy.tools import from_hex_decode, to_hex_encode

from ..mathutils.function import Function
from ..prints.parachute_prints import _ParachutePrints
Expand Down
36 changes: 31 additions & 5 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(

# pylint: disable=consider-using-with
def simulate(
self, number_of_simulations, append=False
self, number_of_simulations, append=False, **kwargs
): # pylint: disable=too-many-statements
"""
Runs the Monte Carlo simulation and saves all data.
Expand All @@ -185,6 +185,17 @@ def simulate(
append : bool, optional
If True, the results will be appended to the existing files. If
False, the files will be overwritten. Default is False.
kwargs : dict
Custom arguments for simulation export of the ``inputs`` file. Options
are:
* ``include_outputs``: whether to also include outputs data of the
simulation. Default is ``False``.
* ``include_function_data``: whether to include ``rocketpy.Function``
results into the export. Default is ``True``.
See ``rocketpy._encoders.RocketPyEncoder`` for more information.
Returns
-------
Expand All @@ -204,6 +215,7 @@ def simulate(
overwritten. Make sure to save the files with the results before
running the simulation again with `append=False`.
"""
self._export_config = kwargs

Check warning on line 218 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L218

Added line #L218 was not covered by tests
# Create data files for inputs, outputs and error logging
open_mode = "a" if append else "w"
input_file = open(self._input_file, open_mode, encoding="utf-8")
Expand All @@ -224,11 +236,21 @@ def simulate(
self.__run_single_simulation(input_file, output_file)
except KeyboardInterrupt:
print("Keyboard Interrupt, files saved.")
error_file.write(json.dumps(self._inputs_dict, cls=RocketPyEncoder) + "\n")
error_file.write(

Check warning on line 239 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L239

Added line #L239 was not covered by tests
json.dumps(
self._inputs_dict, cls=RocketPyEncoder, **self._export_config
)
+ "\n"
)
self.__close_files(input_file, output_file, error_file)
except Exception as error:
print(f"Error on iteration {self.__iteration_count}: {error}")
error_file.write(json.dumps(self._inputs_dict, cls=RocketPyEncoder) + "\n")
error_file.write(

Check warning on line 248 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L248

Added line #L248 was not covered by tests
json.dumps(
self._inputs_dict, cls=RocketPyEncoder, **self._export_config
)
+ "\n"
)
self.__close_files(input_file, output_file, error_file)
raise error
finally:
Expand Down Expand Up @@ -393,8 +415,12 @@ def __export_flight_data(
) from e
results = results | additional_exports

input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")
input_file.write(

Check warning on line 418 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L418

Added line #L418 was not covered by tests
json.dumps(inputs_dict, cls=RocketPyEncoder, **self._export_config) + "\n"
)
output_file.write(

Check warning on line 421 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L421

Added line #L421 was not covered by tests
json.dumps(results, cls=RocketPyEncoder, **self._export_config) + "\n"
)

def __check_export_list(self, export_list):
"""
Expand Down
38 changes: 38 additions & 0 deletions rocketpy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
between minor versions if necessary, although this will be always avoided.
"""

import base64
import functools
import importlib
import importlib.metadata
Expand All @@ -15,6 +16,7 @@
import time
from bisect import bisect_left

import dill
import matplotlib.pyplot as plt
import numpy as np
import pytz
Expand Down Expand Up @@ -1167,6 +1169,42 @@ def get_matplotlib_supported_file_endings():
return filetypes


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.
Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.
Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.
Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.
Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))


if __name__ == "__main__":
import doctest

Expand Down
8 changes: 3 additions & 5 deletions tests/fixtures/motor/tanks_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LevelBasedTank,
MassBasedTank,
MassFlowRateBasedTank,
SphericalTank,
TankGeometry,
UllageBasedTank,
)
Expand Down Expand Up @@ -430,9 +431,7 @@ def oxidizer_tank(oxidizer_fluid, oxidizer_pressurant, propellant_tank_geometry)


@pytest.fixture
def spherical_oxidizer_tank(
oxidizer_fluid, oxidizer_pressurant, spherical_oxidizer_geometry
):
def spherical_oxidizer_tank(oxidizer_fluid, oxidizer_pressurant):
"""An example of a oxidizer spherical tank.
Parameters
Expand All @@ -447,12 +446,11 @@ def spherical_oxidizer_tank(
-------
rocketpy.LevelBasedTank
"""
geometry = SphericalTank(0.051)
liquid_level = Function(lambda t: 0.1 * np.exp(-t / 2) - 0.05)
oxidizer_tank = LevelBasedTank(
name="Lox Tank",
flux_time=10,
geometry=spherical_oxidizer_geometry,
geometry=SphericalTank(0.0501),
liquid=oxidizer_fluid,
gas=oxidizer_pressurant,
liquid_height=liquid_level,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_tank.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_mass_based_tank_fluid_mass(params, request):
expected_gas_mass[:, 1],
tank.gas_mass(expected_gas_mass[:, 0]),
rtol=1e-1,
atol=1e-4,
atol=1e-3,
)


Expand Down

0 comments on commit 7834947

Please sign in to comment.