diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index 55accff68..7745321f4 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -9,10 +9,10 @@ import functools -from typing import Any, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor """ Prepares KJT for PT2 tracing. @@ -28,6 +28,7 @@ def kjt_for_pt2_tracing( kjt: KeyedJaggedTensor, convert_to_vb: bool = False, + mark_length: bool = False, ) -> KeyedJaggedTensor: # Breaking dependency cycle from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -78,8 +79,15 @@ def kjt_for_pt2_tracing( weights = kjt.weights_or_none() if weights is not None: torch._dynamo.decorators.mark_unbacked(weights, 0) + if mark_length: + torch._dynamo.decorators.mark_unbacked(lengths, 0) - return KeyedJaggedTensor( + length_per_key_marked_dynamic = [] + + for length in kjt.length_per_key(): + length_per_key_marked_dynamic.append(length) + + return PT2KeyedJaggedTensor( keys=kjt.keys(), values=values, lengths=lengths, @@ -87,9 +95,81 @@ def kjt_for_pt2_tracing( stride=stride if not is_vb else None, stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices, + length_per_key=(length_per_key_marked_dynamic if is_vb else None), ) +class PT2KeyedJaggedTensor(KeyedJaggedTensor): + """ + This subclass of KeyedJaggedTensor is used to support PT2 tracing. + We can apply some modifications to make KJT friendly for PT2 tracing. + """ + + def __init__( + self, + keys: List[str], + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + # Below exposed to ensure torch.script-able + stride_per_key: Optional[List[int]] = None, + length_per_key: Optional[List[int]] = None, + lengths_offset_per_key: Optional[List[int]] = None, + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + jt_dict: Optional[Dict[str, JaggedTensor]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + ) -> None: + super().__init__( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + offsets=offsets, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=stride_per_key, + length_per_key=None, + lengths_offset_per_key=lengths_offset_per_key, + offset_per_key=offset_per_key, + index_per_key=index_per_key, + jt_dict=jt_dict, + inverse_indices=inverse_indices, + ) + self.length_per_key_tensors: List[torch.Tensor] = [] + for length in length_per_key or []: + t = torch.empty((length, 0)) + torch._dynamo.mark_dynamic(t, 0) + self.length_per_key_tensors.append(t) + self.stride_per_key_per_rank_tensor: List[List[torch.Tensor]] = [] + for strides_per_key in stride_per_key_per_rank or []: + strides_per_key_list: List[torch.Tensor] = [] + for s in strides_per_key: + t = torch.empty((s, 0)) + torch._dynamo.mark_dynamic(t, 0) + strides_per_key_list.append(t) + self.stride_per_key_per_rank_tensor.append(strides_per_key_list) + + def length_per_key(self) -> List[int]: + if len(self.length_per_key_tensors) > 0: + self._length_per_key = [t.size(0) for t in self.length_per_key_tensors] + else: + self._length_per_key = super().length_per_key() + return self._length_per_key + + def stride_per_key_per_rank(self) -> List[List[int]]: + if len(self.stride_per_key_per_rank_tensor) > 0: + self._stride_per_key_per_rank = [ + [t.size(0) for t in strides_per_key_list] + for strides_per_key_list in self.stride_per_key_per_rank_tensor + ] + stride_per_key_per_rank = self._stride_per_key_per_rank + return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + + # pyre-ignore def default_pipeline_input_transformer(inp): for attr_name in ["id_list_features", "id_score_list_features"]: @@ -97,6 +177,29 @@ def default_pipeline_input_transformer(inp): attr = getattr(inp, attr_name) if isinstance(attr, KeyedJaggedTensor): setattr(inp, attr_name, kjt_for_pt2_tracing(attr)) + for attr_name in [ + "uhm_history_timestamps", + "raw_uhm_history_timestamps", + "event_id_list_feature_invert_indexes", + ]: + if hasattr(inp, attr_name): + attr = getattr(inp, attr_name) + if isinstance(attr, dict): + for key in attr: + torch._dynamo.decorators.mark_dynamic(attr[key], 0) + torch._dynamo.decorators.mark_dynamic(inp.supervision_label["keys"], 0) + torch._dynamo.decorators.mark_dynamic(inp.supervision_label["values"], 0) + + for attr_name in ["event_id_list_features_seqs"]: + if hasattr(inp, attr_name): + attr = getattr(inp, attr_name) + if isinstance(attr, dict): + for key in attr: + if isinstance(attr[key], KeyedJaggedTensor): + attr[key] = kjt_for_pt2_tracing(attr[key], mark_length=True) + + setattr(inp, attr_name, attr) + return inp