Skip to content

Commit 44eb5df

Browse files
committed
- Allow providing a config with init_args but no class_path #113.
- When a class_path is overridden, now only the config values that the new subclass doesn't accept are discarded. - Single dash '-' incorrectly parsed as [None].
1 parent 0209c1f commit 44eb5df

File tree

10 files changed

+160
-71
lines changed

10 files changed

+160
-71
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Added
2424
<https://github.com/PyTorchLightning/pytorch-lightning/issues/11653>`__.
2525
- Support init args for unresolved parameters in subclasses `#114
2626
<https://github.com/omni-us/jsonargparse/issues/114>`__.
27+
- Allow providing a config with ``init_args`` but no ``class_path`` `#113
28+
<https://github.com/omni-us/jsonargparse/issues/113>`__.
2729

2830
Fixed
2931
^^^^^
@@ -35,6 +37,7 @@ Fixed
3537
- In some cases ``print_config`` could output invalid values. Now a lenient
3638
check is done while dumping.
3739
- Resolved some issues related to the logger property and reconplogger.
40+
- Single dash ``'-'`` incorrectly parsed as ``[None]``.
3841

3942
-Changed
4043
^^^^^^^
@@ -45,6 +48,8 @@ Fixed
4548
- ``JSONARGPARSE_DEBUG`` now also sets the reconplogger level to ``DEBUG``.
4649
- Renamed the test files to follow the more standard ``test_*.py`` pattern.
4750
- Now ``bool(Namespace())`` evaluates to ``False``.
51+
- When a ``class_path`` is overridden, now only the config values that the new
52+
subclass doesn't accept are discarded.
4853

4954
Deprecated
5055
^^^^^^^^^^

