From 5806919a106e4e6fcf9abf39c06c3e71d32ba96b Mon Sep 17 00:00:00 2001 From: Chong Gu Date: Thu, 12 Sep 2024 16:46:53 -0700 Subject: [PATCH] Allow Passing in Size of Dynamic Dimensions to Inference Function Summary: Add a new param `dynamic_size` to lower settings, that allows explicitly setting the number to look for that corresponds to the dynamic dimension in inputs. Reviewed By: frank-wei Differential Revision: D62448015 --- fx2ait/fx2ait/find_batch_size_dim.py | 88 ++++++++++++++------------- fx2ait/fx2ait/lower/lower_settings.py | 1 + fx2ait/fx2ait/tensor_spec.py | 6 +- 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/fx2ait/fx2ait/find_batch_size_dim.py b/fx2ait/fx2ait/find_batch_size_dim.py index f5072f767..131b6533d 100644 --- a/fx2ait/fx2ait/find_batch_size_dim.py +++ b/fx2ait/fx2ait/find_batch_size_dim.py @@ -21,55 +21,59 @@ def find_batch_size_dim( inputs: Any, can_non_first_dim_be_dynamic: bool = True, can_dim_value_one_be_dynamic: bool = True, + dynamic_size: int = -1, # pyre-fixme Invalid type [31] ) -> []: if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: return [0] - shapes = [i.shape for i in inputs] - frequency_map = {} - position_scores = {} - first_dims = set() - for shape in shapes: - if len(shape) < 2: - # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info - continue - # Dedup shape value for single tensor - first_dims.add(shape[0]) - seen_dims = set() - valid_len = len(shape) if can_non_first_dim_be_dynamic else 1 - for i in range(valid_len): - dim = shape[i] - if dim not in seen_dims: - frequency_map[dim] = frequency_map.get(dim, 0) + 1 - position_scores[dim] = position_scores.get(dim, 0) + i - seen_dims.add(dim) - if len(first_dims) == 1: - # first dim is the same in every input: we use it as batch_size - batch_size = first_dims.pop() - elif frequency_map: - # first dims are different: we use the most frequent dim as batch_size - # if there is more than 1 most frequent dim, we choose the one with the - # lowest position score (i.e., the leftmost of the most frequent ones) - sorted_frequency = sorted( - frequency_map.items(), - key=lambda x: (-x[1], position_scores[x[0]]), - ) - if len(sorted_frequency) > 1: - if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: - # It's often that dim value one indicates a non-dynamic dimension. - # If the user says so, we pick the second most frequent value. - batch_size = sorted_frequency[1][0] + if dynamic_size > 0: + batch_size = dynamic_size + else: + shapes = [i.shape for i in inputs] + frequency_map = {} + position_scores = {} + first_dims = set() + for shape in shapes: + if len(shape) < 2: + # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info + continue + # Dedup shape value for single tensor + first_dims.add(shape[0]) + seen_dims = set() + valid_len = len(shape) if can_non_first_dim_be_dynamic else 1 + for i in range(valid_len): + dim = shape[i] + if dim not in seen_dims: + frequency_map[dim] = frequency_map.get(dim, 0) + 1 + position_scores[dim] = position_scores.get(dim, 0) + i + seen_dims.add(dim) + if len(first_dims) == 1: + # first dim is the same in every input: we use it as batch_size + batch_size = first_dims.pop() + elif frequency_map: + # first dims are different: we use the most frequent dim as batch_size + # if there is more than 1 most frequent dim, we choose the one with the + # lowest position score (i.e., the leftmost of the most frequent ones) + sorted_frequency = sorted( + frequency_map.items(), + key=lambda x: (-x[1], position_scores[x[0]]), + ) + if len(sorted_frequency) > 1: + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + # It's often that dim value one indicates a non-dynamic dimension. + # If the user says so, we pick the second most frequent value. + batch_size = sorted_frequency[1][0] + else: + batch_size = sorted_frequency[0][0] else: - batch_size = sorted_frequency[0][0] + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + batch_size = -1 + else: + batch_size = sorted_frequency[0][0] else: - if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: - batch_size = -1 - else: - batch_size = sorted_frequency[0][0] - else: - # no dims to sort: no batch_size - batch_size = -1 + # no dims to sort: no batch_size + batch_size = -1 bs_dim = [] for i in inputs: diff --git a/fx2ait/fx2ait/lower/lower_settings.py b/fx2ait/fx2ait/lower/lower_settings.py index 1c1ffbb21..340a0e00e 100644 --- a/fx2ait/fx2ait/lower/lower_settings.py +++ b/fx2ait/fx2ait/lower/lower_settings.py @@ -68,6 +68,7 @@ class LowerSettings: name: str = "" dll_name: str = "ait_engine.so" dynamic_profile_strategy: DynamicProfileStrategy = DynamicProfileStrategy.MAX + dynamic_size: int = -1 profile_devs: Any = None # If None, infer the dtypes from the sample inputs. precision: Optional[LowerPrecision] = LowerPrecision.FP16 diff --git a/fx2ait/fx2ait/tensor_spec.py b/fx2ait/fx2ait/tensor_spec.py index f6d8fe114..4a1977f41 100644 --- a/fx2ait/fx2ait/tensor_spec.py +++ b/fx2ait/fx2ait/tensor_spec.py @@ -477,10 +477,14 @@ def find_batch_size_dim( inputs: Any, can_non_first_dim_be_dynamic: bool = True, can_dim_value_one_be_dynamic: bool = True, + dynamic_size: int = -1, # pyre-fixme Invalid type [31] ) -> []: return find_batch_size_dim_impl( - inputs, can_non_first_dim_be_dynamic, can_dim_value_one_be_dynamic + inputs, + can_non_first_dim_be_dynamic, + can_dim_value_one_be_dynamic, + dynamic_size, ) @classmethod