@@ -70,8 +70,7 @@ def r_precision(
7070    matches_per_row  =  torch .sum (same_label  *  relevance_mask , dim = 1 )
7171    max_possible_matches_per_row  =  torch .sum (relevance_mask , dim = 1 )
7272    accuracy_per_sample  =  (
73-         c_f .to_dtype (matches_per_row , dtype = torch .float64 )
74-         /  max_possible_matches_per_row 
73+         matches_per_row .type (torch .float64 ) /  max_possible_matches_per_row 
7574    )
7675    return  maybe_get_avg_of_avgs (
7776        accuracy_per_sample , gt_labels , avg_of_avgs , return_per_class 
@@ -99,9 +98,7 @@ def mean_average_precision(
9998    equality  =  is_same_label  *  relevance_mask 
10099    cumulative_correct  =  torch .cumsum (equality , dim = 1 )
101100    k_idx  =  torch .arange (1 , num_k  +  1 , device = device ).repeat (num_samples , 1 )
102-     precision_at_ks  =  (
103-         c_f .to_dtype (cumulative_correct  *  equality , dtype = torch .float64 ) /  k_idx 
104-     )
101+     precision_at_ks  =  (cumulative_correct  *  equality ).type (torch .float64 ) /  k_idx 
105102    summed_precision_per_row  =  torch .sum (precision_at_ks  *  relevance_mask , dim = 1 )
106103    if  at_r :
107104        max_possible_matches_per_row  =  torch .sum (relevance_mask , dim = 1 )
@@ -172,9 +169,7 @@ def precision_at_k(
172169):
173170    curr_knn_labels  =  knn_labels [:, :k ]
174171    same_label  =  label_comparison_fn (gt_labels , curr_knn_labels )
175-     accuracy_per_sample  =  (
176-         c_f .to_dtype (torch .sum (same_label , dim = 1 ), dtype = torch .float64 ) /  k 
177-     )
172+     accuracy_per_sample  =  torch .sum (same_label , dim = 1 ).type (torch .float64 ) /  k 
178173    return  maybe_get_avg_of_avgs (
179174        accuracy_per_sample , gt_labels , avg_of_avgs , return_per_class 
180175    )
@@ -209,9 +204,7 @@ def get_lone_query_labels(
209204    unique_labels , match_counts  =  label_counts 
210205    if  embeddings_come_from_same_source :
211206        label_matches_itself  =  label_comparison_fn (unique_labels , unique_labels )
212-         lone_condition  =  (
213-             match_counts  -  c_f .to_dtype (label_matches_itself , dtype = torch .long ) <=  0 
214-         )
207+         lone_condition  =  match_counts  -  label_matches_itself .type (torch .long ) <=  0 
215208    else :
216209        lone_condition  =  match_counts  ==  0 
217210    lone_query_labels  =  unique_labels [lone_condition ]
0 commit comments