diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index c60afc2dd33..b6d1f3708a4 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -14,6 +14,7 @@ exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] allow_list_operator = [ diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index c3afc23daeb..8080947f929 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -40,16 +40,7 @@ def __init__( ): self.node_visitors = node_visitor.get_node_visitors(edge_program) - self.skip_node_op_builder_set = set() - if skip_node_op_set is not None: - self.skip_node_op_builder_set = set( - [ - self.node_visitors[val] - for val in skip_node_op_set - if val in self.node_visitors - ] - ) - + self.skip_node_op_set = skip_node_op_set self.skip_node_id_set = skip_node_id_set self.nodes_to_wrappers = defaultdict(dict) self.qnn_manager = PyQnnManager.QnnManager( @@ -69,11 +60,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False - if ( - self.skip_node_op_builder_set is not None - and self.node_visitors[node.target.__name__] - in self.skip_node_op_builder_set - ): + if self.skip_node_op_set is not None and node.target.__name__ in self.skip_node_op_set: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") return False diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index bdee2c8196a..bd898e2decd 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -53,6 +53,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.index.Tensor, + exir_ops.edge.aten.index_put.default, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten._log_softmax.default, diff --git a/backends/qualcomm/passes/replace_inf_buffer.py b/backends/qualcomm/passes/replace_inf_buffer.py index bafa3fdb18b..1dc06630ca3 100644 --- a/backends/qualcomm/passes/replace_inf_buffer.py +++ b/backends/qualcomm/passes/replace_inf_buffer.py @@ -14,8 +14,9 @@ def __init__(self): def call(self, graph_module: torch.fx.GraphModule): for buf_name, tensor in graph_module.named_buffers(): if tensor.is_floating_point(): - tensor[tensor == float("inf")] = torch.finfo(torch.float32).max - tensor[tensor == float("-inf")] = torch.finfo(torch.float32).min + # An arbitrary number + tensor[tensor == float("inf")] = 1000 + tensor[tensor == float("-inf")] = -1000 setattr(graph_module, buf_name, tensor) graph_module.recompile() diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index d3148c95421..6e88277af68 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -52,7 +52,9 @@ from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, + replace_kv_cache_with_simple_kv_cache, replace_sdpa_with_custom_op, + replace_sdpa_with_flex_sdpa, replace_sdpa_with_simple_sdpa, ) @@ -118,7 +120,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--embedding-quantize", default=None, type=str, - help="type of embedding quantization, ',', e.g., '8,1024'.", + help="type of embedding quantization, ',,', e.g., '8,1024,32'.", ) parser.add_argument( "--pt2e_quantize", @@ -197,6 +199,12 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Whether to use sdpa_with_kv_cache update op when using kv cache", ) + parser.add_argument( + "--num_sharding", + type=int, + default=None, + help="Specify the number of splits which is generated with custom op. Expect to be able to divide num layer.", + ) parser.add_argument( "--disable_dynamic_shape", dest="enable_dynamic_shape", @@ -385,7 +393,12 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_custom_op) if args.use_kv_cache: - if args.qnn or args.coreml or args.mps: + if args.qnn: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + + elif args.coreml or args.mps: # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) @@ -486,11 +499,11 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 modelname = f"coreml_{modelname}" if args.qnn: + from executorch.extension.llm.custom_ops import model_sharding + partitioners.append( get_qnn_partitioner( - quant_dtype, - args.use_kv_cache, - args.pt2e_quantize, + args.use_kv_cache, args.pt2e_quantize, args.num_sharding ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` @@ -498,6 +511,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) + if args.num_sharding is not None: + model_sharding.split_graph( + builder_exported_to_edge.edge_manager.exported_program(), + builder_exported_to_edge.metadata["get_n_layers"], + shares=args.num_sharding, + ) if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: @@ -506,7 +525,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 logging.info("Generating etrecord") # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) - builder = builder_exported_to_edge.to_backend(partitioners).to_executorch() + builder = builder_exported_to_edge.to_backend(partitioners) + if args.num_sharding is not None: + from executorch.backends.qualcomm.utils.utils import canonicalize_program + + canonicalize_program(builder.edge_manager.exported_program()) + builder = builder.to_executorch() # Generate ETRecord if edge_manager_copy: @@ -517,7 +541,12 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) logging.info("Generated etrecord.bin") else: - builder = builder_exported_to_edge.to_backend(partitioners).to_executorch() + builder = builder_exported_to_edge.to_backend(partitioners) + if args.num_sharding is not None: + from executorch.backends.qualcomm.utils.utils import canonicalize_program + + canonicalize_program(builder.edge_manager.exported_program()) + builder = builder.to_executorch() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 56bf4a96c39..dacf9eb1fdc 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -161,6 +161,9 @@ def __init__( else: cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index fdf0dc707e4..a5755571c15 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -212,7 +212,7 @@ def get_example_inputs_kvcache_sdpa(self): if self.enable_dynamic_shape: return ( torch.tensor([[2, 3, 4]], dtype=torch.long), - torch.tensor([0], dtype=torch.long), + torch.tensor([0, 1, 2], dtype=torch.long), ) else: return ( diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index bb014145bd8..de2d39f422d 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -384,7 +384,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False + module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False, qge_dtype=torch.half ): for name, child in module.named_children(): # print(f"name: {name}") @@ -400,11 +400,12 @@ def replace_embedding_weight_only_grouped_int8_per_channel( embedding_dim=child.weight.shape[1], group_size=group_size, packed=packed, + dtype=qge_dtype, ), ) else: replace_embedding_weight_only_grouped_int8_per_channel( - child, device, bitwidth, group_size, packed + child, device, bitwidth, group_size, packed, qge_dtype ) @@ -417,6 +418,7 @@ def __init__( bitwidth: int = 8, group_size: Optional[int] = None, packed=False, + qge_dtype=torch.half, ): if isinstance(packed, str): packed = packed == "True" @@ -425,6 +427,7 @@ def __init__( self.group_size = group_size self.bitwidth = bitwidth self.packed = packed + self.qge_dtype = qge_dtype if (bitwidth != 4) and packed: raise RuntimeError("pack only works with bitsize 4") @@ -484,7 +487,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: def convert_for_runtime(self) -> nn.Module: replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.device, self.bitwidth, self.group_size, self.packed + self.mod, self.device, self.bitwidth, self.group_size, self.packed, self.qge_dtype ) return self.mod @@ -554,17 +557,28 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform(args): - bitwidth, group_size = args.embedding_quantize.split(",") + quant_args = [a.strip() for a in args.embedding_quantize.split(",")] + bitwidth, group_size = quant_args[:2] if group_size == "none" or group_size == "None" or group_size == "0": group_size = None else: group_size = int(group_size) bitwidth = int(bitwidth) + + if len(quant_args) == 3: + qge_dtype = quant_args[2] + if qge_dtype in ("32", "torch.float32"): + qge_dtype = torch.float32 + else: + print(f"Use default qge_dtype, {torch.half}") + qge_dtype = torch.half + return lambda model: EmbeddingQuantHandler( model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth == 4), + qge_dtype=qge_dtype, ).quantized_model() diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 4e0ac718689..3f8f19d890d 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -9,12 +9,27 @@ # Example script for exporting Llama2 to flatbuffer import math +from typing import Tuple, List import torch from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + + new_kv = [] + batch, n_heads, seqlen, head_dim = hidden_states.shape + n_heads *= n_rep + for h in hidden_states[0]: + new_kv += [h] * n_rep + return torch.cat(new_kv, 0).reshape(batch, n_heads, seqlen, head_dim) + + class SDPACustom(torch.nn.Module): def __init__( self, @@ -112,6 +127,46 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +class SDPAFlex(torch.nn.Module): + + def __init__( + self, + kv_cache: KVCache, + dim: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.dim = dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + mask, + ): + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + k, v = self.kv_cache.update(input_pos, k, v) + + k_repeat_num = q.shape[1] // k.shape[1] + v_repeat_num = q.shape[1] // v.shape[1] + k = repeat_kv(k, k_repeat_num) + v = repeat_kv(v, v_repeat_num) + attn_mask = mask[input_pos] + + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + y = attn_weight @ v + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): for name, child in module.named_children(): if isinstance(child, SDPA): @@ -125,6 +180,71 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): return module +def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPAFlex(child.kv_cache, child.dim), + ) + else: + replace_sdpa_with_flex_sdpa(child) + return module + + +class KVCacheSimple(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + self.register_buffer( + "past_k_caches", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + "past_v_caches", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) + v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) + + k_out = k_out.transpose(1, 2) + v_out = v_out.transpose(1, 2) + return k_out, v_out + + +def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, KVCache): + setattr( + module, + name, + KVCacheSimple( + child.max_batch_size, + child.max_seq_length, + child.n_heads, + child.head_dim, + child.k_cache.dtype, + ), + ) + else: + replace_kv_cache_with_simple_kv_cache(child) + return module + + def replace_causal_mask(module: torch.nn.Module): for buffer_fqn_name, buffer in module.named_buffers(): buffer_name = buffer_fqn_name.split(".")[-1] @@ -132,7 +252,7 @@ def replace_causal_mask(module: torch.nn.Module): max_seq_len = buffer.shape[-1] mask = torch.full( (max_seq_len, max_seq_len), - float("-inf"), + float("-255"), device="cpu", ) diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py new file mode 100644 index 00000000000..5d3bcc1ee32 --- /dev/null +++ b/extension/llm/custom_ops/model_sharding.py @@ -0,0 +1,93 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import re +from typing import List + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export.exported_program import ExportedProgram +from torch.library import impl, Library + + +fallback_op_lib = Library("llama", "DEF") +# registering an operator. +fallback_op_lib.define("fallback(Tensor input) -> Tensor") + + +@impl(fallback_op_lib, "fallback") +def fallback_impl(a: torch.Tensor) -> torch.Tensor: + return a + + +# registering the out variant. +fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") + + +@impl(fallback_op_lib, "fallback.out") +def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(a) + return out + + +class SplitGraph(ExportPass): + """ + Handle to split the llama model to multiple partitions. + Because there are limited memory on the device, it could + not load all llama model in one pte. + """ + + def __init__(self, shard_layers: List[int]): + super().__init__() + self.shard_layers = shard_layers + + def _insert_fallback_op( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + pattern = r"layers.(\d+)" + prev_node = None + prev_layer = None + for node in graph_module.graph.nodes: + if node.op != "call_function" or "nn_module_stack" not in node.meta: + continue + + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[-1][0] + match = re.search(pattern, full_qualified_name) + if match is None: + continue + + cur_layer = int(match.group(1)) + # Check the current node which is the last node of the layer + if cur_layer in self.shard_layers and prev_layer == cur_layer - 1: + with graph_module.graph.inserting_after(prev_node): + users = list(prev_node.users.keys()) + inserted_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.llama.fallback.default, + (prev_node,), + ) + inserted_node.meta["val"] = prev_node.meta["val"] + if prev_node.meta.get("quant_attrs", None): + inserted_node.meta["quant_attrs"] = prev_node.meta[ + "quant_attrs" + ] + for user in users: + user.replace_input_with(prev_node, inserted_node) + + prev_layer = cur_layer + prev_node = node + + def call(self, graph_module: torch.fx.GraphModule): + self._insert_fallback_op(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + +def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): + graph_module = edge_program.graph_module + shard_layers = list(range(0, num_layers, int(num_layers / shares))) + return SplitGraph(shard_layers)(graph_module) diff --git a/extension/llm/custom_ops/op_fallback.cpp b/extension/llm/custom_ops/op_fallback.cpp new file mode 100644 index 00000000000..11a1b4e7faf --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace torch { +namespace executor { + +namespace native { + +// Copy from op_clone.cpp +Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, in.sizes()) == torch::executor::Error::Ok, + InvalidArgument, + out); + + // The input and out shall share same dtype and size + ET_KERNEL_CHECK( + ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); + + if (in.nbytes() > 0) { + // Note that this check is important. It's valid for a tensor with numel 0 + // to have a null data pointer, but in some environments it's invalid to + // pass a null pointer to memcpy() even when the size is zero. + memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes()); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +EXECUTORCH_LIBRARY( + llama, + "fallback.out", + torch::executor::native::fallback_out); diff --git a/extension/llm/custom_ops/op_fallback.h b/extension/llm/custom_ops/op_fallback.h new file mode 100644 index 00000000000..62a2c0d53eb --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { +Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 8c38eb6a0a0..e3ed9fe0a99 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -8,8 +8,8 @@ def define_common_targets(): """ runtime.cxx_library( name = "custom_ops", - srcs = ["op_sdpa.cpp"], - exported_headers = ["op_sdpa.h"], + srcs = ["op_sdpa.cpp", "op_fallback.cpp"], + exported_headers = ["op_sdpa.h", "op_fallback.h"], exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 264e1e95ad3..0760f80e08b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -11,8 +11,11 @@ import logging from enum import Enum from typing import Any, Callable, List, Optional +from functools import partial import torch +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken +from sentencepiece import SentencePieceProcessor from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, ) @@ -166,7 +169,28 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": ) return self - def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": + def calibrate(self, module: torch.fx.GraphModule): + tokenizer = SimpleTokenizer("tokenizer.model") + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [tokenizer.bos_id] + tokenizer.encode("Once upon a time") + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < 128: + logits = module( + torch.full((1, 1), token_list[pos]), + torch.tensor((pos, )), + ) + pos += 1 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"calibration data:\n{tokenizer.decode(token_list)}") + + def pt2e_quantize( + self, quantizers: Optional[List[Quantizer]] + ) -> "LlamaEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. Args: @@ -189,7 +213,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ), "Please run capture_pre_autograd_graph first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) # Calibrate - m(*self.example_inputs) + self.calibrate(m) + # m(*self.example_inputs) m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m @@ -294,3 +319,19 @@ def get_saved_pte_filename(self) -> Optional[str]: Return the filename of the most recenet saved .pte file. Return None if the model is not saved. """ return self._saved_pte_filename + +class SimpleTokenizer: + def __init__(self, model_path): + try: + module = SentencePieceProcessor(model_file=model_path) + self.bos_id = module.bos_id() + self.eos_id = module.eos_id() + self.encode = module.encode + self.decode = module.decode + except Exception: + print("Using Tiktokenizer") + module = Tiktoken(model_path=model_path) + self.bos_id = module.bos_id + self.eos_id = module.eos_id + self.encode = partial(module.encode, bos=False, eos=False) + self.decode = module.decode diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index bcbeeeee159..dfdab1e0b19 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -105,7 +105,9 @@ def get_coreml_partitioner( def get_qnn_partitioner( - quant_dtype, use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None + use_kv_cache: bool = False, + pt2e_quantize: Optional[str] = None, + num_sharding: int = None, ): assert ( use_kv_cache is True @@ -116,9 +118,6 @@ def get_qnn_partitioner( QnnPartitioner, ) - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` - from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema` from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, @@ -135,27 +134,20 @@ def get_qnn_partitioner( ) use_fp16 = True - skip_node_op_set = {} + skip_node_op_set = {"llama.fallback.default"} if pt2e_quantize is not None: use_fp16 = False - # TODO: fix the lowering error without skipping nodes - - if quant_dtype == QuantDtype.use_8a8w: - raise NotImplementedError("8a8w for llama is still under development") - - elif quant_dtype == QuantDtype.use_16a16w: - raise NotImplementedError("16a16w for llama is still under development") - - elif quant_dtype == QuantDtype.use_16a4w: - raise NotImplementedError("16a4w for llama is still under development") return QnnPartitioner( generate_qnn_executorch_compiler_spec( soc_model=QcomChipset.SM8650, # default to SM8650 - backend_options=generate_htp_compiler_spec(use_fp16=use_fp16), + backend_options=generate_htp_compiler_spec( + use_fp16=use_fp16, + use_multi_contexts=num_sharding is not None, + ), debug=False, saver=False, ), - skip_node_id_set={}, + skip_node_id_set=None, skip_node_op_set=skip_node_op_set, ) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index fe6ad1c201a..7e4f237bb9e 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -8,7 +8,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Sequence import torch @@ -153,12 +153,135 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) + from torch.ao.quantization.observer import MinMaxObserver except ImportError: raise ImportError( "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html" ) + def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, + ) + from torch.fx import Node + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_index_put( + node: Node, quantization_config: QuantizationConfig + ) -> None: + input = node.args[0] + value = node.args[2] + + input_qspec_map = {} + input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + assert isinstance(input_nodes, Sequence) + + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(node, Node) + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def is_edge_condition(node: Node): + if not isinstance(node, Node) or node.op != "call_function": + return True + return False + + def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): + if is_edge_condition(node): + return + + if node.target == torch.ops.aten.index_put_.default: + annotate_index_put(node, quantization_config) + annotate_matmul_input1(node.args[0], quantization_config) + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config) + # Expect that the inputs of the cat op are select ops + for arg in node.args[0][1:]: + annotate_single_in_single_out(arg, quantization_config) + annotate_matmul_input1(node.args[0][0], quantization_config) + else: + annotate_single_in_single_out(node, quantization_config) + annotate_matmul_input1(node.args[0], quantization_config) + + # Annotate 16a8w for matmul op to get better performance + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + # Annotate 8a8w for second input of matmul until past_kv_cache + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.matmul.default + ): + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1], quantization_config_8a8w) + backend, quant_config = pt2e_quantize.split("_") assert ( backend == "qnn" @@ -167,22 +290,25 @@ def get_qnn_quantizer( qnn_quantizer.set_per_channel_conv_quant(enable=True) qnn_quantizer.set_per_channel_linear_quant(enable=True) # more custom quantization are supported including 16a4w etc. default to 8bit quantized - custom_annotations = () + if quant_config == "8a8w": - raise NotImplementedError("8a8w for llama is still under development") quant_dtype = QuantDtype.use_8a8w - pass + custom_annotations = () elif quant_config == "16a16w": - raise NotImplementedError("16a16w for llama is still under development") quant_dtype = QuantDtype.use_16a16w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + qnn_quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) + ) + custom_annotations = () elif quant_config == "16a4w": - raise NotImplementedError("16a4w for llama is still under development") quant_dtype = QuantDtype.use_16a4w qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + qnn_quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + ) qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + custom_annotations = (annotate_matmul_16a8w, ) else: raise AssertionError( f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."