diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 38e3ec4a4c..648bdef9cf 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -14,7 +14,7 @@ import logging import pathlib import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy from transformers import AutoTokenizer, GenerationConfig @@ -33,6 +33,7 @@ "override_config", "process_generation_config", "validate_session_ids", + "compute_engine_inputs", "set_generated_length", ] @@ -82,6 +83,95 @@ def set_generated_length( ) +def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]: + """ + Given the names of the onnx inputs, compute the inputs + to the engine. The inputs will be calculating from the + passed kwargs. The information about the required kwargs + can be found in the docstring of the individual compute + functions. + + :param onnx_input_names: The names of the onnx inputs + :param kwargs: The kwargs to compute the inputs from + :return: The computed inputs to the engine + """ + engine_inputs = [] + for input_name in onnx_input_names: + if input_name == "causal_mask": + # delay the computation of the causal mask + continue + # fetch the compute function for the + # given input_name + compute_func = _get_compute_func(input_name) + # compute the engine input from the kwargs + # and append it to the engine_inputs + engine_inputs.append(compute_func(**kwargs)) + + if "causal_mask" in onnx_input_names: + # compute the causal mask and append it to the engine_inputs + input_ids, attention_mask, *_ = engine_inputs + engine_inputs.append(create_causal_mask(input_ids, attention_mask)) + + return engine_inputs + + +def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]: + # given the input_name, return the appropriate compute function + compute_func = { + "input_ids": _compute_input_ids, + "attention_mask": _compute_attention_mask, + "positions": _compute_positions, + }.get(input_name) + if compute_func is None: + raise ValueError( + "Could not find compute function " f"for the input_name: {input_name}" + ) + return compute_func + + +def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray: + # convert the token_batch to a numpy array + return numpy.array([token_batch]) + + +def _compute_attention_mask( + sequence_length: int, + prompt_sequence_length: int, + num_total_processed_tokens: int, + **kwargs, +) -> numpy.ndarray: + # create a fully masked attention mask with the appropriate + # shape (equal to the sequence_length) + attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64) + # unmask the appropriate number of tokens, the sum of + # - the number of tokens already processed and cached (num_total_processed_tokens) + # - the number of tokens currently processed (prompt_sequence_length) + # the sum cannot exceed the maximum length of the attention_mask + num_attention_entries_to_unmask = min( + num_total_processed_tokens + prompt_sequence_length, sequence_length + ) + # unmask the bits from the right-hand side + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + return attention_mask + + +def _compute_positions( + num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs +): + # create the positions array with the appropriate shape + # positions count starts from the number of tokens already processed + # and ends at the number of tokens already processed + the number of tokens + # currently processed + return ( + numpy.arange( + num_total_processed_tokens, + num_total_processed_tokens + prompt_sequence_length, + ) + .reshape(1, -1) + .astype(numpy.int64) + ) + + def validate_session_ids( session_ids: Optional[str], other_attributes: Dict[str, Any] ) -> Optional[List[str]]: diff --git a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py index 6e97412e43..17d8dd662c 100644 --- a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py +++ b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py @@ -15,9 +15,7 @@ import logging from typing import Any -import numpy - -from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.transformers.utils.helpers import compute_engine_inputs from deepsparse.v2.operators import Operator from deepsparse.v2.utils import PipelineState @@ -66,30 +64,16 @@ def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwarg num_total_processed_tokens = kv_cache.total_num_processed_tokens new_token = tokens[num_total_processed_tokens] - engine_input_names = pipeline_state.current_state.get( - "onnx_input_names_no_cache" - ) - - # padding is added to left, so attention mask is 1s from the - # right up to the number of total tokens (prompt + generated) - attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) - num_attention_entries_to_unmask = min( - num_total_processed_tokens + 1, self.sequence_length - ) # cap by seq len - attention_mask[:, -num_attention_entries_to_unmask:] = 1 - positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64) - input_ids = numpy.array([[new_token]]) - causal_mask = create_causal_mask(input_ids, attention_mask) - engine_inputs_map = dict( - input_ids=input_ids, - attention_mask=attention_mask, - causal_mask=causal_mask, - positions=positions, + engine_inputs = compute_engine_inputs( + onnx_input_names=pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ), + token_batch=[new_token], + prompt_sequence_length=1, + sequence_length=self.sequence_length, + num_total_processed_tokens=num_total_processed_tokens, ) - - engine_inputs = [engine_inputs_map[name] for name in engine_input_names] - return { "engine_inputs": engine_inputs, "kv_cache": kv_cache, diff --git a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py index 9a885c2355..513c34dfc2 100644 --- a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py +++ b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py @@ -13,12 +13,9 @@ # limitations under the License. import logging -from enum import Enum from typing import Any -import numpy - -from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.transformers.utils.helpers import compute_engine_inputs from deepsparse.v2.operators import Operator from deepsparse.v2.utils import PipelineState @@ -28,34 +25,14 @@ __all__ = ["MultiEnginePrefill"] -class OnnxInputNames(Enum): - INPUT_IDS = "input_ids" - ATTN_MASK = "attention_mask" - CAUSAL_MASK = "causal_mask" - POSITIONS = "positions" - - -# NOTE: A possible clean-up could involve combining this Operator and the -# autoregressive_preprocess_operator - - class MultiEnginePrefill(Operator): def __init__(self, prompt_sequence_length, sequence_length): """ Prepare the tokens for the multi-token engine. This requires creating the - attention mask, positions, and causal mask. The output contains these three - arrays to be passed into the multi-token engine. + appropriate engine_inputsto be passed into the multi-token engine. """ self.prompt_sequence_length = prompt_sequence_length self.sequence_length = sequence_length - self.cases = { - OnnxInputNames.ATTN_MASK.value: self._case_attn_mask, - OnnxInputNames.POSITIONS.value: self._case_positions, - } - _LOGGER.warn( - "This operator requires the PipelineState to be set-up with the " - "onnx_input_names_no_cache attribute set from the NLEngineOperator." - ) def can_operate(self, inp: Any): """ @@ -75,59 +52,23 @@ def can_operate(self, inp: Any): return True return False - def _case_attn_mask(self, num_total_processed_tokens: int): - # create an empty attention mask - engine_input = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) - # calculate the number of entries in attention mask that should be set to 1 - num_attention_entries_to_unmask = min( - num_total_processed_tokens + self.prompt_sequence_length, - self.sequence_length, - ) - engine_input[:, -num_attention_entries_to_unmask:] = 1 - return engine_input - - def _case_positions(self, num_total_processed_tokens: int): - return ( - numpy.arange( - num_total_processed_tokens, - num_total_processed_tokens + self.prompt_sequence_length, - ) - .reshape(1, -1) - .astype(numpy.int64) - ) - def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length) - onnx_input_names_no_cache = pipeline_state.current_state.get( - "onnx_input_names_no_cache" - ) - num_total_processed_tokens = kv_cache.total_num_processed_tokens start = num_total_processed_tokens end = start + self.prompt_sequence_length token_batch = tokens[start:end] - engine_inputs = [] - for name in onnx_input_names_no_cache: - if name == OnnxInputNames.INPUT_IDS.value: - engine_input = numpy.array([token_batch]) - elif ( - name == OnnxInputNames.ATTN_MASK.value - or name == OnnxInputNames.POSITIONS.value - ): - engine_input = self.cases[name](num_total_processed_tokens) - elif name == OnnxInputNames.CAUSAL_MASK.value: - continue - - engine_inputs.append(engine_input) - - if OnnxInputNames.CAUSAL_MASK.value in onnx_input_names_no_cache: - causal_mask = create_causal_mask( - input_ids=engine_inputs[0], - attention_mask=engine_inputs[1], - ) - engine_inputs.append(causal_mask) + engine_inputs = compute_engine_inputs( + onnx_input_names=pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ), + token_batch=token_batch, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + num_total_processed_tokens=num_total_processed_tokens, + ) return { "engine_inputs": engine_inputs, diff --git a/tests/deepsparse/transformers/utils/test_helpers.py b/tests/deepsparse/transformers/utils/test_helpers.py index 7fcadcbf9c..95e4ee7fa7 100644 --- a/tests/deepsparse/transformers/utils/test_helpers.py +++ b/tests/deepsparse/transformers/utils/test_helpers.py @@ -16,12 +16,86 @@ import pytest from deepsparse.transformers.utils.helpers import ( + compute_engine_inputs, create_causal_mask, initialize_kv_cache_state, validate_session_ids, ) +@pytest.mark.parametrize( + "onnx_input_names, " + "token_batch, " + "prompt_sequence_length, " + "sequence_length, " + "num_total_processed_tokens, " + "expected_engine_inputs", + [ + ( + ["input_ids", "attention_mask", "positions"], + [1, 2, 3], + 3, + 6, + 2, + [ + numpy.array([[1, 2, 3]]), + numpy.array([[0, 1, 1, 1, 1, 1]]), + numpy.array([[2, 3, 4]]), + ], + ), + ( + ["input_ids", "attention_mask", "positions", "causal_mask"], + [1, 2, 3], + 3, + 6, + 2, + [ + numpy.array([[1, 2, 3]]), + numpy.array([[0, 1, 1, 1, 1, 1]]), + numpy.array([[2, 3, 4]]), + create_causal_mask( + input_ids=numpy.array([[1, 2, 3]]), + attention_mask=numpy.array([[0, 1, 1, 1, 1, 1]]), + ), + ], + ), + ( + ["input_ids", "attention_mask", "positions", "causal_mask"], + [15], + 1, + 5, + 3, + [ + numpy.array([[15]]), + numpy.array([[0, 1, 1, 1, 1]]), + numpy.array([[3]]), + create_causal_mask( + input_ids=numpy.array([[15]]), + attention_mask=numpy.array([[0, 1, 1, 1, 1]]), + ), + ], + ), + ], +) +def test_compute_engine_inputs( + onnx_input_names, + token_batch, + prompt_sequence_length, + sequence_length, + num_total_processed_tokens, + expected_engine_inputs, +): + engine_inputs = compute_engine_inputs( + onnx_input_names=onnx_input_names, + token_batch=token_batch, + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + num_total_processed_tokens=num_total_processed_tokens, + ) + for x, y in zip(engine_inputs, expected_engine_inputs): + assert numpy.array_equal(x, y) + + @pytest.mark.parametrize( "input_ids, attention_mask, expected_causal_mask", [