Skip to content

Commit

Permalink
2024-12-31 nightly release (455de88)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 31, 2024
1 parent 634a0a8 commit 9b778c5
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 218 deletions.
4 changes: 2 additions & 2 deletions examples/retrieval/tests/test_two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class InferTest(unittest.TestCase):
@skip_if_asan
# pyre-ignore[56]
@unittest.skipIf(
not torch.cuda.is_available(),
"this test requires a GPU",
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_infer_function(self) -> None:
infer(
Expand Down
57 changes: 47 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -933,11 +933,31 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

return destination
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand All @@ -950,14 +970,16 @@ def named_parameters(
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
Expand All @@ -968,6 +990,21 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -982,11 +1019,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
return self.emb_module.split_embedding_weights(no_snapshot)


class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
Expand Down
Loading

0 comments on commit 9b778c5

Please sign in to comment.