1616# under the License.
1717
1818import collections .abc
19- from typing import Dict
19+ from copy import deepcopy
20+ from typing import Any , ClassVar , Dict , MutableMapping , Optional , Union , overload
2021
21- from .utils import DslBase
22+ from .utils import DslBase , JSONType
2223
2324
24- # Incomplete annotation to not break query.py tests
25- def SF (name_or_sf , ** params ) -> "ScoreFunction" :
25+ @overload
26+ def SF (name_or_sf : MutableMapping [str , Any ]) -> "ScoreFunction" : ...
27+
28+
29+ @overload
30+ def SF (name_or_sf : "ScoreFunction" ) -> "ScoreFunction" : ...
31+
32+
33+ @overload
34+ def SF (name_or_sf : str , ** params : Any ) -> "ScoreFunction" : ...
35+
36+
37+ def SF (
38+ name_or_sf : Union [str , "ScoreFunction" , MutableMapping [str , Any ]],
39+ ** params : Any ,
40+ ) -> "ScoreFunction" :
2641 # {"script_score": {"script": "_score"}, "filter": {}}
27- if isinstance (name_or_sf , collections .abc .Mapping ):
42+ if isinstance (name_or_sf , collections .abc .MutableMapping ):
2843 if params :
2944 raise ValueError ("SF() cannot accept parameters when passing in a dict." )
30- kwargs = {}
31- sf = name_or_sf .copy ()
45+
46+ kwargs : Dict [str , Any ] = {}
47+ sf = deepcopy (name_or_sf )
3248 for k in ScoreFunction ._param_defs :
3349 if k in name_or_sf :
3450 kwargs [k ] = sf .pop (k )
3551
3652 # not sf, so just filter+weight, which used to be boost factor
53+ sf_params = params
3754 if not sf :
3855 name = "boost_factor"
3956 # {'FUNCTION': {...}}
4057 elif len (sf ) == 1 :
41- name , params = sf .popitem ()
58+ name , sf_params = sf .popitem ()
4259 else :
4360 raise ValueError (f"SF() got an unexpected fields in the dictionary: { sf !r} " )
4461
4562 # boost factor special case, see elasticsearch #6343
46- if not isinstance (params , collections .abc .Mapping ):
47- params = {"value" : params }
63+ if not isinstance (sf_params , collections .abc .Mapping ):
64+ sf_params = {"value" : sf_params }
4865
4966 # mix known params (from _param_defs) and from inside the function
50- kwargs .update (params )
67+ kwargs .update (sf_params )
5168 return ScoreFunction .get_dsl_class (name )(** kwargs )
5269
5370 # ScriptScore(script="_score", filter=Q())
@@ -70,14 +87,16 @@ class ScoreFunction(DslBase):
7087 "filter" : {"type" : "query" },
7188 "weight" : {},
7289 }
73- name = None
90+ name : ClassVar [ Optional [ str ]] = None
7491
75- def to_dict (self ):
92+ def to_dict (self ) -> Dict [ str , JSONType ] :
7693 d = super ().to_dict ()
7794 # filter and query dicts should be at the same level as us
7895 for k in self ._param_defs :
79- if k in d [self .name ]:
80- d [k ] = d [self .name ].pop (k )
96+ if self .name is not None :
97+ val = d [self .name ]
98+ if isinstance (val , dict ) and k in val :
99+ d [k ] = val .pop (k )
81100 return d
82101
83102
@@ -88,12 +107,15 @@ class ScriptScore(ScoreFunction):
88107class BoostFactor (ScoreFunction ):
89108 name = "boost_factor"
90109
91- def to_dict (self ) -> Dict [str , int ]:
110+ def to_dict (self ) -> Dict [str , JSONType ]:
92111 d = super ().to_dict ()
93- if "value" in d [self .name ]:
94- d [self .name ] = d [self .name ].pop ("value" )
95- else :
96- del d [self .name ]
112+ if self .name is not None :
113+ val = d [self .name ]
114+ if isinstance (val , dict ):
115+ if "value" in val :
116+ d [self .name ] = val .pop ("value" )
117+ else :
118+ del d [self .name ]
97119 return d
98120
99121
0 commit comments