@@ -210,14 +210,24 @@ def __init__(
210210 if baselines is None :
211211 # default baseline is to remove the element
212212 baselines = ["" ] * len (values )
213- elif dict_keys :
214- assert isinstance (baselines , dict ), (
215- "if values is dict, the baselines must also be a dict, "
216- f"received: { type (baselines )} "
217- )
213+ elif not callable (baselines ):
214+ if dict_keys :
215+ assert isinstance (baselines , dict ), (
216+ "if values is a dict, the baselines must also be a dict "
217+ "or a callable which return a dict, "
218+ f"received: { type (baselines )} "
219+ )
218220
219- # convert dict to list
220- baselines = [baselines [k ] for k in self .dict_keys ]
221+ # convert dict to list
222+ baselines = [baselines [k ] for k in dict_keys ]
223+ else :
224+ assert isinstance (baselines , list ), (
225+ "if values is a list, the baselines must also be a list "
226+ "or a callable which return a list, "
227+ f"received: { type (baselines )} "
228+ )
229+
230+ self .baselines = baselines
221231
222232 if mask is None :
223233 n_itp_features = n_features
@@ -247,14 +257,13 @@ def __init__(
247257 if isinstance (template , str ):
248258 template = template .format
249259 else :
250- assert isinstance (template , Callable ), (
260+ assert callable (template ), (
251261 "the template must be either a string or a callable, "
252262 f"received: { type (template )} "
253263 )
254264 template = template
255265 self .format_fn = template
256266
257- self .baselines = baselines
258267 self .mask = mask
259268
260269 def to_tensor (self ) -> torch .Tensor :
@@ -265,13 +274,23 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
265274 values = list (self .values ) # clone
266275
267276 if perturbed_tensor is not None :
268- baselines = self .baselines
269- if isinstance (baselines , Callable ):
277+ if callable (self .baselines ):
270278 # a placeholder for advanced baselines
271279 # TODO: support callable baselines
272280 baselines = self .baselines ()
273281 if self .dict_keys :
282+ assert isinstance (baselines , dict ), (
283+ "if values is a dict and the baselines is a callable"
284+ f"it must return a dict, received: { type (baselines )} "
285+ )
274286 baselines = [baselines [k ] for k in self .dict_keys ]
287+ else :
288+ assert isinstance (baselines , list ), (
289+ "if values is a list and the baselines is a callable"
290+ f"it must return a list, received: { type (baselines )} "
291+ )
292+ else :
293+ baselines = self .baselines
275294
276295 for i in range (len (values )):
277296 itp_idx = i
@@ -284,8 +303,8 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
284303 values [i ] = baselines [i ]
285304
286305 if self .dict_keys :
287- values = dict (zip (self .dict_keys , values ))
288- input_str = self .format_fn (** values )
306+ dict_values = dict (zip (self .dict_keys , values ))
307+ input_str = self .format_fn (** dict_values )
289308 else :
290309 input_str = self .format_fn (* values )
291310
0 commit comments