diff --git a/src/deepsparse/v2/operators/engine_operator.py b/src/deepsparse/v2/operators/engine_operator.py index c2fc562c63..bd58aefafa 100644 --- a/src/deepsparse/v2/operators/engine_operator.py +++ b/src/deepsparse/v2/operators/engine_operator.py @@ -20,7 +20,7 @@ from deepsparse import Context as EngineContext from deepsparse import Engine, MultiModelEngine, Scheduler from deepsparse.benchmark import ORTEngine -from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs +from deepsparse.utils import model_to_path from deepsparse.v2.operators import Operator @@ -145,18 +145,6 @@ def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict: # planned refactor engine_outputs = inp.engine(inp.engine_inputs) return {"engine_outputs": engine_outputs} - inp = inp.engine_inputs - batches, orig_batch_size = self.expand_inputs(engine_inputs=inp) - batches_outputs = list(map(self.engine, batches)) - engine_outputs = self.condense_inputs( - batch_outputs=batches_outputs, orig_batch_size=orig_batch_size - ) - return {"engine_outputs": engine_outputs} - def expand_inputs(self, **kwargs): - return split_engine_inputs(kwargs["engine_inputs"], self._batch_size) - - def condense_inputs(self, **kwargs): - batch_outputs = kwargs["batch_outputs"] - orig_batch_size = kwargs["orig_batch_size"] - return join_engine_outputs(batch_outputs, orig_batch_size) + engine_outputs = self.engine(inp.engine_inputs) + return {"engine_outputs": engine_outputs} diff --git a/src/deepsparse/v2/operators/operator.py b/src/deepsparse/v2/operators/operator.py index b3963d8223..5bb0be841a 100644 --- a/src/deepsparse/v2/operators/operator.py +++ b/src/deepsparse/v2/operators/operator.py @@ -99,7 +99,6 @@ def __call__( pipeline_state=pipeline_state, **kwargs, ) - if self.has_output_schema(): return self.output_schema(**run_output) return run_output @@ -117,18 +116,6 @@ def can_operate(self, inp: Any) -> bool: """ return True - def expand_inputs(self, **kwargs): - """ - Generic function to handle expanding values. - """ - raise NotImplementedError - - def condense_inputs(self, **kwargs): - """ - Generic function to handle condensing values. - """ - raise NotImplementedError - def yaml(self): pass diff --git a/src/deepsparse/v2/pipeline.py b/src/deepsparse/v2/pipeline.py index 0a8c8b2f93..f56680d2b9 100644 --- a/src/deepsparse/v2/pipeline.py +++ b/src/deepsparse/v2/pipeline.py @@ -13,7 +13,10 @@ # limitations under the License. -from typing import Dict, List, Union +import copy +from concurrent.futures import Future +from functools import partial +from typing import Any, Callable, Dict, List, Union from deepsparse.v2.operators import Operator from deepsparse.v2.routers import Router @@ -56,9 +59,88 @@ def __init__( self.pipeline_state = pipeline_state self.validate() - # SchedulerGroup handles running all schedulers in order of priority self._scheduler_group = SchedulerGroup(self.schedulers) + def _run_sequential( + self, + inp: Any, + inference_state: InferenceState, + pipeline_state: PipelineState, + start: str, + end: str, + ): + next_step = start + while next_step != end: + outputs = self._run_next_step( + func=self.ops[next_step], + next_step=next_step, + input=inp, + pipeline_state=pipeline_state, + inference_state=inference_state, + ) + next_step, operator_output, state_update = outputs + if state_update: + inference_state.update_state(state_update) + inp = operator_output + return inp + + def _apply_split(self, inp: Any, inference_state: InferenceState): + """ + Split inputs using the pipeline's expand_inputs function. Inputs are split + into a batch size of one when a SPLIT_ROUTE node is found in a given pipeline's + provided router. The split batches are run asynchronously and then joined when + a JOIN_ROUTE node is found, using the pipeline's condense_inputs function. + """ + + batches, orig_batch_size = self.expand_inputs(inp, 1) + run_with_state = partial( + self._run_sequential, + pipeline_state=self.pipeline_state, + start=self.router.route[self.router.SPLIT_ROUTE], + end=self.router.JOIN_ROUTE, + ) + inference_state_list = [ + copy.deepcopy(inference_state) for x in range(len(batches)) + ] + futures = self._scheduler_group.map( + batches, + inference_state_list, + func=run_with_state, + ) + return self.condense_inputs([x.result() for x in futures]) + + def _run_next_step( + self, + *args, + func: Callable, + next_step: Union[str, int], + input: Any = None, + **kwargs, + ): + """ + Generic function to run a given func, process the output and determine the next + step. + """ + if input: + operator_output = ( + func(*args, **kwargs, **input) + if isinstance(input, dict) + else func(input, *args, **kwargs) + ) + else: + operator_output = func(*args, **kwargs) + + if isinstance(operator_output, Future): + operator_output = operator_output.result() + + state_update = None + if isinstance(operator_output, tuple): + state_update = operator_output[-1] + operator_output = operator_output[0] + + next_step = self.router.next(next_step, self.ops, operator_output) + return next_step, operator_output, state_update + def run( self, *args, @@ -78,40 +160,34 @@ def run( operator_output = None while next_step != self.router.END_ROUTE: - # Either a dictionary key or valid index - operator = self.ops[next_step] + # NOTE: split_route should only appear after the start route node + if next_step == self.router.SPLIT_ROUTE: + operator_output = self._apply_split(operator_output, inference_state) + next_step = self.router.route[self.router.JOIN_ROUTE] + if next_step == self.router.START_ROUTE: - output_future = self._scheduler_group.submit( + outputs = self._run_next_step( *args, + next_step=next_step, + func=self._scheduler_group.submit, inference_state=inference_state, - operator=operator, + operator=self.ops[next_step], pipeline_state=pipeline_state, **kwargs, ) else: - if isinstance(operator_output, dict): - output_future = self._scheduler_group.submit( - inference_state=inference_state, - operator=operator, - pipeline_state=pipeline_state, - **operator_output, - ) - else: - output_future = self._scheduler_group.submit( - operator_output, - inference_state=inference_state, - pipeline_state=pipeline_state, - operator=operator, - ) - - operator_output = output_future.result() - if isinstance(operator_output, tuple): - state_update = operator_output[-1] - operator_output = operator_output[0] - inference_state.update_state(state_update) - - next_step = self.router.next(next_step, self.ops, operator_output) + outputs = self._run_next_step( + func=self._scheduler_group.submit, + input=operator_output, + next_step=next_step, + inference_state=inference_state, + operator=self.ops[next_step], + pipeline_state=pipeline_state, + ) + next_step, operator_output, state_update = outputs + if state_update: + inference_state.update_state(state_update) return operator_output def __call__(self, *args, **kwargs): @@ -136,6 +212,27 @@ def __call__(self, *args, **kwargs): return self.run(*args, **kwargs) + def expand_inputs(self, *args, **kwargs): + """ + Generic function to handle expanding values. + """ + raise NotImplementedError( + "This function should be implemented for any router with split or join" + "nodes. expand_inputs will be called prior to the split node (stored in " + "the router's SPLIT_ROUTE attribute), expanding outputs for each output " + "such that there is a batch size of one per thread." + ) + + def condense_inputs(self, *args, **kwargs): + """ + Generic function to handle condensing values. + """ + raise NotImplementedError( + "This function should be implemented for any router with split or join " + "nodes. condense_inputs will be called after the join node (stored in the " + "router's JOIN_ROUTE attribute), condensing outputs from multiple threads." + ) + def validate(self): """ Validate that compatability of the router and operators provided. diff --git a/src/deepsparse/v2/routers/router.py b/src/deepsparse/v2/routers/router.py index d1110d4ca7..1b70164002 100644 --- a/src/deepsparse/v2/routers/router.py +++ b/src/deepsparse/v2/routers/router.py @@ -41,9 +41,13 @@ def __init__( end_route: Union[str, int], start_route: Union[str, int], route: Optional[Dict] = None, + split_route: str = "SPLIT", + join_route: str = "JOIN", ): self.START_ROUTE = start_route self.END_ROUTE = end_route + self.SPLIT_ROUTE = split_route + self.JOIN_ROUTE = join_route self.route = route @abstractmethod @@ -79,6 +83,9 @@ class LinearRouter(Router): def __init__(self, end_route: int, start_route: int = 0): super().__init__(end_route=end_route, start_route=start_route) + self.SPLIT_ROUTE = None + self.JOIN_ROUTE = None + _LOGGER.warn("SPLIT and JOIN are not yet supported for the LinearRouter.") def next( self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None @@ -128,8 +135,10 @@ class GraphRouter(Router): where `can_operate` returns True will run. Paths should be deterministic. """ - def __init__(self, end_route: str, start_route: str, route: Dict): - super().__init__(end_route=end_route, start_route=start_route, route=route) + def __init__(self, end_route: str, start_route: str, route: Dict, **kwargs): + super().__init__( + end_route=end_route, start_route=start_route, route=route, **kwargs + ) def next( self, diff --git a/src/deepsparse/v2/schedulers/scheduler.py b/src/deepsparse/v2/schedulers/scheduler.py index 78a58e3389..5313683107 100644 --- a/src/deepsparse/v2/schedulers/scheduler.py +++ b/src/deepsparse/v2/schedulers/scheduler.py @@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor +from typing import Callable from deepsparse.v2.operators import Operator @@ -64,3 +65,13 @@ def can_process( Base OperatorScheduler always returns True """ return True + + def map(self, *args, func: Callable): + """ + :param func: generic callable run for each arg + :return: list of futures for each submit + """ + futures = [] + for _, values in enumerate(zip(*args)): + futures.append(self.submit(*values, operator=func)) + return futures diff --git a/src/deepsparse/v2/schedulers/scheduler_group.py b/src/deepsparse/v2/schedulers/scheduler_group.py index 40b5695f22..14d869a0f2 100644 --- a/src/deepsparse/v2/schedulers/scheduler_group.py +++ b/src/deepsparse/v2/schedulers/scheduler_group.py @@ -55,23 +55,3 @@ def submit( operator=operator, **kwargs, ) - - def can_process( - self, - *args, - operator: Operator, - **kwargs, - ) -> bool: - """ - :param operator: operator to check - :return: True if this Operator can process the given operator and input. - SchedulerGroup always returns True - """ - return any( - scheduler.can_process( - *args, - operator=operator, - **kwargs, - ) - for scheduler in self.schedulers - ) diff --git a/src/deepsparse/v2/text_generation/__init__.py b/src/deepsparse/v2/text_generation/__init__.py index 21cd7e2acd..08836b8bbe 100644 --- a/src/deepsparse/v2/text_generation/__init__.py +++ b/src/deepsparse/v2/text_generation/__init__.py @@ -17,6 +17,7 @@ from .compile_generations import * from .compile_logits import * from .generate_new_token import * +from .join_output import * from .kv_cache_operator import * from .multi_engine_prefill_operator import * from .nl_engine_operator import * diff --git a/src/deepsparse/v2/text_generation/join_output.py b/src/deepsparse/v2/text_generation/join_output.py new file mode 100644 index 0000000000..8a6c77a2f1 --- /dev/null +++ b/src/deepsparse/v2/text_generation/join_output.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy + +from deepsparse.transformers.utils.helpers import pad_to_fixed_length +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.compile_generations import CompileGenerationsOutput + + +__all__ = ["JoinOutput"] + + +class JoinOutput(Operator): + """ + Run this operator to combine the results from multiple prompts. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def run(self, inp: List[CompileGenerationsOutput], **kwargs): + batch_outputs = [x for x in inp[0]] + generated_tokens = [x.generated_tokens for x in batch_outputs] + generated_logits = [x.generated_logits for x in batch_outputs] + finished_reason = [x.finished_reason for x in batch_outputs] + + max_len = max(token.shape[1] for token in generated_tokens) + + # pad all tokens to the same length + tokens = [ + pad_to_fixed_length( + array=prediction, + max_len=max_len, + value=self.tokenizer.pad_token_id, + axis=1, + ) + for prediction in generated_tokens + ] + + # find the longest sequence in the batch of logits + max_len = max(logits.shape[1] for logits in generated_logits) + + # pad all logits to the same length + logits = [ + pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1) + for single_logits in generated_logits + ] + + tokens = numpy.concatenate(tokens) + logits = numpy.concatenate(logits) + + return { + "generated_tokens": tokens, + "generated_logits": logits, + "finished_reason": finished_reason, + } diff --git a/src/deepsparse/v2/text_generation/pipeline.py b/src/deepsparse/v2/text_generation/pipeline.py index 49826b8af7..240da04907 100644 --- a/src/deepsparse/v2/text_generation/pipeline.py +++ b/src/deepsparse/v2/text_generation/pipeline.py @@ -15,6 +15,7 @@ from typing import Dict from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.utils import split_engine_inputs from deepsparse.v2.pipeline import Pipeline from deepsparse.v2.routers import GraphRouter from deepsparse.v2.schedulers import OperatorScheduler @@ -24,6 +25,7 @@ CompileGenerations, CompilePromptLogits, GenerateNewTokenOperator, + JoinOutput, KVCacheCreator, MultiEnginePrefill, NLEngineOperator, @@ -131,6 +133,7 @@ def __init__( process_output = ProcessOutputs(tokenizer=self.tokenizer) compile_generations = CompileGenerations() compile_generated_tokens = CompileGeneratedTokens() + join_output = JoinOutput(tokenizer=self.tokenizer) ops = { "process_input": process_inputs, @@ -146,10 +149,12 @@ def __init__( "process_outputs": process_output, "compile_generations": compile_generations, "compile_generated_tokens": compile_generated_tokens, + "join_output": join_output, } routes = { - "process_input": "prepare_prefill", + "process_input": "SPLIT", + "SPLIT": "prepare_prefill", "prepare_prefill": ["multi_engine_prefill", "autoregressive_preprocess"], "multi_engine_prefill": "multi_engine", "multi_engine": "compile_logits", @@ -169,7 +174,9 @@ def __init__( "autoregressive_preprocess", "compile_generations", ], - "compile_generations": "process_outputs", + "compile_generations": "JOIN", + "JOIN": "join_output", + "join_output": "process_outputs", "process_outputs": "STOP", } @@ -181,6 +188,15 @@ def __init__( ops=ops, router=router, schedulers=scheduler, pipeline_state=pipeline_state ) + def expand_inputs(self, items, batch_size): + items = [items.get(key) for key in items.keys()] + out, orig_batch_size = split_engine_inputs(items, batch_size) + combined_batches = [{"input_ids": b[0], "attention_mask": b[1]} for b in out] + return combined_batches, orig_batch_size + + def condense_inputs(self, *args, **kwargs): + return args[0], kwargs + # TODO: Move to be part of a generic transformers set-up Operator. def setup_onnx_file_path(self, model_path, sequence_length) -> str: import logging diff --git a/src/deepsparse/v2/text_generation/prep_for_prefill.py b/src/deepsparse/v2/text_generation/prep_for_prefill.py index 2f9eb15797..2e5fecb3e8 100644 --- a/src/deepsparse/v2/text_generation/prep_for_prefill.py +++ b/src/deepsparse/v2/text_generation/prep_for_prefill.py @@ -42,13 +42,20 @@ def __init__(self, kv_cache_creator: Operator): "from the NLEngineOperator" ) - def run(self, tokens: Any, pipeline_state: PipelineState, **kwargs): + def run( + self, + input_ids: Any, + attention_mask: Any, + pipeline_state: PipelineState, + **kwargs, + ): # NOTE: Can potentially just be class attributes instead of relying on # pipeline state. cache_shape = pipeline_state.current_state.get("cache_shape") data_type = pipeline_state.current_state.get("kv_cache_data_type") output_names = pipeline_state.current_state.get("output_names") + tokens = input_ids[attention_mask.nonzero()].tolist() kv_cache = self.kv_cache_creator.run( cache_shape=cache_shape, kv_cache_data_type=data_type, diff --git a/src/deepsparse/v2/text_generation/process_inputs.py b/src/deepsparse/v2/text_generation/process_inputs.py index e57e402983..5d47c8ff39 100644 --- a/src/deepsparse/v2/text_generation/process_inputs.py +++ b/src/deepsparse/v2/text_generation/process_inputs.py @@ -114,8 +114,7 @@ def run(self, inp: TextGenerationInput, **kwargs): frequency_penalty=generation_config.repetition_penalty, ) - # TODO: move this step to prep_for_prefill and add attention mask to the output - # this will allow us to split/join more easily when processing multiple prompts - # in parallel - tokens = input_ids[attention_mask.nonzero()].tolist() - return {"tokens": tokens}, inference_state_update + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, inference_state_update diff --git a/src/deepsparse/v2/text_generation/process_outputs.py b/src/deepsparse/v2/text_generation/process_outputs.py index ca1cf78521..7173b8e256 100644 --- a/src/deepsparse/v2/text_generation/process_outputs.py +++ b/src/deepsparse/v2/text_generation/process_outputs.py @@ -22,7 +22,6 @@ TextGenerationOutput, ) from deepsparse.v2.operators import Operator -from deepsparse.v2.text_generation.compile_generations import CompileGenerationsOutput from deepsparse.v2.utils import InferenceState @@ -52,19 +51,20 @@ def _create_generated_text_output( ) def run( - self, inp: CompileGenerationsOutput, inference_state: InferenceState, **kwargs + self, + generated_tokens: numpy.ndarray, + generated_logits: numpy.ndarray, + finished_reason: list, + inference_state: InferenceState, + **kwargs, ): generation_config = inference_state.current_state.get("generation_config") - generated_tokens = inp.generated_tokens - generated_logits = ( - inp.generated_logits if generation_config.output_scores else None - ) - finished_reason = inp.finished_reason + generated_logits = generated_logits if generation_config.output_scores else None sequences = self.tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) - finished_reason = [f for f in finished_reason if f] + finished_reason = [f[-1] for f in finished_reason] if generated_logits is not None: generations = list( @@ -79,6 +79,15 @@ def run( generations = list( map(self._create_generated_text_output, sequences, finished_reason) ) + + num_preds = generation_config.num_return_sequences + if num_preds > 1: + grouped_generations = [ + generations[n : n + num_preds] + for n in range(0, len(generations), num_preds) + ] + generations = grouped_generations + outputs = dict( created=datetime.datetime.now(), prompts=inference_state.current_state.get("prompts"),