Skip to content

Commit 8d6fa66

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Backend) (#3563)
Summary: X-link: facebookresearch/FBGEMM#649 As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. **For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.** We pack the arguments as a list for each type. For **common** arguments, we pack - weights and arguments of type `Momentum` into TensorList - other tensors and optional tensors to list of optional tensors `aux_tensor` - `int` arguments into `aux_int` - `float` arguments into `aux_float` - `bool` arguments into `aux_bool`. Similarly for **optimizer-specific** arguments, we pack - arguments of type `Momentum` that are *__not__ optional* into TensorList - *optional* tensors to list of optional tensors `optim_tensor` - `int` arguments into `optim_int` - `float` arguments into `optim_float` - `bool` arguments into `optim_bool`. We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. **This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. Please refer to the design doc on which arguments are packed and signature. Design doc: https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb Full signature for each optimizer lookup function will be provided shortly. Reviewed By: sryap Differential Revision: D68054868
1 parent 3e0db25 commit 8d6fa66

11 files changed

+631
-297
lines changed

fbgemm_gpu/cmake/tbe_sources.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@
191191
+ [
192192
"gen_embedding_backward_split_common_device_kernel.cuh",
193193
]
194+
+ [
195+
"pt2_arg_utils.h",
196+
]
194197
)
195198

196199
gen_defused_optim_templates = [
@@ -502,15 +505,13 @@
502505
for optimizer in COMMON_OPTIMIZERS + CPU_ONLY_OPTIMIZERS + GPU_ONLY_OPTIMIZERS
503506
for fstring in [
504507
"lookup_{}.py",
505-
"lookup_{}_pt2.py",
506508
]
507509
]
508510
+ [
509511
fstring.format(optimizer)
510512
for optimizer in SSD_OPTIMIZERS
511513
for fstring in [
512514
"lookup_{}_ssd.py",
513-
"lookup_{}_ssd_pt2.py",
514515
]
515516
]
516517
+ [

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
197197
)
198198
for filename in [
199199
f"lookup_{optimizer}{sdesc}.py",
200-
f"lookup_{optimizer}{sdesc}_pt2.py",
201200
]:
202201
template.write(
203202
filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs
@@ -331,6 +330,23 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
331330
},
332331
)
333332

333+
@staticmethod
334+
def generate_backward_header(
335+
aux_args: Dict[str, List[str]], aux_names: List[str]
336+
) -> None:
337+
"""
338+
Generate a header file that contains enum of argument order from the dict
339+
340+
Parameters:
341+
aux_args (Dict[str, List[str]]): a dict containing a list of arguments
342+
aux_names (List[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.)
343+
Return:
344+
None
345+
"""
346+
# Generate backward header for PT2 Autograd
347+
template = CodeTemplate.load("training/pt2/pt2_arg_utils_template.h")
348+
template.write(f"pt2_arg_utils.h", aux_args=aux_args, aux_names=aux_names)
349+
334350
@staticmethod
335351
def generate_python_sources(
336352
all_optimizers: List[str], ssd_optimizers: List[str]
@@ -375,6 +391,40 @@ def generate() -> None:
375391
"actions_count",
376392
]
377393

394+
aux_names = ["aux_tensor", "aux_int", "aux_float", "aux_bool"]
395+
# This is a dict of auxilary arguments used in TBE PT2 interface where the aux
396+
# arguments of a type are packed into a list for that type. This dict maintains the
397+
# order of the arguments of each type.
398+
aux_args: Dict[str, List[str]] = {
399+
"aux_tensor": [
400+
"B_offsets", # 0
401+
"vbe_output_offsets_feature_rank", # 1
402+
"vbe_B_offsets_rank_per_feature", # 2
403+
"lxu_cache_locations", # 3
404+
"uvm_cache_stats", # 4
405+
"prev_iter_dev", # 5
406+
],
407+
"aux_int": [
408+
"iter", # 0
409+
],
410+
"aux_float": [
411+
"gwd_lower_bound", # 0
412+
"max_gradient", # 1
413+
],
414+
"aux_bool": [
415+
"is_experimental_tbe", # 0
416+
"use_uniq_cache_locations_bwd", # 1
417+
"use_homogeneous_placements", # 2
418+
"apply_global_weight_decay", # 3
419+
"gradient_clipping", # 4
420+
"stochastic_rounding", # 5
421+
"mixed_D", # 6
422+
],
423+
}
424+
assert (
425+
list(aux_args.keys()) == aux_names
426+
), f"{aux_names} must match {aux_args.keys()}"
427+
378428
all_optimizers = []
379429
ssd_optimizers = []
380430

@@ -399,6 +449,9 @@ def generate() -> None:
399449
BackwardSplitGenerator.generate_backward_grad()
400450
BackwardSplitGenerator.generate_backward_indices()
401451

452+
# Generate headers for backwards
453+
BackwardSplitGenerator.generate_backward_header(aux_args, aux_names)
454+
402455
BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers)
403456

404457

fbgemm_gpu/codegen/genscript/optimizer_args.py

+175-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99
# pyre-ignore-all-errors[29]
10+
# pyre-ignore-all-errors[53]
1011
# flake8: noqa F401
1112

