Skip to content

Commit

Permalink
Fix command line help from argparse formatting problem (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie authored and hramezani committed Jun 10, 2024
1 parent 813ac94 commit 321d36d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
25 changes: 18 additions & 7 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,11 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:

return True, allow_parse_failure

# Default value of `case_sensitive` is `False`, because we don't want to break existing behavior.
# We have to change the method to a non-static method and use
# `self.case_sensitive` instead in V3.
@staticmethod
def next_field(field: FieldInfo | Any | None, key: str) -> FieldInfo | None:
def next_field(field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None) -> FieldInfo | None:
"""
Find the field in a sub model by key(env name)
Expand All @@ -623,6 +626,7 @@ class Cfg(BaseSettings):
Args:
field: The field.
key: The key (env name).
case_sensitive: Whether to search for key case sensitively.
Returns:
Field if it finds the next field otherwise `None`.
Expand All @@ -633,11 +637,18 @@ class Cfg(BaseSettings):
annotation = field.annotation if isinstance(field, FieldInfo) else field
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
for type_ in get_args(annotation):
type_has_key = EnvSettingsSource.next_field(type_, key)
type_has_key = EnvSettingsSource.next_field(type_, key, case_sensitive)
if type_has_key:
return type_has_key
elif is_model_class(annotation) and annotation.model_fields.get(key):
return annotation.model_fields[key]
elif is_model_class(annotation):
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
if (case_sensitive is None or case_sensitive) and annotation.model_fields.get(key):
return annotation.model_fields[key]
elif not case_sensitive:
for field_name, f in annotation.model_fields.items():
if field_name.lower() == key.lower():
return f

return None

Expand Down Expand Up @@ -670,12 +681,12 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
env_var = result
target_field: FieldInfo | None = field
for key in keys:
target_field = self.next_field(target_field, key)
target_field = self.next_field(target_field, key, self.case_sensitive)
if isinstance(env_var, dict):
env_var = env_var.setdefault(key, {})

# get proper field with last_key
target_field = self.next_field(target_field, last_key)
target_field = self.next_field(target_field, last_key, self.case_sensitive)

# check if env_val maps to a complex field and if so, parse the env_val
if (target_field or is_dict) and env_val:
Expand Down Expand Up @@ -1407,7 +1418,7 @@ def _help_format(self, field_info: FieldInfo) -> str:
elif field_info.default_factory is not None:
default = f'(default: {field_info.default_factory})'
_help += f' {default}' if _help else default
return _help
return _help.replace('%', '%%') if issubclass(type(self._root_parser), ArgumentParser) else _help


class ConfigFileSourceMixin(ABC):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,29 @@ class Cfg(BaseSettings):
)


def test_cli_help_string_format(capsys, monkeypatch):
class Cfg(BaseSettings):
date_str: str = '%Y-%m-%d'

argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments'

with monkeypatch.context() as m:
m.setattr(sys, 'argv', ['example.py', '--help'])

with pytest.raises(SystemExit):
Cfg(_cli_parse_args=True)

assert (
re.sub(r'0x\w+', '0xffffffff', capsys.readouterr().out, re.MULTILINE)
== f"""usage: example.py [-h] [--date_str str]
{argparse_options_text}:
-h, --help show this help message and exit
--date_str str (default: %Y-%m-%d)
"""
)


def test_cli_nested_dataclass_arg():
@pydantic_dataclasses.dataclass
class MyDataclass:
Expand Down Expand Up @@ -3836,3 +3859,17 @@ class Settings(BaseSettings):
env.set('nested__bar', '123')
s = Settings()
assert s.model_dump() == {'nested': {'BaR': 123, 'FOO': 'string'}}


def test_case_insensitive_nested_list(env):
class NestedSettings(BaseModel):
FOO: list[str]

class Settings(BaseSettings):
model_config = SettingsConfigDict(env_nested_delimiter='__', case_sensitive=False)

nested: Optional[NestedSettings]

env.set('nested__FOO', '["string1", "string2"]')
s = Settings()
assert s.model_dump() == {'nested': {'FOO': ['string1', 'string2']}}

0 comments on commit 321d36d

Please sign in to comment.