Skip to content

Commit

Permalink
Support int32_t indices/offsets for caching handling logics (pytorch#811
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#811

In training, we assume the indices / offsets are int64_t for embedding (TBE), but in inference, we assume the indices / offsets are int32_t.

This Diff enables both int32_t and int64_t supports for the caching logics so that we can reuse the same functions for both training and inference, while reducing the extra overhead to convert the indices/offsets from int to long or vice versa.

Reviewed By: jspark1105

Differential Revision: D33045589

fbshipit-source-id: 4cdc7cec15e07c51af999276bf5366199eb216b5
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 28, 2021
1 parent e3bff30 commit ae22b20
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1824,7 +1824,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
if not self.lxu_cache_weights.numel():
return

(indices, offsets) = indices.long(), offsets.long()
linear_cache_indices = torch.ops.fb.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
Expand Down
Loading

0 comments on commit ae22b20

Please sign in to comment.