1213

@@ -205,6 +206,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str:
205206
return f"bool {name} = {default}"
206207

207208

209+
def list_arg(ty: str) -> str:
210+
"""
211+
Returns a C++ argument for a list of optimizer arguments the given type.
212+
213+
Parameters:
214+
ty (str) - type of the list e.g., "int", "float", "tensor"
215+
Returns:
216+
C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector<int> optim_int",
217+
"""
218+
return {
219+
"tensor": "std::vector<std::optional<at::Tensor>> optim_tensor",
220+
"int": "std::vector<int64_t> optim_int",
221+
"float": "std::vector<double> optim_float",
222+
"bool": "c10::List<bool> optim_bool",
223+
}[ty]
224+
225+
226+
def schema_list_arg(ty: str) -> str:
227+
"""
228+
Returns a C++ schema for a list of optimizer arguments the given type.
229+
230+
Parameters:
231+
ty (str) - type of the list e.g., "int", "float", "tensor"
232+
Returns:
233+
C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int",
234+
"""
235+
return {
236+
"tensor": "Tensor?[] optim_tensor",
237+
"int": "int[] optim_int",
238+
"float": "float[] optim_float",
239+
"bool": "bool[] optim_bool",
240+
}[ty]
241+
242+
208243
def optional_tensor_arg(name: str) -> str:
209244
return f"std::optional<Tensor> {name} = std::nullopt"
210245

@@ -230,7 +265,6 @@ def schema_optional_tensorlist_arg(name: str) -> str:
230265

231266