jsonargparse/core.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .loaders_dumpers import check_valid_dump_format, dump_using_format, get_loader_exceptions, loaders, load_value, load_value_context, yaml_load
1818
from .namespace import is_meta_key, Namespace, split_key, split_key_leaf, strip_meta
1919
from .signatures import is_pure_dataclass, SignatureArguments
20-
from .typehints import ActionTypeHint, is_class_object
20+
from .typehints import ActionTypeHint, is_subclass_spec
2121
from .typing import is_final_class
2222
from .actions import (
2323
ActionParser,
@@ -570,7 +570,7 @@ def parse_string(
570570
with_meta=with_meta,
571571
skip_check=_skip_check,
572572
fail_no_subcommand=_fail_no_subcommand,
573-
log_message=(f'Parsed {self.parser_mode} string.'),
573+
log_message=f'Parsed {self.parser_mode} string.',
574574
)
575575

576576
except (TypeError, KeyError) as ex:
@@ -740,7 +740,7 @@ def _dump_delete_default_entries(self, subcfg, subdefaults):
740740
val = subcfg[key]
741741
default = subdefaults[key]
742742
class_object_val = None
743-
if is_class_object(val):
743+
if is_subclass_spec(val):
744744
if val['class_path'] != default.get('class_path'):
745745
parser = ActionTypeHint.get_class_parser(val['class_path'])
746746
default = {'init_args': parser.get_defaults().as_dict()}
@@ -1218,8 +1218,7 @@ def _apply_actions(self, cfg: Union[Namespace, Dict[str, Any]], parent_key: str
12181218
return cfg[parent_key] if parent_key else cfg
12191219

12201220

1221-
@staticmethod
1222-
def merge_config(cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
1221+
def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
12231222
"""Merges the first configuration into the second configuration.
12241223
12251224
Args:
@@ -1229,14 +1228,8 @@ def merge_config(cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
12291228
Returns:
12301229
A new object with the merged configuration.
12311230
"""
1232-
del_keys = []
1233-
for key_class_path in [k for k in cfg_from.keys() if k.endswith('.class_path')]:
1234-
key_init_args = key_class_path[:-len('class_path')] + 'init_args'
1235-
if key_class_path in cfg_to and key_init_args in cfg_to and cfg_from[key_class_path] != cfg_to[key_class_path]:
1236-
del_keys.append(key_init_args)
12371231
cfg = cfg_to.clone()
1238-
for key in reversed(del_keys):
1239-
del cfg[key]
1232+
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg, cfg_from)
12401233
cfg.update(cfg_from)
12411234
return cfg
12421235

jsonargparse/loaders_dumpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def remove_implicit_resolver(cls, tag_to_remove):
6464

6565

6666
def yaml_load(stream):
67-
value = yaml.load(stream, Loader=DefaultLoader)
67+
if stream.strip() == '-':
68+
value = stream
69+
else:
70+
value = yaml.load(stream, Loader=DefaultLoader)
6871
if isinstance(value, dict) and all(v is None for v in value.values()):
6972
keys = {k for k in regex_curly_comma.split(stream) if k}
7073
if len(keys) > 0 and keys == set(value.keys()):

jsonargparse/parameter_resolvers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ParamData:
2323

2424

2525
ParamList = List[ParamData]
26-
parameter_attributes = [s[1:] for s in inspect.Parameter.__slots__]
26+
parameter_attributes = [s[1:] for s in inspect.Parameter.__slots__] # type: ignore
2727
kinds = inspect._ParameterKind
2828

2929

jsonargparse/typehints.py

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
is_subclass,
5050
iter_to_set_str,
5151
lenient_check,
52+
lenient_check_context,
5253
NoneType,
5354
object_path_serializer,
5455
ParserError,
@@ -240,6 +241,15 @@ def parse_subclass_arg(arg_string):
240241
subclass_arg_parser.set(subparser)
241242

242243

244+
@staticmethod
245+
def discard_init_args_on_class_path_change(parser, cfg_to, cfg_from):
246+
for action in [a for a in parser._actions if ActionTypeHint.is_subclass_typehint(a)]:
247+
val_to = cfg_to.get(action.dest)
248+
val_from = cfg_from.get(action.dest)
249+
if is_subclass_spec(val_to) and is_subclass_spec(val_from):
250+
discard_init_args_on_class_path_change(action, val_to, val_from)
251+
252+
243253
@staticmethod
244254
@contextmanager
245255
def subclass_arg_context(parser):
@@ -585,7 +595,7 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
585595
# Callable
586596
elif typehint_origin in callable_origin_types or typehint in callable_origin_types:
587597
if serialize:
588-
if is_class_object(val):
598+
if is_subclass_spec(val):
589599
val = adapt_class_type(val, True, False, sub_add_kwargs)
590600
else:
591601
val = object_path_serializer(val)
@@ -600,12 +610,11 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
600610
else:
601611
raise ImportError(f'Unexpected import object {val_obj}')
602612
if isinstance(val, (dict, Namespace)):
603-
if not is_class_object(val):
613+
if not is_subclass_spec(val):
604614
raise ImportError(f'Dict must include a class_path and optionally init_args, but got {val}')
605-
val = Namespace(val)
606-
val_class = import_object(val.class_path)
615+
val_class = import_object(val['class_path'])
607616
if not (inspect.isclass(val_class) and callable_instances(val_class)):
608-
raise ImportError(f'"{val.class_path}" is not a callable class.')
617+
raise ImportError(f'{val["class_path"]!r} is not a callable class.')
609618
val['class_path'] = get_import_path(val_class)
610619
val = adapt_class_type(val, False, instantiate_classes, sub_add_kwargs)
611620
except (ImportError, AttributeError, ParserError) as ex:
@@ -619,25 +628,25 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
619628
return val
620629
if serialize and isinstance(val, str):
621630
return val
622-
if not (isinstance(val, str) or is_class_object(val)):
623-
raise ValueError(f'Type {typehint} expects an str or a Dict/Namespace with a class_path entry but got "{val}"')
631+
632+
val_input = val
633+
val = subclass_spec_as_namespace(val)
634+
if not is_subclass_spec(val) and prev_val and 'class_path' in prev_val:
635+
if 'init_args' in val:
636+
val['class_path'] = prev_val['class_path']
637+
else:
638+
val = Namespace(class_path=prev_val['class_path'], init_args=val)
639+
if not is_subclass_spec(val):
640+
raise ValueError(
641+
f'Type {typehint} expects: a class path (str); or a dict with a class_path entry; '
642+
f'or a dict with init_args (if class path given previously). Got "{val_input}".'
643+
)
644+
624645
try:
625-
if isinstance(val, str):
626-
val = Namespace(class_path=val)
627-
elif isinstance(val, dict):
628-
val = Namespace(val)
629646
val_class = import_object(resolve_class_path_by_name(typehint, val['class_path']))
630647
if not is_subclass(val_class, typehint):
631648
raise ValueError(f'"{val["class_path"]}" is not a subclass of {typehint}')
632-
val['class_path'] = class_path = get_import_path(val_class)
633-
if isinstance(prev_val, Namespace) and 'class_path' in prev_val and 'init_args' not in val:
634-
prev_class_path = prev_val['class_path']
635-
prev_init_args = prev_val.get('init_args')
636-
if prev_class_path != class_path and prev_init_args:
637-
warnings.warn(
638-
f'Changing class_path to {class_path} implies discarding init_args {prev_init_args.as_dict()} '
639-
f'defined for class_path {prev_class_path}.'
640-
)
649+
val['class_path'] = get_import_path(val_class)
641650
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
642651
except (ImportError, AttributeError, AssertionError, ParserError) as ex:
643652
class_path = val if isinstance(val, str) else val['class_path']
@@ -647,14 +656,24 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
647656
return val
648657

649658

650-
def is_class_object(val):
659+
def is_subclass_spec(val):
651660
is_class = isinstance(val, (dict, Namespace)) and 'class_path' in val
652661
if is_class:
653662
keys = getattr(val, '__dict__', val).keys()
654663
is_class = len(set(keys)-{'class_path', 'init_args', 'dict_kwargs', '__path__'}) == 0
655664
return is_class
656665

657666

667+
def subclass_spec_as_namespace(val):
668+
if isinstance(val, str):
669+
return Namespace(class_path=val)
670+
if isinstance(val, dict):
671+
val = Namespace(val)
672+
if 'init_args' in val and isinstance(val['init_args'], dict):
673+
val['init_args'] = Namespace(val['init_args'])
674+
return val
675+
676+
658677
class NestedArg:
659678
def __init__(self, key, val):
660679
self.key = key
@@ -716,8 +735,34 @@ def dump_kwargs_context(kwargs):
716735
yield
717736

718737

738+
def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
739+
if prev_val and 'init_args' in prev_val and prev_val['class_path'] != value.class_path:
740+
parser = parser_or_action
741+
if isinstance(parser_or_action, ActionTypeHint):
742+
sub_add_kwargs = getattr(parser_or_action, 'sub_add_kwargs', {})
743+
parser = ActionTypeHint.get_class_parser(value.class_path, sub_add_kwargs)
744+
prev_val = subclass_spec_as_namespace(prev_val)
745+
del_args = {}
746+
for key, val in list(prev_val.init_args.__dict__.items()):
747+
action = _find_action(parser, key)
748+
if action:
749+
with lenient_check_context(lenient=False):
750+
try:
751+
parser._check_value_key(action, val, key, Namespace())
752+
except Exception:
753+
action = None
754+
if not action:
755+
del_args[key] = prev_val.init_args.pop(key)
756+
if del_args:
757+
warnings.warn(
758+
f'Due to class_path change from {prev_val.class_path!r} to {value.class_path!r}, '
759+
f'discarding init_args: {del_args}.'
760+
)
761+
762+
719763
def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None):
720-
val_class = import_object(value['class_path'])
764+
value = subclass_spec_as_namespace(value)
765+
val_class = import_object(value.class_path)
721766
parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs)
722767

