7
7
8
8
# pyre-strict
9
9
# pyre-ignore-all-errors[29]
10
+ # pyre-ignore-all-errors[53]
10
11
# flake8: noqa F401
11
12
12
13
@@ -205,6 +206,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str:
205
206
return f"bool { name } = { default } "
206
207
207
208
209
+ def list_arg (ty : str ) -> str :
210
+ """
211
+ Returns a C++ argument for a list of optimizer arguments the given type.
212
+
213
+ Parameters:
214
+ ty (str) - type of the list e.g., "int", "float", "tensor"
215
+ Returns:
216
+ C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector<int> optim_int",
217
+ """
218
+ return {
219
+ "tensor" : "std::vector<std::optional<at::Tensor>> optim_tensor" ,
220
+ "int" : "std::vector<int64_t> optim_int" ,
221
+ "float" : "std::vector<double> optim_float" ,
222
+ "bool" : "c10::List<bool> optim_bool" ,
223
+ }[ty ]
224
+
225
+
226
+ def schema_list_arg (ty : str ) -> str :
227
+ """
228
+ Returns a C++ schema for a list of optimizer arguments the given type.
229
+
230
+ Parameters:
231
+ ty (str) - type of the list e.g., "int", "float", "tensor"
232
+ Returns:
233
+ C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int",
234
+ """
235
+ return {
236
+ "tensor" : "Tensor?[] optim_tensor" ,
237
+ "int" : "int[] optim_int" ,
238
+ "float" : "float[] optim_float" ,
239
+ "bool" : "bool[] optim_bool" ,
240
+ }[ty ]
241
+
242
+
208
243
def optional_tensor_arg (name : str ) -> str :
209
244
return f"std::optional<Tensor> { name } = std::nullopt"
210
245
@@ -230,7 +265,6 @@ def schema_optional_tensorlist_arg(name: str) -> str:
230
265
231
266
232
267
def make_kernel_arg (
233
- # pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
234
268
ty : ArgType ,
235
269
name : str ,
236
270
default : Union [int , float , None ],
@@ -505,6 +539,10 @@ class PT2ArgsSet:
505
539
split_function_schemas : List [str ]
506
540
split_saved_tensorlist : List [str ]
507
541
split_saved_tensorlist_optional : List [str ]
542
+ split_saved_data : List [dict [str , str ]]
543
+ split_variables : List [str ]
544
+ split_unpacked_arg_names : List [str ]
545
+ split_args_dict : Dict [str , List [str ]]
508
546
509
547
@staticmethod
510
548
# pyre-ignore[3]
@@ -525,59 +563,178 @@ def create(
525
563
Returns:
526
564
PT2ArgsSet object with the following attributes:
527
565
split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions
528
- Tensors will be packed and pass as TensorList
529
- e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay '].
566
+ Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict.
567
+ e.g., ['at::TensorList momentum1', 'at::Dict<std:string, int> optim_int '].
530
568
split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions
531
- e.g., ['momentum1', 'eps ', 'weight_decay '].
569
+ e.g., ['momentum1', 'optim_int ', 'optim_float '].
532
570
split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format
533
571
e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay'].
534
572
split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in
535
573
PT2 autograd function. e.g., ['momentum1'].
536
574
split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional
537
575
and will be unpacked in PT2 autograd function e.g., ['row_counter'].
576
+ split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward
577
+ split_unpacked_arg_names: List[str] - List of argument names, unrolled from list
578
+ e.g., ['momentum1', 'eps', 'weight_decay', 'iter'].
579
+ split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type.
580
+ e.g., if an optimizer only has an int argument called iter, the dict will look like:
581
+ {'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []}
538
582
"""
539
583
split_function_arg_names = []
540
584
split_function_args = []
541
585
split_function_schemas = []
542
586
split_saved_tensorlist = []
543
587
split_saved_tensorlist_optional = []
588
+ split_saved_data = []
589
+ split_variables = []
590
+ split_unpacked_arg_names = []
591
+ has_optim_tensor = False # optim tensors here are optional tensor
592
+ has_optim_int = False
593
+ has_optim_float = False
594
+ has_optim_bool = False
595
+ split_args_dict = {
596
+ "optim_tensor" : [],
597
+ "optim_int" : [],
598
+ "optim_float" : [],
599
+ "optim_bool" : [],
600
+ }
601
+ # list of symint args to be appended after optim_xxx args
602
+ # since they have default values
603
+ symint_list : List [OptimItem ] = []
604
+
544
605
for s in arg_spec :
545
606
if s .name == "learning_rate_tensor" :
546
607
split_function_arg_names .append (s .name )
608
+ split_unpacked_arg_names .append (s .name )
547
609
split_function_args .append (tensor_arg (s .name ))
548
610
split_function_schemas .append (tensor_arg (s .name ))
611
+ split_variables .append (f"ret.push_back(Variable()); // { s .name } " )
549
612
elif s .ty in (
550
613
ArgType .TENSOR ,
551
614
ArgType .INT_TENSOR ,
552
615
ArgType .LONG_TENSOR ,
553
616
ArgType .PLACEHOLDER_TENSOR ,
554
617
):
555
618
name = s .name
556
- split_function_arg_names .append (name )
619
+ split_unpacked_arg_names .append (name )
557
620
if s .is_optional :
558
- split_function_args .append (optional_tensorlist_arg (name ))
559
- split_function_schemas .append (schema_optional_tensorlist_arg (name ))
560
621
split_saved_tensorlist_optional .append (name )
622
+ split_args_dict ["optim_tensor" ].append (s .name )
623
+ has_optim_tensor = True
561
624
else :
562
625
split_function_args .append (
563
626
tensor_list_arg_no_default (name , pass_by_ref = False )
564
627
)
628
+ split_function_arg_names .append (name )
565
629
split_function_schemas .append (
566
630
schema_tensor_list_arg_no_default (name )
567
631
)
568
632
split_saved_tensorlist .append (name )
633
+ split_variables .append (
634
+ f"ret.push_back(Variable()); // { s .name } _dev or host"
635
+ )
636
+ split_variables .append (
637
+ f"ret.push_back(Variable()); // { s .name } _placements"
638
+ )
639
+ split_variables .append (
640
+ f"ret.push_back(Variable()); // { s .name } _offsets"
641
+ )
642
+ split_variables .append ("if (" + name + "_host.numel() == 0) {" )
643
+ split_variables .append (
644
+ f"ret.push_back(Variable()); // { s .name } _uvm"
645
+ )
646
+ split_variables .append ("}" )
569
647
else :
570
- split_function_arg_names .append (s .name )
571
- split_function_args .append (make_function_arg (s .ty , s .name , s .default ))
572
- split_function_schemas .append (
573
- make_function_schema_arg (s .ty , s .name , s .default )
574
- )
648
+ if s .ty == ArgType .INT :
649
+ # iter is passed in aux_int
650
+ if s .name != "iter" :
651
+ split_args_dict ["optim_int" ].append (s .name )
652
+ split_saved_data .append (
653
+ (
654
+ s .name ,
655
+ f'optim_int[{ len (split_args_dict ["optim_int" ]) - 1 } ]' ,
656
+ make_ivalue_cast (s .ty ),
657
+ "int64_t" ,
658
+ )
659
+ )
660
+ has_optim_int = True
661
+ elif s .ty == ArgType .SYM_INT :
662
+ symint_list .append (s )
663
+ split_saved_data .append (
664
+ (
665
+ s .name ,
666
+ "" ,
667
+ make_ivalue_cast (s .ty ),
668
+ "c10::SymInt" ,
669
+ )
670
+ )
671
+ elif s .ty == ArgType .FLOAT :
672
+ split_args_dict ["optim_float" ].append (s .name )
673
+ split_saved_data .append (
674
+ (
675
+ s .name ,
676
+ f'optim_float[{ len (split_args_dict ["optim_float" ])- 1 } ]' ,
677
+ make_ivalue_cast (s .ty ),
678
+ "double" ,
679
+ )
680
+ )
681
+ has_optim_float = True
682
+ elif s .ty == ArgType .BOOL :
683
+ split_args_dict ["optim_bool" ].append (s .name )
684
+ split_saved_data .append (
685
+ (
686
+ s .name ,
687
+ f'optim_bool[{ len (split_args_dict ["optim_bool" ])- 1 } ]' ,
688
+ make_ivalue_cast (s .ty ),
689
+ "bool" ,
690
+ )
691
+ )
692
+ has_optim_bool = True
693
+ else :
694
+ raise ValueError (f"Unsupported type { s .ty } " )
695
+ split_unpacked_arg_names .append (s .name )
696
+
697
+ def append_lists (type_name : str ) -> None :
698
+ """
699
+ Append the list as one argument to the list of function arguments, schemas, names and saved_variables.
700
+ e.g., if type_name is "tensor", optim_tensor will be appended with the corresponding syntax.
701
+
702
+ Parameters:
703
+ type_name (str) - type name of the list to be appended
704
+
705
+ Returns:
706
+ None
707
+ """
708
+ split_function_args .append (list_arg (type_name ))
709
+ split_function_schemas .append (schema_list_arg (type_name ))
710
+ split_function_arg_names .append (f"optim_{ type_name } " )
711
+ split_variables .append (f"ret.push_back(Variable()); // optim_{ type_name } " )
712
+
713
+ if has_optim_tensor :
714
+ append_lists ("tensor" )
715
+ if has_optim_int :
716
+ append_lists ("int" )
717
+ if has_optim_float :
718
+ append_lists ("float" )
719
+ if has_optim_bool :
720
+ append_lists ("bool" )
721
+ for s in symint_list :
722
+ split_function_arg_names .append (s .name )
723
+ split_function_args .append (make_function_arg (s .ty , s .name , s .default ))
724
+ split_function_schemas .append (
725
+ make_function_schema_arg (s .ty , s .name , s .default )
726
+ )
727
+ split_variables .append (f"ret.push_back(Variable()); // { s .name } " )
575
728
return PT2ArgsSet (
576
729
split_function_args = split_function_args ,
577
730
split_function_arg_names = split_function_arg_names ,
578
731
split_function_schemas = split_function_schemas ,
579
732
split_saved_tensorlist = split_saved_tensorlist ,
580
733
split_saved_tensorlist_optional = split_saved_tensorlist_optional ,
734
+ split_saved_data = split_saved_data ,
735
+ split_variables = split_variables ,
736
+ split_unpacked_arg_names = split_unpacked_arg_names ,
737
+ split_args_dict = split_args_dict ,
581
738
)
582
739
583
740
@@ -637,12 +794,14 @@ def create(
637
794
if s .is_optional :
638
795
has_optional_tensors = True
639
796
640
- # Optional tensors are converted to tensor in autograd functions
641
- # Reorganize arguments for wrapper, backend and kernel functions
797
+ # Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors
798
+ # The optional tensors are converted to Tensor in autograd functions
799
+ # Hence, need to reorganize such that the tensors come before non-tensors which have default values values
800
+ # This is used in wrapper, backend and kernel functions
642
801
if has_optional_tensors :
643
- # Arg order : non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors,
802
+ # reordered args for split_arg_spec : non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors
644
803
split_arg_spec = reorder_args (split_arg_spec )
645
- # Arg order : non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
804
+ # reordered args for kernel_split_arg_spec : non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
646
805
kernel_split_arg_spec = reorder_args (kernel_split_arg_spec )
647
806
648
807
# Compute placeholder tensor combinations
0 commit comments