Skip to content

Commit 1856276

Browse files
dagardner-nvyczhang-nv
authored andcommitted
Allow overriding configuration values not set in the YAML (NVIDIA#85)
* This removes the need to repeat default config values in a YAML for the sole purpose allowing users to override the values with the `--override` flag. * Remove unnecessary `assert` in `src/aiq/data_models/common.py` as this is actually already handled by the code, allowing users to prototype config classes directly in the interpreter. * Add unittests for the `LayeredConfig` class. Closes NVIDIA#83 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/AgentIQ/blob/develop/docs/source/advanced/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: NVIDIA#85 Signed-off-by: Yuchen Zhang <[email protected]>
1 parent 6a8d36b commit 1856276

File tree

4 files changed

+118
-9
lines changed

4 files changed

+118
-9
lines changed

docs/source/concepts/evaluate.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ The output of the evaluators are stored in distinct files in the same `output_di
317317
## Customizing the output
318318
You can customize the output of the pipeline by providing custom scripts. One or more Python scripts can be provided in the `eval.general.output_scripts` section of the `config.yml` file.
319319

320-
The custom scripts are executed after the evaluation is complete. They are executed as Python scripts with the kwargs provided in the `eval.general.output.custom_scripts.<script_name>.kwargs` section.
320+
The custom scripts are executed after the evaluation is complete. They are executed as Python scripts with the `kwargs` provided in the `eval.general.output.custom_scripts.<script_name>.kwargs` section.
321321

322-
The kwargs typically include the file or directory to operate on. To avoid overwriting contents it is recommended to provide a unique output file or directory name for the customization. It is also recommended that changes be limited to the contents of the output directory to avoid unintended side effects.
322+
The `kwargs` typically include the file or directory to operate on. To avoid overwriting contents it is recommended to provide a unique output file or directory name for the customization. It is also recommended that changes be limited to the contents of the output directory to avoid unintended side effects.
323323

324324
**Example:**
325325
```yaml

src/aiq/cli/cli_utils/config_override.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30+
class _Placeholder:
31+
"""Placeholder class to represent a value that is not set yet."""
32+
pass
33+
34+
3035
class LayeredConfig:
3136

3237
def __init__(self, base_config: dict[str, Any]):
@@ -46,7 +51,11 @@ def validate_path(self, path: str) -> None:
4651
current_path = '.'.join(parts[:i])
4752
raise click.BadParameter(f"Cannot navigate through non-dictionary value at '{current_path}'")
4853
if part not in current:
49-
raise click.BadParameter(f"Path '{path}' not found in config. '{part}' is invalid.")
54+
if i == len(parts) - 1:
55+
current[part] = _Placeholder()
56+
else:
57+
current[part] = {}
58+
5059
current = current[part]
5160

5261
def set_override(self, path: str, value: str) -> None:
@@ -70,9 +79,10 @@ def set_override(self, path: str, value: str) -> None:
7079
# Convert string value to appropriate type
7180
try:
7281
if isinstance(original_value, bool):
73-
if value.lower() not in ['true', 'false']:
82+
lower_value = value.lower().strip()
83+
if lower_value not in ['true', 'false']:
7484
raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
75-
value = value.lower() == 'true'
85+
value = lower_value == 'true'
7686
elif isinstance(original_value, (int, float)):
7787
value = type(original_value)(value)
7888
elif isinstance(original_value, list):
@@ -86,7 +96,12 @@ def set_override(self, path: str, value: str) -> None:
8696
# Store converted value
8797
self.overrides[path] = value
8898
self._effective_config = None
89-
logger.info("Successfully set override for %s with value: %s", path, value)
99+
100+
log_msg = f"Successfully set override for {path} with value: {value}"
101+
if not isinstance(original_value, _Placeholder):
102+
log_msg += f" with type {type(value)})"
103+
104+
logger.info(log_msg)
90105

91106
except Exception as e:
92107
logger.error("Failed to set override for %s: %s", path, str(e))

src/aiq/data_models/common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ def __init_subclass__(cls, name: str | None = None):
109109
module = inspect.getmodule(cls)
110110

111111
assert module is not None, f"Module not found for class {cls} when registering {name}"
112-
assert module.__package__ is not None, f"Package not found for class {cls} when registering {name}"
113-
114-
package_name: str = module.__package__
112+
package_name: str | None = module.__package__
115113

116114
# If the package name is not set, then we use the module name. Must have some namespace which will be unique
117115
if (not package_name):
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import click
17+
import pytest
18+
19+
from aiq.cli.cli_utils import config_override
20+
from aiq.data_models.function import FunctionBaseConfig
21+
22+
23+
@pytest.fixture(name="base_config")
24+
def fixture_base_config() -> dict:
25+
return {"a": {"b": 1, "c": 2}, "d": 3, "bool_val": True}
26+
27+
28+
def test_layered_config_set_override(base_config: dict):
29+
layered_config = config_override.LayeredConfig(base_config)
30+
31+
# Override a value that already exists
32+
layered_config.set_override("a.b", '10')
33+
34+
# Override a value that doesn't exist
35+
layered_config.set_override("a.e", '20')
36+
37+
# override a nested value
38+
layered_config.set_override("f.g", '30')
39+
40+
layered_config.set_override("bool_val", '\tfALse ')
41+
42+
assert layered_config.get_effective_config() == {
43+
"a": {
44+
"b": 10, "c": 2, "e": '20'
45+
}, "d": 3, "f": {
46+
"g": '30'
47+
}, "bool_val": False
48+
}
49+
50+
51+
def test_layered_config_set_override_error(base_config: dict):
52+
layered_config = config_override.LayeredConfig(base_config)
53+
54+
# Attempt to set an override with an invalid path
55+
with pytest.raises(click.BadParameter, match="Cannot navigate through non-dictionary value at 'a.b'"):
56+
layered_config.set_override("a.b.c", '10')
57+
58+
# Attempt to set an override a boolean value with an invalid string
59+
with pytest.raises(click.BadParameter, match="Boolean value must be 'true' or 'false', got 'not_a_bool'"):
60+
layered_config.set_override("bool_val", 'not_a_bool')
61+
62+
# Attempt to set a value with a type that doesn't match the original
63+
with pytest.raises(click.BadParameter, match=r"Type mismatch for 'a\.b'"):
64+
layered_config.set_override("a.b", 'not_a_number')
65+
66+
67+
def test_layered_config_constructor_error(base_config: dict):
68+
# Attempt to set an override with an invalid base config
69+
with pytest.raises(ValueError, match="Base config must be a dictionary"):
70+
config_override.LayeredConfig("invalid_base_config")
71+
72+
73+
def test_config_casting():
74+
"""
75+
Test to verify that pydantic's casting works as expected in situations where LayeredConfig
76+
is unable to determine the type of the value being set.
77+
"""
78+
79+
class TestConfig(FunctionBaseConfig, name="TestConfig"):
80+
a: bool
81+
b: int
82+
c: float
83+
84+
layered_config = config_override.LayeredConfig({})
85+
for (field, value) in (
86+
("a", "false"),
87+
("b", "45"),
88+
("c", "5.6"),
89+
):
90+
layered_config.set_override(field, value)
91+
92+
effective_config = layered_config.get_effective_config()
93+
config = TestConfig(**effective_config)
94+
assert config.a is False
95+
assert config.b == 45
96+
assert config.c == 5.6

0 commit comments

Comments
 (0)