Skip to content

Commit fd91404

Browse files
hamelphiTorax team
authored and
Torax team
committed
Remove QLKKNN model path flag
This flag is not necessary, the model path can be set using the `TORAX_QLKNN_MODEL_PATH` environment variable instead. This simplifies the logic to update the config and build the sim, and removes dependencies on qlknn from the main interface. PiperOrigin-RevId: 713417605
1 parent e536221 commit fd91404

File tree

4 files changed

+7
-84
lines changed

4 files changed

+7
-84
lines changed

docs/quickstart.rst

+1-13
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TORAX_QLKNN_MODEL_PATH
3131
^^^^^^^^^^^^^^^^^^^^^^^
3232
Path to the QuaLiKiz-neural-network parameters. The path specified here
3333
will be ignored if the ``model_path`` field in the ``qlknn_params`` section of
34-
the run config file or the ``qlknn_model_path`` flag are set.
34+
the run config file is set.
3535

3636
.. code-block:: console
3737
@@ -105,18 +105,6 @@ Provide a reference run to compare against in post-simulation plotting.
105105
--config='torax.examples.basic_config' \
106106
--reference_run=<path_to_reference_run>
107107
108-
qlknn_model_path
109-
^^^^^^^^^^^^^^^^
110-
Provide a path to load the QLKNN model from. This flag supersedes
111-
the path set in the config file and the ``TORAX_QLKNN_MODEL_PATH`` environment
112-
variable.
113-
114-
.. code-block:: console
115-
116-
python3 run_simulation_main.py \
117-
--config='torax.examples.basic_config' \
118-
--qlknn_model_path=<path_to_qlknn_model>
119-
120108
output_dir
121109
^^^^^^^^^^
122110
Override the default output directory. If not provided, it will be set to

docs/running.rst

+1-13
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ TORAX_QLKNN_MODEL_PATH
3030
^^^^^^^^^^^^^^^^^^^^^^^
3131
Path to the QuaLiKiz-neural-network parameters. The path specified here
3232
will be ignored if the ``model_path`` field in the ``qlknn_params`` section of
33-
the run config file or the ``qlknn_model_path`` flag are set.
33+
the run config file is set.
3434

3535
.. code-block:: console
3636
@@ -104,18 +104,6 @@ Provide a reference run to compare against in post-simulation plotting.
104104
--config='torax.examples.basic_config' \
105105
--reference_run=<path_to_reference_run>
106106
107-
qlknn_model_path
108-
^^^^^^^^^^^^^^^^
109-
Provide a path to load the QLKNN model from. This flag supersedes
110-
the path set in the config file and the ``TORAX_QLKNN_MODEL_PATH`` environment
111-
variable.
112-
113-
.. code-block:: console
114-
115-
python3 run_simulation_main.py \
116-
--config='torax.examples.basic_config' \
117-
--qlknn_model_path=<path_to_qlknn_model>
118-
119107
output_dir
120108
^^^^^^^^^^
121109
Override the default output directory. If not provided, it will be set to

run_simulation_main.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from torax.config import config_loader
3636
from torax.config import runtime_params
3737
from torax.plotting import plotruns_lib
38-
from torax.transport_model import qlknn_transport_model
3938

4039

4140
# String used when prompting the user to make a choice of command
@@ -101,17 +100,6 @@
101100
'If True, quits after the first operation (no interactive mode).',
102101
)
103102