232267
def make_kernel_arg(
233-
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
234268
ty: ArgType,
235269
name: str,
236270
default: Union[int, float, None],
@@ -505,6 +539,10 @@ class PT2ArgsSet:
505539
split_function_schemas: List[str]
506540
split_saved_tensorlist: List[str]
507541
split_saved_tensorlist_optional: List[str]
542+
split_saved_data: List[dict[str, str]]
543+
split_variables: List[str]
544+
split_unpacked_arg_names: List[str]
545+
split_args_dict: Dict[str, List[str]]
508546

509547
@staticmethod
510548
# pyre-ignore[3]
@@ -525,59 +563,178 @@ def create(
525563
Returns:
526564
PT2ArgsSet object with the following attributes:
527565
split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions
528-
Tensors will be packed and pass as TensorList
529-
e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay'].
566+
Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict.
567+
e.g., ['at::TensorList momentum1', 'at::Dict<std:string, int> optim_int'].
530568
split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions
531-
e.g., ['momentum1', 'eps', 'weight_decay'].
569+
e.g., ['momentum1', 'optim_int', 'optim_float'].
532570
split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format
533571
e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay'].
534572
split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in
535573
PT2 autograd function. e.g., ['momentum1'].
536574
split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional
537575
and will be unpacked in PT2 autograd function e.g., ['row_counter'].
576+
split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward
577+
split_unpacked_arg_names: List[str] - List of argument names, unrolled from list
578+
e.g., ['momentum1', 'eps', 'weight_decay', 'iter'].
579+
split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type.
580+
e.g., if an optimizer only has an int argument called iter, the dict will look like:
581+
{'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []}
538582
"""
539583
split_function_arg_names = []
540584
split_function_args = []
541585
split_function_schemas = []
542586
split_saved_tensorlist = []
543587
split_saved_tensorlist_optional = []
588+
split_saved_data = []
589+
split_variables = []
590+
split_unpacked_arg_names = []
591+
has_optim_tensor = False # optim tensors here are optional tensor
592+
has_optim_int = False
593+
has_optim_float = False
594+
has_optim_bool = False
595+
split_args_dict = {
596+
"optim_tensor": [],
597+
"optim_int": [],
598+
"optim_float": [],
599+
"optim_bool": [],
600+
}
601+
# list of symint args to be appended after optim_xxx args
602+
# since they have default values
603+
symint_list: List[OptimItem] = []
604+
544605
for s in arg_spec:
545606
if s.name == "learning_rate_tensor":
546607
split_function_arg_names.append(s.name)
608+
split_unpacked_arg_names.append(s.name)
547609
split_function_args.append(tensor_arg(s.name))
548610
split_function_schemas.append(tensor_arg(s.name))
611+
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
549612
elif s.ty in (
550613
ArgType.TENSOR,
551614
ArgType.INT_TENSOR,
552615
ArgType.LONG_TENSOR,
553616
ArgType.PLACEHOLDER_TENSOR,
554617
):
555618
name = s.name
556-
split_function_arg_names.append(name)
619+
split_unpacked_arg_names.append(name)
557620
if s.is_optional:
558-
split_function_args.append(optional_tensorlist_arg(name))
559-
split_function_schemas.append(schema_optional_tensorlist_arg(name))
560621
split_saved_tensorlist_optional.append(name)
622+
split_args_dict["optim_tensor"].append(s.name)
623+
has_optim_tensor = True
561624
else:
562625
split_function_args.append(
563626
tensor_list_arg_no_default(name, pass_by_ref=False)
564627
)
628+
split_function_arg_names.append(name)
565629
split_function_schemas.append(
566630
schema_tensor_list_arg_no_default(name)
567631
)
568632
split_saved_tensorlist.append(name)
633+
split_variables.append(
634+
f"ret.push_back(Variable()); // {s.name}_dev or host"
635+
)
636+
split_variables.append(
637+
f"ret.push_back(Variable()); // {s.name}_placements"
638+
)
639+
split_variables.append(
640+
f"ret.push_back(Variable()); // {s.name}_offsets"
641+
)
642+
split_variables.append("if (" + name + "_host.numel() == 0) {")
643+
split_variables.append(
644+
f"ret.push_back(Variable()); // {s.name}_uvm"
645+
)
646+
split_variables.append("}")
569647
else:
570-
split_function_arg_names.append(s.name)
571-
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
572-
split_function_schemas.append(
573-
make_function_schema_arg(s.ty, s.name, s.default)
574-
)
648+
if s.ty == ArgType.INT:
649+
# iter is passed in aux_int
650+
if s.name != "iter":
651+
split_args_dict["optim_int"].append(s.name)
652+
split_saved_data.append(
653+
(
654+
s.name,
655+
f'optim_int[{len(split_args_dict["optim_int"]) - 1}]',
656+
make_ivalue_cast(s.ty),
657+
"int64_t",
658+
)
659+
)
660+
has_optim_int = True
661+
elif s.ty == ArgType.SYM_INT:
662+
symint_list.append(s)
663+
split_saved_data.append(
664+
(
665+
s.name,
666+
"",
667+
make_ivalue_cast(s.ty),
668+
"c10::SymInt",
669+
)
670+
)
671+
elif s.ty == ArgType.FLOAT:
672+
split_args_dict["optim_float"].append(s.name)
673+
split_saved_data.append(
674+
(
675+
s.name,
676+
f'optim_float[{len(split_args_dict["optim_float"])- 1}]',
677+
make_ivalue_cast(s.ty),
678+
"double",
679+
)
680+
)
681+
has_optim_float = True
682+
elif s.ty == ArgType.BOOL:
683+
split_args_dict["optim_bool"].append(s.name)
684+
split_saved_data.append(
685+
(
686+
s.name,
687+
f'optim_bool[{len(split_args_dict["optim_bool"])- 1}]',
688+
make_ivalue_cast(s.ty),
689+
"bool",
690+
)
691+
)
692+
has_optim_bool = True
693+
else:
694+
raise ValueError(f"Unsupported type {s.ty}")
695+
split_unpacked_arg_names.append(s.name)
696+
697+
def append_lists(type_name: str) -> None:
698+
"""
699+
Append the list as one argument to the list of function arguments, schemas, names and saved_variables.
700+
e.g., if type_name is "tensor", optim_tensor will be appended with the corresponding syntax.
701+
702+
Parameters:
703+
type_name (str) - type name of the list to be appended
704+
705+
Returns:
706+
None
707+
"""
708+
split_function_args.append(list_arg(type_name))
709+
split_function_schemas.append(schema_list_arg(type_name))
710+
split_function_arg_names.append(f"optim_{type_name}")
711+
split_variables.append(f"ret.push_back(Variable()); // optim_{type_name}")
712+
713+
if has_optim_tensor:
714+
append_lists("tensor")
715+
if has_optim_int:
716+
append_lists("int")
717+
if has_optim_float:
718+
append_lists("float")
719+
if has_optim_bool:
720+
append_lists("bool")
721+
for s in symint_list:
722+
split_function_arg_names.append(s.name)
723+
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
724+
split_function_schemas.append(
725+
make_function_schema_arg(s.ty, s.name, s.default)
726+
)
727+
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
575728
return PT2ArgsSet(
576729
split_function_args=split_function_args,
577730
split_function_arg_names=split_function_arg_names,
578731
split_function_schemas=split_function_schemas,
579732
split_saved_tensorlist=split_saved_tensorlist,
580733
split_saved_tensorlist_optional=split_saved_tensorlist_optional,
734+
split_saved_data=split_saved_data,
735+
split_variables=split_variables,
736+
split_unpacked_arg_names=split_unpacked_arg_names,
737+
split_args_dict=split_args_dict,
581738
)
582739

583740

@@ -637,12 +794,14 @@ def create(
637794
if s.is_optional:
638795
has_optional_tensors = True
639796

640-
# Optional tensors are converted to tensor in autograd functions
641-
# Reorganize arguments for wrapper, backend and kernel functions
797+
# Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors
798+
# The optional tensors are converted to Tensor in autograd functions
799+
# Hence, need to reorganize such that the tensors come before non-tensors which have default values values
800+
# This is used in wrapper, backend and kernel functions
642801
if has_optional_tensors:
643-
# Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors,
802+
# reordered args for split_arg_spec: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors
644803
split_arg_spec = reorder_args(split_arg_spec)
645-
# Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
804+
# reordered args for kernel_split_arg_spec: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
646805
kernel_split_arg_spec = reorder_args(kernel_split_arg_spec)
647806

648807
# Compute placeholder tensor combinations

0 commit comments

Comments
 (0)