Skip to content

Commit 252944a

Browse files
avbokovoyfacebook-github-bot
authored andcommitted
Optimzed backward pass for ROCm devices (#3367)
Summary: Added optimized implementation of backward pass for ROCm devices. Currently support **not nobag** mode, **rowwise_adagrad** optimizer with non-mixed dimensions in [64, 128, 160, 192]. Pull Request resolved: #3367 Differential Revision: D66310520 Pulled By: leitian
1 parent 38518bf commit 252944a

14 files changed

+1643
-10
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+22
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,27 @@ def generate_backward_indices() -> None:
310310
ssd=ssd,
311311
)
312312

313+
@staticmethod
314+
def generate_rocm_backward_split(**kwargs: Any) -> None:
315+
# Generate backward device kernels based on weighted (True/False), VBE
316+
# (True/False), no bag (True/False)
317+
template_filepath = (
318+
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
319+
)
320+
321+
BackwardSplitGenerator.render_backward_templates(
322+
template_filepath,
323+
"",
324+
"{}gen_embedding_backward_{}_device_kernel_hip.hip",
325+
{
326+
"has_gpu_support": True,
327+
"has_vbe_support": False,
328+
"has_ssd_support": False,
329+
"dense": False,
330+
"gen_once": False,
331+
},
332+
)
333+
313334
@staticmethod
314335
def generate_python_sources(
315336
all_optimizers: List[str], ssd_optimizers: List[str]
@@ -369,6 +390,7 @@ def generate() -> None:
369390
BackwardSplitGenerator.generate_backward_split(
370391
ssd_tensors=ssd_tensors, **optimizer
371392
)
393+
BackwardSplitGenerator.generate_rocm_backward_split(**optimizer)
372394

373395
# Generate common device kernels for backwards
374396
BackwardSplitGenerator.generate_backward_device()

fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
171171
Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */,
172172
c10::SymInt /* max_B = -1 */,
173173
c10::SymInt /* max_B_feature_rank = -1 */,
174-
c10::SymInt /* vbe_output_size = -1 */) {
174+
c10::SymInt /* vbe_output_size = -1 */,
175+
bool /* mixed_D = false */) {
175176
return SplitLookupFunction_Dense_Op::apply(
176177
host_weights,
177178
weights_offsets,
@@ -190,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
190191
// Deprecated for fb namespace! Please use fbgemm namespace instead!
191192
TORCH_LIBRARY_FRAGMENT(fb, m) {
192193
m.def(
193-
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
194+
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
194195
DISPATCH_TO_CPU(
195196
"dense_embedding_codegen_lookup_function",
196197
split_embedding_codegen_lookup_dense_function);
197198
}
198199

199200
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
200201
m.def(
201-
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
202+
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
202203
DISPATCH_TO_CPU(
203204
"dense_embedding_codegen_lookup_function",
204205
split_embedding_codegen_lookup_dense_function);

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ enum SSDTensor {
152152
{%- else %}
153153
D_offsets,
154154
max_D,
155+
mixed_D,
155156
{%- endif %} {# /* if nobag */ #}
156157
hash_size_cumsum,
157158
total_hash_size_bits,
@@ -224,6 +225,7 @@ enum SSDTensor {
224225
Variable(), // D_offsets
225226
Variable(), // total_D
226227
Variable(), // max_D
228+
Variable(), // mixed_D
227229
{%- endif %}
228230
Variable(), // hash_size_cumsum
229231
Variable(), //total_hash_size_bits
@@ -304,6 +306,7 @@ enum SSDTensor {
304306
D_offsets,
305307
total_D,
306308
max_D,
309+
mixed_D,
307310
{%- endif %}
308311
hash_size_cumsum,
309312
total_hash_size_bits,
@@ -484,6 +487,7 @@ Tensor
484487
{%- else %}
485488
const Tensor& D_offsets,
486489
const c10::SymInt max_D,
490+
const bool mixed_D,
487491
{%- endif %}
488492
const Tensor& hash_size_cumsum,
489493
const int64_t total_hash_size_bits,
@@ -566,6 +570,7 @@ class {{ autograd_func }} :
566570
const Tensor& D_offsets,
567571
const c10::SymInt total_D,
568572
const c10::SymInt max_D,
573+
const bool mixed_D,
569574
{%- else %}
570575
const c10::SymInt D,
571576
{%- endif %}
@@ -762,6 +767,7 @@ class {{ autograd_func }} :
762767

763768
{%- if not nobag %}
764769
ctx->saved_data["max_D"] = max_D;
770+
ctx->saved_data["mixed_D"] = mixed_D;
765771
ctx->saved_data["pooling_mode"] = pooling_mode;
766772
{%- else %}
767773
ctx->saved_data["D"] = D;
@@ -877,6 +883,7 @@ class {{ autograd_func }} :
877883

878884
{%- if not nobag %}
879885
auto max_D = ctx->saved_data["max_D"].toSymInt();
886+
const auto mixed_D = ctx->saved_data["mixed_D"].toBool();
880887
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
881888
{%- else %}
882889
auto D = ctx->saved_data["D"].toSymInt();
@@ -1072,10 +1079,11 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
10721079
{%- if ssd %}
10731080
const std::optional<at::TensorList>& ssd_tensors = std::nullopt,
10741081
{%- endif %}
1075-
const double gwd_lower_bound = 0
1082+
const double gwd_lower_bound = 0,
10761083
{%- else %}
1077-
const c10::SymInt vbe_output_size = -1
1084+
const c10::SymInt vbe_output_size = -1,
10781085
{%- endif %}
1086+
const bool mixed_D = false
10791087
) {
10801088
// TODO: refactor into macro
10811089
{%- if has_gpu_support %}
@@ -1191,7 +1199,8 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
11911199
{%- if ssd %}
11921200
" Tensor[]? ssd_tensors=None,"
11931201
{%- endif %}
1194-
" float gwd_lower_bound=0 "
1202+
" float gwd_lower_bound=0, "
1203+
" bool mixed_D=False"
11951204
") -> Tensor",
11961205
{PT2_COMPLIANT_TAG});
11971206

0 commit comments

Comments
 (0)