104-
_QLKNN_MODEL_PATH = flags.DEFINE_string(
105-
'qlknn_model_path',
106-
None,
107-
'Path to the qlknn model network parameters (if using a QLKNN transport'
108-
' model). If not set, then it will use the value from the config in the'
109-
' "model_path" field in the qlknn_params. If that is not set, it will look'
110-
f' for the "{qlknn_transport_model.MODEL_PATH_ENV_VAR}" env variable.'
111-
' Finally, if this is also not set, it uses a hardcoded default path'
112-
f' "{qlknn_transport_model.DEFAULT_MODEL_PATH}".',
113-
)
114-
115103
_OUTPUT_DIR = flags.DEFINE_string(
116104
'output_dir',
117105
None,
@@ -204,7 +192,6 @@ def maybe_update_config_module(
204192
def change_config(
205193
sim: sim_lib.Sim,
206194
config_module_str: str,
207-
qlknn_model_path: str | None,
208195
) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams] | None:
209196
"""Returns a new Sim with the updated config but same SimulationStepFn.
210197
@@ -219,7 +206,6 @@ def change_config(
219206
Args:
220207
sim: Sim object used in the previous run.
221208
config_module_str: Config module being used.
222-
qlknn_model_path: QLKNN model path set by flag.
223209
224210
Returns:
225211
Tuple with:
@@ -257,9 +243,6 @@ def change_config(
257243
if hasattr(config_module, 'CONFIG'):
258244
# Assume that the config module uses the basic config dict to build Sim.
259245
sim_config = config_module.CONFIG
260-
config_loader.maybe_update_config_with_qlknn_model_path(
261-
sim_config, qlknn_model_path
262-
)
263246
new_runtime_params = build_sim.build_runtime_params_from_config(
264247
sim_config['runtime_params']
265248
)
@@ -316,7 +299,7 @@ def change_config(
316299

317300

318301
def change_sim_obj(
319-
config_module_str: str, qlknn_model_path: str | None
302+
config_module_str: str
320303
) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams, str]:
321304
"""Builds a new Sim from the config module.
322305
@@ -327,7 +310,6 @@ def change_sim_obj(
327310
Args:
328311
config_module_str: Config module used previously. User will have the
329312
opportunity to update which module to load.
330-
qlknn_model_path: QLKNN model path set by flag.
331313
332314
Returns:
333315
Tuple with:
@@ -344,7 +326,7 @@ def change_sim_obj(
344326
input('Press Enter when done changing the module.')
345327
sim, new_runtime_params = (
346328
config_loader.build_sim_and_runtime_params_from_config_module(
347-
config_module_str, qlknn_model_path, _PYTHON_CONFIG_PACKAGE.value
329+
config_module_str, _PYTHON_CONFIG_PACKAGE.value
348330
)
349331
)
350332
return sim, new_runtime_params, config_module_str
@@ -482,15 +464,14 @@ def main(_):
482464
log_sim_progress = _LOG_SIM_PROGRESS.value
483465
plot_sim_progress = _PLOT_SIM_PROGRESS.value
484466
log_sim_output = _LOG_SIM_OUTPUT.value
485-
qlknn_model_path = _QLKNN_MODEL_PATH.value
486467
sim = None
487468
new_runtime_params = None
488469
output_files = []
489470
try:
490471
start_time = time.time()
491472
sim, new_runtime_params = (
492473
config_loader.build_sim_and_runtime_params_from_config_module(
493-
config_module_str, qlknn_model_path, _PYTHON_CONFIG_PACKAGE.value
474+
config_module_str, _PYTHON_CONFIG_PACKAGE.value
494475
)
495476
)
496477
output_dir = (
@@ -573,8 +554,7 @@ def main(_):
573554
try:
574555
start_time = time.time()
575556
sim_and_runtime_params_or_none = change_config(
576-
sim, config_module_str, qlknn_model_path
577-
)
557+
sim, config_module_str)
578558
if sim_and_runtime_params_or_none is not None:
579559
sim, new_runtime_params = sim_and_runtime_params_or_none
580560
config_change_time = time.time() - start_time
@@ -595,7 +575,7 @@ def main(_):
595575
try:
596576
start_time = time.time()
597577
sim, new_runtime_params, config_module_str = change_sim_obj(
598-
config_module_str, qlknn_model_path
578+
config_module_str
599579
)
600580
sim_change_time = time.time() - start_time
601581
simulation_app.log_to_stdout(

torax/config/config_loader.py

-33
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import importlib
1818
import logging
19-
from typing import Any
2019

2120
from torax import sim
2221
from torax.config import build_sim
@@ -28,16 +27,13 @@
2827

2928
def build_sim_and_runtime_params_from_config_module(
3029
config_module_str: str,
31-
qlknn_model_path: str | None,
3230
config_package: str | None = None,
3331
) -> tuple[sim.Sim, runtime_params.GeneralRuntimeParams]:
3432
"""Returns a Sim and RuntimeParams from the config module.
3533
3634
Args:
3735
config_module_str: Python package path to config module. E.g.
3836
torax.examples.iterhybrid_predictor_corrector.
39-
qlknn_model_path: QLKNN model path set by flag. See qlknn_model_path flag
40-
docs.
4137
config_package: Optional, base package config is imported from. See
4238
config_package flag docs.
4339
"""
@@ -46,7 +42,6 @@ def build_sim_and_runtime_params_from_config_module(
4642
# The module likely uses the "basic" config setup which has a single CONFIG
4743
# dictionary defining the full simulation.
4844
config = config_module.CONFIG
49-
maybe_update_config_with_qlknn_model_path(config, qlknn_model_path)
5045
new_runtime_params = build_sim.build_runtime_params_from_config(
5146
config['runtime_params']
5247
)
@@ -56,8 +51,6 @@ def build_sim_and_runtime_params_from_config_module(
5651
):
5752
# The module is likely using the "advances", more Python-forward
5853
# configuration setup.
59-
if qlknn_model_path is not None:
60-
logging.warning('Cannot override qlknn model for this type of config.')
6154
new_runtime_params = config_module.get_runtime_params()
6255
simulator = config_module.get_sim()
6356
else:
@@ -68,32 +61,6 @@ def build_sim_and_runtime_params_from_config_module(
6861
return simulator, new_runtime_params
6962

7063

71-
def maybe_update_config_with_qlknn_model_path(
72-
config: dict[str, Any], qlknn_model_path: str | None
73-
) -> None:
74-
"""Sets the qlknn_model_path in the config if needed."""
75-
if qlknn_model_path is None:
76-
return
77-
if (
78-
'transport' not in config
79-
or 'transport_model' not in config['transport']
80-
or config['transport']['transport_model'] != 'qlknn'
81-
):
82-
return
83-
qlknn_params = config['transport'].get('qlknn_params', {})
84-
config_model_path = qlknn_params.get('model_path', '')
85-
if config_model_path:
86-
logging.info(
87-
'Overriding QLKNN model path from "%s" to "%s"',
88-
config_model_path,
89-
qlknn_model_path,
90-
)
91-
else:
92-
logging.info('Setting QLKNN model path to "%s".', qlknn_model_path)
93-
qlknn_params['model_path'] = qlknn_model_path
94-
config['transport']['qlknn_params'] = qlknn_params
95-
96-
9764
def import_module(module_name: str, config_package: str | None = None):
9865
"""Imports a module."""
9966
try:

0 commit comments

Comments
 (0)