Skip to content

Commit 5c9e468

Browse files
committed
bugfix in collector batch to tensor when top_k > number of elements
1 parent 7b64458 commit 5c9e468

File tree

2 files changed

+5
-81
lines changed

2 files changed

+5
-81
lines changed

thoi/collectors.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,12 @@ def batch_to_tensor(nplets_idxs: torch.Tensor,
376376
metric,
377377
largest)
378378

379+
metric_func = partial(_get_string_metric, metric=metric) if isinstance(metric, str) else metric
380+
379381
# If not top_k or len(nplets_measuresa) > top_k return the original values
380-
# |k x D x 4|, |k x N|
381-
return (nplets_measures, nplets_idxs, None)
382+
# |k x D x 4|, |k x N|, |k|
383+
values = metric_func(nplets_measures).to(nplets_measures.device)
384+
return (nplets_measures, nplets_idxs, values)
382385

383386

384387
def concat_batched_tensors(batched_tensors: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],

thoi/graph.py

-79
This file was deleted.

0 commit comments

Comments
 (0)