723768
# No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream.
@@ -736,18 +781,18 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
736781

737782
break
738783

739-
unresolved = value.pop('dict_kwargs', {})
784+
discard_init_args_on_class_path_change(parser, prev_val, value)
785+
786+
dict_kwargs = value.pop('dict_kwargs', {})
740787
init_args = value.get('init_args', Namespace())
741-
if isinstance(init_args, dict):
742-
value['init_args'] = init_args = Namespace(init_args)
743788

744789
if instantiate_classes:
745790
init_args = parser.instantiate_classes(init_args)
746791
if not sub_add_kwargs.get('instantiate', True):
747792
if init_args:
748793
value['init_args'] = init_args
749794
return value
750-
return val_class(**{**init_args, **unresolved})
795+
return val_class(**{**init_args, **dict_kwargs})
751796

752797
if isinstance(init_args, NestedArg):
753798
value['init_args'] = parser.parse_args(
@@ -761,29 +806,29 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
761806
if init_args:
762807
value['init_args'] = load_value(parser.dump(init_args, **dump_kwargs.get()))
763808
else:
764-
if isinstance(unresolved, str):
765-
unresolved = load_value(unresolved)
766-
if isinstance(unresolved, (dict, Namespace)):
767-
for key in list(unresolved.keys()):
809+
if isinstance(dict_kwargs, str):
810+
dict_kwargs = load_value(dict_kwargs)
811+
if isinstance(dict_kwargs, (dict, Namespace)):
812+
for key in list(dict_kwargs.keys()):
768813
if _find_action(parser, key):
769-
init_args[key] = unresolved.pop(key)
770-
if isinstance(unresolved, Namespace):
771-
unresolved = dict(unresolved)
772-
elif unresolved:
773-
init_args['dict_kwargs'] = unresolved
774-
unresolved = None
814+
init_args[key] = dict_kwargs.pop(key)
815+
if isinstance(dict_kwargs, Namespace):
816+
dict_kwargs = dict(dict_kwargs)
817+
elif dict_kwargs:
818+
init_args['dict_kwargs'] = dict_kwargs
819+
dict_kwargs = None
775820
init_args = parser.parse_object(init_args, defaults=sub_defaults.get())
776821
if init_args:
777822
value['init_args'] = init_args
778-
if unresolved:
779-
value['dict_kwargs'] = unresolved
823+
if dict_kwargs:
824+
value['dict_kwargs'] = dict_kwargs
780825
return value
781826

782827

783828
def not_append_diff(val1, val2):
784829
if isinstance(val1, list) and isinstance(val2, list):
785-
val1 = [x.get('class_path') if is_class_object(x) else x for x in val1]
786-
val2 = [x.get('class_path') if is_class_object(x) else x for x in val2]
830+
val1 = [x.get('class_path') if is_subclass_spec(x) else x for x in val1]
831+
val2 = [x.get('class_path') if is_subclass_spec(x) else x for x in val2]
787832
return val1 != val2
788833

789834

jsonargparse/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def object_path_serializer(value):
252252

253253

254254
@contextmanager
255-
def lenient_check_context(caller=None):
256-
t = lenient_check.set(False if caller == 'argcomplete' else True)
255+
def lenient_check_context(caller=None, lenient=True):
256+
t = lenient_check.set(False if caller == 'argcomplete' else lenient)
257257
try:
258258
yield
259259
finally:

jsonargparse_tests/test_core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from contextlib import redirect_stderr, redirect_stdout
1414
from collections import OrderedDict
1515
from random import randint, shuffle
16+
from typing import Optional
1617
from jsonargparse import (
1718
ActionConfigFile,
1819
ActionJsonnet,
@@ -447,7 +448,6 @@ def test_subcommands(self):
447448
self.assertRaises(ParserError, lambda: parser.parse_string('{"a": {"ap1": "ap1_cfg", "unk": "unk_cfg"}}'))
448449

449450
with warnings.catch_warnings(record=True) as w:
450-
warnings.simplefilter("always")
451451
cfg = parser.parse_string('{"a": {"ap1": "ap1_cfg"}, "b": {"nums": {"val1": 2}}}')
452452
self.assertEqual(cfg.subcommand, 'a')
453453
self.assertFalse(hasattr(cfg, 'b'))
@@ -1057,9 +1057,12 @@ def test_strip_unknown(self):
10571057

10581058

10591059
def test_merge_config(self):
1060+
parser = ArgumentParser()
1061+
for key in [1, 2, 3]:
1062+
parser.add_argument(f'--op{key}', type=Optional[int])
10601063
cfg_from = Namespace(op1=1, op2=None)
10611064
cfg_to = Namespace(op1=None, op2=2, op3=3)
1062-
cfg = ArgumentParser.merge_config(cfg_from, cfg_to)
1065+
cfg = parser.merge_config(cfg_from, cfg_to)
10631066
self.assertEqual(cfg, Namespace(op1=1, op2=None, op3=3))
10641067

10651068

jsonargparse_tests/test_loaders_dumpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from importlib.util import find_spec
66
from typing import List
77
from jsonargparse import ActionConfigFile, ArgumentParser, set_dumper, set_loader, ParserError
8-
from jsonargparse.loaders_dumpers import yaml_dump
8+
from jsonargparse.loaders_dumpers import load_value, load_value_context, yaml_dump
99
from jsonargparse.optionals import dump_preserve_order_support
1010

1111

@@ -120,5 +120,11 @@ def test_dump_header_invalid(self):
120120
parser.dump_header = True
121121

122122

123+
def test_load_value_dash(self):
124+
with load_value_context('yaml'):
125+
self.assertEqual('-', load_value('-'))
126+
self.assertEqual(' - ', load_value(' - '))
127+
128+
123129
if __name__ == '__main__':
124130
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)