@@ -555,73 +555,9 @@ def attribute_future(
555555                        ]
556556                    )
557557
558-                     def  eval_fut_to_ablated_out_fut (
559-                         # pyre-ignore Invalid type parameters [24] 
560-                         eval_futs : Future [List [Future [List [object ]]]],
561-                         current_inputs : Tuple [Tensor , ...],
562-                         current_mask : Tensor ,
563-                         i : int ,
564-                         perturbations_per_eval : int ,
565-                         num_examples : int ,
566-                         formatted_inputs : Tuple [Tensor , ...],
567-                     ) ->  Tuple [List [Tensor ], List [Tensor ]]:
568-                         try :
569-                             modified_eval  =  cast (Tensor , eval_futs .value ()[1 ].value ())
570-                             initial_eval_tuple  =  cast (
571-                                 Tuple [
572-                                     List [Tensor ],
573-                                     List [Tensor ],
574-                                     Tensor ,
575-                                     Tensor ,
576-                                     int ,
577-                                     dtype ,
578-                                 ],
579-                                 eval_futs .value ()[0 ].value (),
580-                             )
581-                             if  len (initial_eval_tuple ) !=  6 :
582-                                 raise  AssertionError (
583-                                     "eval_fut_to_ablated_out_fut: " 
584-                                     "initial_eval_tuple should have 6 elements: " 
585-                                     "total_attrib, weights, initial_eval, " 
586-                                     "flattened_initial_eval, n_outputs, attrib_type " 
587-                                 )
588-                             if  not  isinstance (modified_eval , Tensor ):
589-                                 raise  AssertionError (
590-                                     "eval_fut_to_ablated_out_fut: " 
591-                                     "modified eval should be a Tensor" 
592-                                 )
593-                             (
594-                                 total_attrib ,
595-                                 weights ,
596-                                 initial_eval ,
597-                                 flattened_initial_eval ,
598-                                 n_outputs ,
599-                                 attrib_type ,
600-                             ) =  initial_eval_tuple 
601-                             result  =  self ._process_ablated_out (  # type: ignore # noqa: E501 line too long 
602-                                 modified_eval = modified_eval ,
603-                                 current_inputs = current_inputs ,
604-                                 current_mask = current_mask ,
605-                                 perturbations_per_eval = perturbations_per_eval ,
606-                                 num_examples = num_examples ,
607-                                 initial_eval = initial_eval ,
608-                                 flattened_initial_eval = flattened_initial_eval ,
609-                                 inputs = formatted_inputs ,
610-                                 n_outputs = n_outputs ,
611-                                 total_attrib = total_attrib ,
612-                                 weights = weights ,
613-                                 i = i ,
614-                                 attrib_type = attrib_type ,
615-                             )
616-                         except  FeatureAblationFutureError  as  e :
617-                             raise  FeatureAblationFutureError (
618-                                 "eval_fut_to_ablated_out_fut func failed)" 
619-                             ) from  e 
620-                         return  result 
621- 
622558                    ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] =  (
623559                        eval_futs .then (
624-                             lambda  eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : eval_fut_to_ablated_out_fut (  # type: ignore # noqa: E501 line too long 
560+                             lambda  eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self . _eval_fut_to_ablated_out_fut (  # type: ignore # noqa: E501 line too long 
625561                                eval_futs = eval_futs ,
626562                                current_inputs = current_inputs ,
627563                                current_mask = current_mask ,
@@ -660,6 +596,70 @@ def _attribute_progress_setup(
660596        )
661597        return  attr_progress 
662598
599+     def  _eval_fut_to_ablated_out_fut (
600+         self ,
601+         # pyre-ignore Invalid type parameters [24] 
602+         eval_futs : Future [List [Future [List [object ]]]],
603+         current_inputs : Tuple [Tensor , ...],
604+         current_mask : Tensor ,
605+         i : int ,
606+         perturbations_per_eval : int ,
607+         num_examples : int ,
608+         formatted_inputs : Tuple [Tensor , ...],
609+     ) ->  Tuple [List [Tensor ], List [Tensor ]]:
610+         try :
611+             modified_eval  =  cast (Tensor , eval_futs .value ()[1 ].value ())
612+             initial_eval_tuple  =  cast (
613+                 Tuple [
614+                     List [Tensor ],
615+                     List [Tensor ],
616+                     Tensor ,
617+                     Tensor ,
618+                     int ,
619+                     dtype ,
620+                 ],
621+                 eval_futs .value ()[0 ].value (),
622+             )
623+             if  len (initial_eval_tuple ) !=  6 :
624+                 raise  AssertionError (
625+                     "eval_fut_to_ablated_out_fut: " 
626+                     "initial_eval_tuple should have 6 elements: " 
627+                     "total_attrib, weights, initial_eval, " 
628+                     "flattened_initial_eval, n_outputs, attrib_type " 
629+                 )
630+             if  not  isinstance (modified_eval , Tensor ):
631+                 raise  AssertionError (
632+                     "eval_fut_to_ablated_out_fut: "  "modified eval should be a Tensor" 
633+                 )
634+             (
635+                 total_attrib ,
636+                 weights ,
637+                 initial_eval ,
638+                 flattened_initial_eval ,
639+                 n_outputs ,
640+                 attrib_type ,
641+             ) =  initial_eval_tuple 
642+             result  =  self ._process_ablated_out (  # type: ignore # noqa: E501 line too long 
643+                 modified_eval = modified_eval ,
644+                 current_inputs = current_inputs ,
645+                 current_mask = current_mask ,
646+                 perturbations_per_eval = perturbations_per_eval ,
647+                 num_examples = num_examples ,
648+                 initial_eval = initial_eval ,
649+                 flattened_initial_eval = flattened_initial_eval ,
650+                 inputs = formatted_inputs ,
651+                 n_outputs = n_outputs ,
652+                 total_attrib = total_attrib ,
653+                 weights = weights ,
654+                 i = i ,
655+                 attrib_type = attrib_type ,
656+             )
657+         except  FeatureAblationFutureError  as  e :
658+             raise  FeatureAblationFutureError (
659+                 "eval_fut_to_ablated_out_fut func failed)" 
660+             ) from  e 
661+         return  result 
662+ 
663663    # pyre-fixme[3]: Return type must be specified as type that does not contain `Any` 
664664    def  _ith_input_ablation_generator (
665665        self ,
0 commit comments