-
Notifications
You must be signed in to change notification settings - Fork 456
/
dlrm_predict_single_gpu.py
117 lines (93 loc) · 4.29 KB
/
dlrm_predict_single_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
import torch
from dlrm_predict import DLRMModelConfig, DLRMPredictModule
from torchrec.inference.model_packager import load_pickle_config
from torchrec.inference.modules import PredictFactory
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer
logger: logging.Logger = logging.getLogger(__name__)
# OSS Only
class DLRMPredictSingleGPUModule(DLRMPredictModule):
"""
nn.Module used for unsharded, single GPU, DLRM predict module. DistributedModelParallel
(TorchRec sharding API) is not expected to wrap this module.
"""
# TODO: Determine cleaner way to remove the copy.
# This is needed because the server expects copy method to exist on predict module.
def copy(self, device: torch.device):
return self
class DLRMPredictSingleGPUFactory(PredictFactory):
def __init__(self) -> None:
self.model_config: DLRMModelConfig = load_pickle_config(
"config.pkl", DLRMModelConfig
)
def create_predict_module(self, world_size: int) -> torch.nn.Module:
logging.basicConfig(level=logging.INFO)
default_cuda_rank = 0
device = torch.device("cuda", default_cuda_rank)
torch.cuda.set_device(device)
eb_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings,
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
self.model_config.id_list_features_keys
)
]
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
module = DLRMPredictSingleGPUModule(
embedding_bag_collection=ebc,
dense_in_features=self.model_config.dense_in_features,
dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes,
over_arch_layer_sizes=self.model_config.over_arch_layer_sizes,
id_list_features_keys=self.model_config.id_list_features_keys,
dense_device=device,
)
module = fuse_embedding_optimizer(
module,
optimizer_type=torch.optim.SGD,
optimizer_kwargs={"lr": 0.0},
device=torch.device("cuda"),
)
# TensorRT Lowering - Use torch_tensorrt.fx (https://github.com/pytorch/TensorRT) for lowering dense module
# Follow https://github.com/pytorch/TensorRT/blob/master/py/torch_tensorrt/fx/example/fx2trt_example.py
# for fully detailed example on splitting and lowering a submodule.
# Example for lowering the dense part of this DLRMPredictSingleGPUModule:
# import torch.fx
# import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
# from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
# sample_input = {
# "float_features": torch.ones(13),
# "id_list_features.lengths": torch.ones(26),
# "id_list_features.values": torch.ones(26)
# }
# traced = acc_tracer.trace(model, sample_input)
# splitter = TRTSplitter(traced, sample_input)
# split_mod = splitter()
# Lower dense part (_run_on_acc_0, the part that can be lowered)
# interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
# r = interp.run()
# trt_mod = TRTModule(r.engine, r.input_names, r.output_names)
# split_mod._run_on_acc_0 = trt_mod
# return split_mod
return module
def batching_metadata(self) -> Dict[str, str]:
return {
"float_features": "dense",
"id_list_features": "sparse",
}
def result_metadata(self) -> str:
return "dict_of_tensor"