@@ -418,7 +418,7 @@ def forward(self, input1, input2, input3=None):
418418        return  self .linear2 (self .relu (self .linear1 (embeddings ))).sum (1 )
419419
420420
421- class  GradientUnsupportedLayerOutput (nn .Module ):
421+ class  PassThroughLayerOutput (nn .Module ):
422422    """ 
423423    This layer is used to test the case where the model returns a layer that 
424424    is not supported by the gradient computation. 
@@ -428,10 +428,8 @@ def __init__(self) -> None:
428428        super ().__init__ ()
429429
430430    @no_type_check  
431-     def  forward (
432-         self , unsupported_layer_output : PassThroughOutputType 
433-     ) ->  PassThroughOutputType :
434-         return  unsupported_layer_output 
431+     def  forward (self , output : PassThroughOutputType ) ->  PassThroughOutputType :
432+         return  output 
435433
436434
437435class  BasicModel_GradientLayerAttribution (nn .Module ):
@@ -456,7 +454,7 @@ def __init__(
456454
457455        self .relu  =  nn .ReLU (inplace = inplace )
458456        self .relu_alt  =  nn .ReLU (inplace = False )
459-         self .unsupportedLayer  =  GradientUnsupportedLayerOutput ()
457+         self .unsupported_layer  =  PassThroughLayerOutput ()
460458
461459        self .linear2  =  nn .Linear (4 , 2 )
462460        self .linear2 .weight  =  nn .Parameter (torch .ones (2 , 4 ))
@@ -466,6 +464,8 @@ def __init__(
466464        self .linear3 .weight  =  nn .Parameter (torch .ones (2 , 4 ))
467465        self .linear3 .bias  =  nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
468466
467+         self .int_layer  =  PassThroughLayerOutput ()  # sample layer with an int ouput 
468+ 
469469    @no_type_check  
470470    def  forward (
471471        self , x : Tensor , add_input : Optional [Tensor ] =  None 
@@ -476,7 +476,7 @@ def forward(
476476        lin1_out_alt  =  self .linear1_alt (lin0_out )
477477
478478        if  self .unsupported_layer_output  is  not None :
479-             self .unsupportedLayer (self .unsupported_layer_output )
479+             self .unsupported_layer (self .unsupported_layer_output )
480480            # unsupportedLayer is unused in the forward func. 
481481        self .relu_alt (
482482            lin1_out_alt 
@@ -485,9 +485,10 @@ def forward(
485485        relu_out  =  self .relu (lin1_out )
486486        lin2_out  =  self .linear2 (relu_out )
487487
488-         lin3_out  =  self .linear3 (lin1_out_alt ).to (torch .int64 )
488+         lin3_out  =  self .linear3 (lin1_out_alt )
489+         int_output  =  self .int_layer (lin3_out .to (torch .int64 ))
489490
490-         output_tensors  =  torch .cat ((lin2_out , lin3_out ), dim = 1 )
491+         output_tensors  =  torch .cat ((lin2_out , int_output ), dim = 1 )
491492
492493        # we return a dictionary of tensors as an output to test the case 
493494        # where an output accessor is required 
0 commit comments