@@ -676,47 +676,62 @@ def _validate_args(args):
676676 )
677677
678678
679- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
680- _validate_args (args )
681-
682- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
683-
684- # export_to_edge
685- builder_exported = _prepare_for_llama_export (args ).export ()
686-
687- builder_exported .run_canonical_optimizations ()
688-
689- if args .export_only :
690- exit ()
691-
692- builder_exported_to_edge = builder_exported .pt2e_quantize (
693- quantizers
694- ).export_to_edge ()
695-
696- modelname = builder_exported_to_edge .modelname
697-
698- # to_backend
679+ def _to_edge_and_lower_llama_xnnpack (
680+ builder_exported ,
681+ modelname ,
682+ additional_passes ,
683+ pt2e_quant_params ,
684+ quantizers ,
685+ quant_dtype ,
686+ args ,
687+ ) -> LLMEdgeManager : # noqa: C901
699688 partitioners = []
700689
701690 # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702- if (
703- pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None
704- ) or (args .xnnpack ):
705- partitioners .append (
706- get_xnnpack_partitioner (dynamic_quant_only_partitioner = True )
707- )
691+ partitioners .append (get_xnnpack_partitioner (dynamic_quant_only_partitioner = True ))
708692
709- # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710- args .xnnpack = True
711- modelname = f"xnnpack_dq_{ modelname } "
693+ modelname = f"xnnpack_dq_{ modelname } "
712694
713695 if args .xnnpack_extended_ops :
714- assert args .xnnpack , "xnnpack_extended_ops requires xnnpack to be enabled"
715696 partitioners .append (
716697 get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
717698 )
718699 modelname = f"xnnpack_{ modelname } "
719700
701+ logging .info ("Lowering model using following partitioner(s): " )
702+ for partitioner in partitioners :
703+ logging .info (f"--> { partitioner .__class__ .__name__ } " )
704+
705+ # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706+ if args .generate_etrecord :
707+ raise NotImplementedError (
708+ "export_llama does not support XNNPack and generating ETRecord at the moment."
709+ )
710+
711+ builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
712+ partitioners
713+ )
714+ if args .verbose :
715+ print_delegation_info (builder .edge_manager .exported_program ().graph_module )
716+
717+ return builder .to_executorch (passes = additional_passes )
718+
719+
720+ def _to_edge_and_lower_llama ( # noqa: C901
721+ builder_exported ,
722+ modelname ,
723+ additional_passes ,
724+ pt2e_quant_params ,
725+ quantizers ,
726+ quant_dtype ,
727+ args ,
728+ ):
729+ builder_exported_to_edge = builder_exported .pt2e_quantize (
730+ quantizers
731+ ).export_to_edge ()
732+
733+ # to_backend
734+ partitioners = []
720735 if args .vulkan :
721736 partitioners .append (
722737 get_vulkan_partitioner (
@@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
731746 modelname = f"vulkan_{ modelname } "
732747
733748 # Need to remove asserts from the graph to prevent graph breaks
734- # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
735749 remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
736750
737751 if args .mps :
@@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760774 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
761775 from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
762776
763- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
764777 _transform (builder_exported_to_edge .edge_manager .exported_program ())
765778
766779 if args .num_sharding > 0 :
767780 model_sharding .split_graph (
768781 builder_exported_to_edge .edge_manager .exported_program (),
769- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
770782 builder_exported_to_edge .metadata ["get_n_layers" ],
771783 shares = args .num_sharding ,
772784 )
@@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
792804 atten .head_dim ,
793805 )
794806 )
795- # pyre-ignore
796807 tag_quant_io (
797808 builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
798- partial (get_custom_quant_ios_dtype , cache_shape ), # pyre-ignore
809+ partial (get_custom_quant_ios_dtype , cache_shape ),
799810 )
800811
801812 logging .info ("Lowering model using following partitioner(s): " )
802813 for partitioner in partitioners :
803814 logging .info (f"--> { partitioner .__class__ .__name__ } " )
804815
805- additional_passes = []
806- if args .model in TORCHTUNE_DEFINED_MODELS :
807- additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
808816 if args .generate_etrecord :
809817 if not builder_exported_to_edge .edge_manager :
810818 raise ValueError ("Unable to generate etrecord due to missing edge manager." )
@@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
818826 if args .num_sharding > 0 and args .qnn :
819827 from executorch .backends .qualcomm .utils .utils import canonicalize_program
820828
821- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
822829 canonicalize_program (builder .edge_manager .exported_program ())
823830
824831 builder = builder .to_executorch (
@@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
840847 if args .num_sharding > 0 and args .qnn :
841848 from executorch .backends .qualcomm .utils .utils import canonicalize_program
842849
843- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
844850 canonicalize_program (builder .edge_manager .exported_program ())
845851
846852 builder = builder .to_executorch (passes = additional_passes )
847853
854+ return builder
855+
856+
857+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
858+ _validate_args (args )
859+
860+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
861+
862+ additional_passes = []
863+ if args .model in TORCHTUNE_DEFINED_MODELS :
864+ additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
865+
866+ # export_to_edge
867+ builder_exported = _prepare_for_llama_export (args ).export ()
868+ builder_exported .run_canonical_optimizations ()
869+ modelname = builder_exported .modelname
870+
871+ if args .export_only :
872+ exit ()
873+
874+ if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
875+ # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876+ args .xnnpack = True
877+
878+ if args .xnnpack :
879+ builder = _to_edge_and_lower_llama_xnnpack (
880+ builder_exported ,
881+ modelname ,
882+ additional_passes ,
883+ pt2e_quant_params ,
884+ quantizers ,
885+ quant_dtype ,
886+ args ,
887+ )
888+ else :
889+ builder = _to_edge_and_lower_llama (
890+ builder_exported ,
891+ modelname ,
892+ additional_passes ,
893+ pt2e_quant_params ,
894+ quantizers ,
895+ quant_dtype ,
896+ args ,
897+ )
898+
848899 if args .profile_memory :
849900 generate_memory_trace (builder .export_program , "memory_profile.json" )
850901
@@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
866917 output_file = f"{ builder .output_dir } /{ modelname } .pte"
867918
868919 builder .save_to_pte (output_file )
869-
870920 return builder
871921
872922
0 commit comments