@@ -79,13 +79,13 @@ def fn_(obj: Any) -> None:
7979            obj .detach_ ().requires_grad_ (requires_grad )
8080
8181    if  isinstance (target , ModuleState ):
82-         true_target  =  cast (TensorTree , (target .params , target .buffers ))
82+         true_target  =  cast (' TensorTree'  , (target .params , target .buffers ))
8383    elif  isinstance (target , nn .Module ):
84-         true_target  =  cast (TensorTree , tuple (target .parameters ()))
84+         true_target  =  cast (' TensorTree'  , tuple (target .parameters ()))
8585    elif  isinstance (target , MetaOptimizer ):
86-         true_target  =  cast (TensorTree , target .state_dict ())
86+         true_target  =  cast (' TensorTree'  , target .state_dict ())
8787    else :
88-         true_target  =  cast (TensorTree , target )  # tree of tensors 
88+         true_target  =  cast (' TensorTree'  , target )  # tree of tensors 
8989
9090    pytree .tree_map_ (fn_ , true_target )
9191
@@ -325,7 +325,7 @@ def recover_state_dict(
325325    from  torchopt .optim .meta .base  import  MetaOptimizer 
326326
327327    if  isinstance (target , nn .Module ):
328-         params , buffers , * _  =  state  =  cast (ModuleState , state )
328+         params , buffers , * _  =  state  =  cast (' ModuleState'  , state )
329329        params_containers , buffers_containers  =  extract_module_containers (target , with_buffers = True )
330330
331331        if  state .detach_buffers :
@@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
343343        ):
344344            tgt .update (src )
345345    elif  isinstance (target , MetaOptimizer ):
346-         state  =  cast (Sequence [OptState ], state )
346+         state  =  cast (' Sequence[OptState]'  , state )
347347        target .load_state_dict (state )
348348    else :
349349        raise  TypeError (f'Unexpected class of { target }  ' )
@@ -422,9 +422,9 @@ def module_clone(  # noqa: C901
422422
423423    if  isinstance (target , (nn .Module , MetaOptimizer )):
424424        if  isinstance (target , nn .Module ):
425-             containers  =  cast (TensorTree , extract_module_containers (target , with_buffers = True ))
425+             containers  =  cast (' TensorTree'  , extract_module_containers (target , with_buffers = True ))
426426        else :
427-             containers  =  cast (TensorTree , target .state_dict ())
427+             containers  =  cast (' TensorTree'  , target .state_dict ())
428428        tensors  =  pytree .tree_leaves (containers )
429429        memo  =  {id (t ): t  for  t  in  tensors }
430430        cloned  =  copy .deepcopy (target , memo = memo )
@@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
476476    else :
477477        replicate  =  clone_detach_ 
478478
479-     return  pytree .tree_map (replicate , cast (TensorTree , target ))
479+     return  pytree .tree_map (replicate , cast (' TensorTree'  , target ))
480480
481481
482482@overload  
0 commit comments