1616from  torch .fx  import  Node 
1717
1818_ScalarType  =  Union [int , bool , float ]
19- _Argument  =  Union [torch . fx . Node , int , bool , float , str ]
19+ _Argument  =  Union [Node , int , bool , float , str ]
2020
2121
2222class  VkGraphBuilder :
@@ -29,7 +29,7 @@ def __init__(self, program: ExportedProgram) -> None:
2929        self .output_ids  =  []
3030        self .const_tensors  =  []
3131
32-         # Mapping from torch.fx. Node to VkValue id 
32+         # Mapping from Node to VkValue id 
3333        self .node_to_value_ids  =  {}
3434
3535    @staticmethod  
@@ -39,18 +39,18 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
3939        else :
4040            raise  AssertionError (f"Invalid dtype for vulkan_preprocess ({ torch_dtype }  )
4141
42-     def  is_constant (self , node : torch . fx . Node ):
42+     def  is_constant (self , node : Node ):
4343        return  (
4444            node .name  in  self .program .graph_signature .inputs_to_lifted_tensor_constants 
4545        )
4646
47-     def  is_get_attr_node (self , node : torch . fx . Node ) ->  bool :
47+     def  is_get_attr_node (self , node : Node ) ->  bool :
4848        """ 
4949        Returns true if the given node is a get attr node for a tensor of the model 
5050        """ 
51-         return  isinstance (node , torch . fx . Node ) and  node .op  ==  "get_attr" 
51+         return  isinstance (node , Node ) and  node .op  ==  "get_attr" 
5252
53-     def  is_param_node (self , node : torch . fx . Node ) ->  bool :
53+     def  is_param_node (self , node : Node ) ->  bool :
5454        """ 
5555        Check if the given node is a parameter within the exported program 
5656        """ 
@@ -61,7 +61,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool:
6161            or  self .is_constant (node )
6262        )
6363
64-     def  get_constant (self , node : torch . fx . Node ) ->  Optional [torch .Tensor ]:
64+     def  get_constant (self , node : Node ) ->  Optional [torch .Tensor ]:
6565        """ 
6666        Returns the constant associated with the given node in the exported program. 
6767        Returns None if the node is not a constant within the exported program 
@@ -79,7 +79,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
7979
8080        return  None 
8181
82-     def  get_param_tensor (self , node : torch . fx . Node ) ->  torch .Tensor :
82+     def  get_param_tensor (self , node : Node ) ->  torch .Tensor :
8383        tensor  =  None 
8484        if  node  is  None :
8585            raise  RuntimeError ("node is None" )
@@ -168,7 +168,7 @@ def create_string_value(self, string: str) -> int:
168168        return  new_id 
169169
170170    def  get_or_create_value_for (self , arg : _Argument ):
171-         if  isinstance (arg , torch . fx . Node ):
171+         if  isinstance (arg , Node ):
172172            # If the value has already been created, return the existing id 
173173            if  arg  in  self .node_to_value_ids :
174174                return  self .node_to_value_ids [arg ]
0 commit comments