diff --git a/src/deepsparse/debug_analysis.py b/src/deepsparse/debug_analysis.py index e2786b4f21..81ef625af2 100644 --- a/src/deepsparse/debug_analysis.py +++ b/src/deepsparse/debug_analysis.py @@ -141,12 +141,6 @@ def parse_args(): type=str, default="", ) - parser.add_argument( - "--disable-batch-override", - help="Ignores the batch_size parameter", - action="store_true", - default=False, - ) parser.add_argument( "--use-kvcache", help="Enable KVCache", action="store_true", default=False ) @@ -316,10 +310,6 @@ def main(): print("Analyzing model: {}".format(orig_model_path)) batch_size = args.batch_size - if args.disable_batch_override: - batch_size = None - os.environ["NM_DISABLE_BATCH_OVERRIDE"] = "1" - print("Disable batch override: ON") if input_shapes: with override_onnx_input_shapes(model_path, input_shapes) as tmp_path: @@ -357,7 +347,6 @@ def main(): num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, optimization_level=int(args.optimization), - disable_batch_override=args.disable_batch_override, imposed_ks=imposed_kernel_sparsity, input_shapes=input_shapes, kv_cache_params=kv_cache_params, diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py index 79da1ed51d..9699da6026 100644 --- a/src/deepsparse/engine.py +++ b/src/deepsparse/engine.py @@ -17,12 +17,12 @@ """ import logging +import os import time from enum import Enum from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy -import onnx from tqdm.auto import tqdm from deepsparse.analytics import deepsparse_analytics as _analytics @@ -105,9 +105,11 @@ def from_str(key: str): raise ValueError(f"unsupported Scheduler: {key}") -def _validate_batch_size(batch_size: int) -> int: - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") +def _validate_batch_size(batch_size: Optional[int]) -> Optional[int]: + if batch_size is None or batch_size < 1: + _LOGGER.warn("batch_size < 1 so disabling batch size override") + os.environ["NM_DISABLE_BATCH_OVERRIDE"] = "1" + return None return batch_size @@ -225,12 +227,11 @@ class BaseEngine(object): def construct( self, model: Union[str, "Model", "File"], - batch_size: int = 1, - num_cores: int = None, - num_streams: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - disable_batch_override: bool = False, + batch_size: Optional[int] = 1, + num_cores: Optional[int] = None, + num_streams: Optional[int] = None, + scheduler: Optional[Scheduler] = None, + input_shapes: Optional[List[List[int]]] = None, kv_cache_params: Optional[KVCacheParams] = None, ): _analytics.send_event("python__engine__init") @@ -240,7 +241,6 @@ def construct( self._num_streams = _validate_num_streams(num_streams, self._num_cores) self._scheduler = _validate_scheduler(scheduler) self._input_shapes = input_shapes - self._disable_batch_override = disable_batch_override self._kv_cache_params = kv_cache_params self._cpu_avx_type = AVX_TYPE self._cpu_vnni = VNNI @@ -248,10 +248,9 @@ def construct( def construct_with_context( self, model: Union[str, "Model", "File"], - batch_size: int, + batch_size: Optional[int], context: Context, - input_shapes: List[List[int]] = None, - disable_batch_override: bool = False, + input_shapes: Optional[List[List[int]]] = None, kv_cache_params: Optional[KVCacheParams] = None, ): _analytics.send_event("python__engine__init") @@ -261,7 +260,6 @@ def construct_with_context( self._num_streams = context.num_streams self._scheduler = _validate_scheduler(context.scheduler) self._input_shapes = input_shapes - self._disable_batch_override = disable_batch_override self._kv_cache_params = kv_cache_params self._cpu_avx_type = AVX_TYPE self._cpu_vnni = VNNI @@ -297,24 +295,28 @@ class Engine(BaseEngine): def __init__( self, model: Union[str, "Model", "File"], - batch_size: int = 1, + batch_size: Optional[int] = 1, num_cores: int = None, num_streams: int = None, scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - cached_outputs: List[bool] = None, + input_shapes: Optional[List[List[int]]] = None, + cached_outputs: Optional[List[bool]] = None, ): BaseEngine.construct( self, model, batch_size, num_cores, num_streams, scheduler, input_shapes ) + # self._batch_size is allowed to be None to disable setting a batch size, + # but the engine needs to be passed an integer. The value is abitrary and ignored + engine_batch_size = self._batch_size if self._batch_size else 1 + if self._input_shapes: with override_onnx_input_shapes( self._model_path, self._input_shapes ) as model_path: self._eng_net = LIB.deepsparse_engine( model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -324,7 +326,7 @@ def __init__( else: self._eng_net = LIB.deepsparse_engine( self._model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -332,6 +334,9 @@ def __init__( cached_outputs, ) + if self._batch_size is None: + os.environ.pop("NM_DISABLE_BATCH_OVERRIDE", None) + def __call__( self, inp: List[numpy.ndarray], val_inp: bool = True ) -> List[numpy.ndarray]: @@ -704,14 +709,13 @@ def _validate_inputs(self, inp: List[numpy.ndarray]): raise ValueError("inp must be a list, given {}".format(type(inp))) for arr in inp: - if not self._disable_batch_override: - if arr.shape[0] != self._batch_size: - raise ValueError( - ( - "array batch size of {} must match the batch size " - "the model was instantiated with {}" - ).format(arr.shape[0], self._batch_size) - ) + if self._batch_size and arr.shape[0] != self._batch_size: + raise ValueError( + ( + "array batch size of {} must match the batch size " + "the model was instantiated with {}" + ).format(arr.shape[0], self._batch_size) + ) if not arr.flags["C_CONTIGUOUS"]: raise ValueError( @@ -767,14 +771,13 @@ class DebugAnalysisEngine(Engine): def __init__( self, model: Union[str, "Model", "File"], - batch_size: int = 1, - num_cores: int = None, - scheduler: Scheduler = None, + batch_size: Optional[int] = 1, + num_cores: Optional[int] = None, + scheduler: Optional[Scheduler] = None, input_shapes: List[List[int]] = None, num_iterations: int = 20, num_warmup_iterations: int = 5, optimization_level: int = 1, - disable_batch_override: bool = False, imposed_as: Optional[float] = None, imposed_ks: Optional[float] = None, kv_cache_params: Optional[KVCacheParams] = None, @@ -787,12 +790,15 @@ def __init__( None, scheduler, input_shapes, - disable_batch_override, kv_cache_params, ) # Helper def make_engine(self, model_path): + # self._batch_size is allowed to be None to disable setting a batch size, + # but the engine needs to be passed an integer. The value is abitrary and ignored + engine_batch_size = self._batch_size if self._batch_size else 1 + if self._kv_cache_params: self._kv_cache = LIB.kv_cache( self._kv_cache_params.prev_num_tokens, @@ -801,7 +807,7 @@ def make_engine(self, model_path): self._eng_net = LIB.deepsparse_engine( model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -819,7 +825,7 @@ def make_engine(self, model_path): self._eng_net = LIB.deepsparse_engine( model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -840,6 +846,9 @@ def make_engine(self, model_path): else: make_engine(self, self._model_path) + if self._batch_size is None: + os.environ.pop("NM_DISABLE_BATCH_OVERRIDE", None) + def analyze( self, inp: List[numpy.ndarray], val_inp: bool = True ) -> List[numpy.ndarray]: @@ -887,22 +896,26 @@ class MultiModelEngine(Engine): def __init__( self, model: Union[str, "Model", "File"], - batch_size: int, + batch_size: Optional[int], context: Context, - input_shapes: List[List[int]] = None, - cached_outputs: List[bool] = None, + input_shapes: Optional[List[List[int]]] = None, + cached_outputs: Optional[List[bool]] = None, ): BaseEngine.construct_with_context( self, model, batch_size, context, input_shapes ) + # self._batch_size is allowed to be None to disable setting a batch size, + # but the engine needs to be passed an integer. The value is abitrary and ignored + engine_batch_size = self._batch_size if self._batch_size else 1 + if self._input_shapes: with override_onnx_input_shapes( self._model_path, self._input_shapes ) as model_path: self._eng_net = LIB.deepsparse_engine( model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -912,7 +925,7 @@ def __init__( else: self._eng_net = LIB.deepsparse_engine( self._model_path, - self._batch_size, + engine_batch_size, self._num_cores, self._num_streams, self._scheduler.value, @@ -920,14 +933,17 @@ def __init__( cached_outputs, ) + if self._batch_size is None: + os.environ.pop("NM_DISABLE_BATCH_OVERRIDE", None) + def compile_model( model: Union[str, "Model", "File"], - batch_size: int = 1, - num_cores: int = None, - num_streams: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, + batch_size: Optional[int] = 1, + num_cores: Optional[int] = None, + num_streams: Optional[int] = None, + scheduler: Optional[Scheduler] = None, + input_shapes: Optional[List[List[int]]] = None, ) -> Engine: """ Convenience function to compile a model in the DeepSparse Engine @@ -962,16 +978,16 @@ def compile_model( def benchmark_model( model: Union[str, "Model", "File"], inp: List[numpy.ndarray], - batch_size: int = 1, - num_cores: int = None, - num_streams: int = None, + batch_size: Optional[int] = 1, + num_cores: Optional[int] = None, + num_streams: Optional[int] = None, num_iterations: int = 20, num_warmup_iterations: int = 5, include_inputs: bool = False, include_outputs: bool = False, show_progress: bool = False, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, + scheduler: Optional[Scheduler] = None, + input_shapes: Optional[List[List[int]]] = None, ) -> BenchmarkResults: """ Convenience function to benchmark a model in the DeepSparse Engine @@ -1029,16 +1045,15 @@ def benchmark_model( def model_debug_analysis( model: Union[str, "Model", "File"], inp: List[numpy.ndarray], - batch_size: int = 1, - num_cores: int = None, + batch_size: Optional[int] = 1, + num_cores: Optional[int] = None, num_iterations: int = 20, num_warmup_iterations: int = 5, optimization_level: int = 1, - disable_batch_override: bool = False, imposed_as: Optional[float] = None, imposed_ks: Optional[float] = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, + scheduler: Optional[Scheduler] = None, + input_shapes: Optional[List[List[int]]] = None, kv_cache_params: Optional[KVCacheParams] = None, ) -> dict: """ @@ -1054,7 +1069,8 @@ def model_debug_analysis( object that defines the neural network graph definition to analyze :param inp: The list of inputs to pass to the engine for analyzing inference. The expected order is the inputs order as defined in the ONNX graph. - :param batch_size: The batch size of the inputs to be used with the model + :param batch_size: The batch size of the inputs to be used with the model, + <1 disables it. :param num_cores: The number of physical cores to run the model on. Pass None or 0 to run on the max number of cores for the current machine; default None @@ -1064,7 +1080,6 @@ def model_debug_analysis( before analyzing, default is 5 :param optimization_level: The amount of graph optimizations to perform. The current choices are either 0 (minimal) or 1 (all), default is 1 - :param disable_batch_override: Indicates whether disable_batch_override was used or not :param imposed_as: Imposed activation sparsity, defaults to None. Will force the activation sparsity from all ReLu layers in the graph to match this desired sparsity level (percentage of 0's in the tensor). @@ -1087,7 +1102,6 @@ def model_debug_analysis( num_iterations=num_iterations, num_warmup_iterations=num_warmup_iterations, optimization_level=optimization_level, - disable_batch_override=disable_batch_override, imposed_as=imposed_as, imposed_ks=imposed_ks, kv_cache_params=kv_cache_params, diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index 24d2734d73..b7c295d27d 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -25,6 +25,7 @@ from onnx import ModelProto from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE +from deepsparse.utils import parse_input_shapes from deepsparse.utils.extractor import Extractor from sparsezoo.utils import save_onnx, validate_onnx @@ -213,11 +214,16 @@ def generate_random_inputs( if batch_size is not None: in_shape[0] = batch_size - _LOGGER.info( - "Generating input '{}', type = {}, shape = {}".format( - external_input.name, numpy.dtype(elem_type).name, in_shape - ) + input_string = "input '{}', type = {}, shape = {}".format( + external_input.name, numpy.dtype(elem_type).name, in_shape ) + + assert not any(dim < 1 for dim in in_shape), ( + f"Dynamic shape found in {input_string}. " + "All shapes must be non-zero in order to generate random data" + ) + + _LOGGER.info(f"Generating {input_string}") input_data_list.append(numpy.random.rand(*in_shape).astype(elem_type)) return input_data_list @@ -244,6 +250,10 @@ def override_onnx_batch_size( model. Else the modified model will be saved to a temporary file. """ + + if batch_size is None: + return onnx_filepath + model = onnx.load(onnx_filepath, load_external_data=not inplace) all_inputs = model.graph.input initializer_input_names = [node.name for node in model.graph.initializer] @@ -269,7 +279,7 @@ def override_onnx_batch_size( @contextlib.contextmanager def override_onnx_input_shapes( onnx_filepath: str, - input_shapes: Union[List[int], List[List[int]]], + input_shapes: Union[None, str, List[int], List[List[int]]], inplace: bool = True, ) -> str: """ @@ -298,6 +308,9 @@ def override_onnx_input_shapes( input for input in all_inputs if input.name not in initializer_input_names ] + if isinstance(input_shapes, str): + input_shapes = parse_input_shapes(input_shapes) + # Input shapes should be a list of lists, even if there is only one input if not all(isinstance(inp, list) for inp in input_shapes): input_shapes = [input_shapes] diff --git a/tests/test_engine.py b/tests/test_engine.py index f76eef9ea0..b51653c47e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -141,3 +141,26 @@ def test_analyze(self, model: Model, batch_size: int, engine_io): num_warmup_iterations=0, ) assert "layer_info" in results + + +@pytest.mark.smoke +class TestBatchedEngine: + def test_batched(self): + model_stub = ( + "zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/base-none" + ) + + # batch_size=None disable batch override + engine = Engine(model_stub, batch_size=None, input_shapes=[3, 3, 224, 224]) + assert engine.input_shapes[0] == (3, 3, 224, 224) + assert engine.generate_random_inputs()[0].shape == (3, 3, 224, 224) + + # Engine implicitly assumes batch size 1 + engine = Engine(model_stub) + assert engine.input_shapes[0] == (1, 3, 224, 224) + assert engine.generate_random_inputs()[0].shape == (1, 3, 224, 224) + + # Engine first applies input_shapes, then applies batch override to the model + engine = Engine(model_stub, batch_size=5, input_shapes=[1, 3, 224, 224]) + assert engine.input_shapes[0] == (5, 3, 224, 224) + assert engine.generate_random_inputs()[0].shape == (5, 3, 224, 224)