@@ -240,6 +240,29 @@ def _call(self,
240240 """
241241 raise NotImplementedError ()
242242
243+ def __input_type_check (
244+ self , objs : Iterable [drgn .Object ]) -> Iterable [drgn .Object ]:
245+ assert self .input_type is not None
246+ valid_types = [type_canonicalize_name (self .input_type )]
247+
248+ #
249+ # Some commands support multiple input types. Check for InputHandler
250+ # implementations to expand the set of valid input types.
251+ #
252+ for (_ , method ) in inspect .getmembers (self , inspect .ismethod ):
253+ if not hasattr (method , "input_typename_handled" ):
254+ continue
255+ valid_types .append (
256+ type_canonicalize_name (method .input_typename_handled ))
257+
258+ for obj in objs :
259+ if type_canonical_name (obj .type_ ) not in valid_types :
260+ raise CommandError (
261+ self .name ,
262+ f'exepected input of type { self .input_type } , but received '
263+ f'type { obj .type_ } ' )
264+ yield obj
265+
243266 def __invalid_memory_objects_check (self , objs : Iterable [drgn .Object ],
244267 fatal : bool ) -> Iterable [drgn .Object ]:
245268 """
@@ -277,7 +300,11 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
277300 # the command is running.
278301 #
279302 try :
280- result = self ._call (objs )
303+ if self .input_type and objs :
304+ result = self ._call (self .__input_type_check (objs ))
305+ else :
306+ result = self ._call (objs )
307+
281308 if result is not None :
282309 #
283310 # The whole point of the SingleInputCommands are that
@@ -637,17 +664,7 @@ def _call( # type: ignore[return]
637664 This function will call pretty_print() on each input object,
638665 verifying the types as we go.
639666 """
640-
641- assert self .input_type is not None
642- type_name = type_canonicalize_name (self .input_type )
643- for obj in objs :
644- if type_canonical_name (obj .type_ ) != type_name :
645- raise CommandError (
646- self .name ,
647- f'exepected input of type { self .input_type } , but received '
648- f'type { obj .type_ } ' )
649-
650- self .pretty_print ([obj ])
667+ self .pretty_print (objs )
651668
652669
653670class Locator (Command ):
0 commit comments