Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix alias resolution to use preferred key. #481

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
a flag to determine whether value is complex.
"""

for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
# paths reversed to match the last-wins behaviour of `env_file`
for secrets_path in reversed(self.secrets_paths):
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
Expand All @@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
continue

if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
return path.read_text().strip(), preferred_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)

return None, field_key, value_is_complex
return None, preferred_key, value_is_complex

def __repr__(self) -> str:
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
Expand Down Expand Up @@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
"""

env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
env_val = self.env_vars.get(env_name)
if env_val is not None:
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
break

return env_val, field_key, value_is_complex
return env_val, preferred_key, value_is_complex

def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Expand Down
33 changes: 32 additions & 1 deletion tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing_extensions
from pydantic import (
AliasChoices,
AliasGenerator,
AliasPath,
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -107,7 +108,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
return self.parser.parse_args(*args, **kwargs)


def test_validation_alias_with_cli_prefix():
def test_cli_validation_alias_with_cli_prefix():
class Settings(BaseSettings, cli_exit_on_error=False):
foobar: str = Field(validation_alias='foo')

Expand All @@ -119,6 +120,36 @@ class Settings(BaseSettings, cli_exit_on_error=False):
assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar'


@pytest.mark.parametrize(
'alias_generator',
[
AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))),
AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)),
],
)
def test_cli_alias_resolution_consistency_with_env(env, alias_generator):
class SubModel(BaseModel):
v1: str = 'model default'

class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_nested_delimiter='__',
nested_model_default_partial_update=True,
alias_generator=alias_generator,
)

sub_model: SubModel = SubModel(v1='top default')

assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}}

env.set('SUB_MODEL__V1', 'env default')
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}}

assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == {
'sub_model': {'v1': 'cli default'}
}


def test_cli_nested_arg():
class SubSubValue(BaseModel):
v6: str
Expand Down
Loading