Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import pathlib
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy
from transformers import AutoTokenizer, GenerationConfig
Expand All @@ -33,6 +33,7 @@
"override_config",
"process_generation_config",
"validate_session_ids",
"compute_engine_inputs",
"set_generated_length",
]

Expand Down Expand Up @@ -82,6 +83,95 @@ def set_generated_length(
)


def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]:
"""
Given the names of the onnx inputs, compute the inputs
to the engine. The inputs will be calculating from the
passed kwargs. The information about the required kwargs
can be found in the docstring of the individual compute
functions.

:param onnx_input_names: The names of the onnx inputs
:param kwargs: The kwargs to compute the inputs from
:return: The computed inputs to the engine
"""
engine_inputs = []
for input_name in onnx_input_names:
if input_name == "causal_mask":
# delay the computation of the causal mask
continue
# fetch the compute function for the
# given input_name
compute_func = _get_compute_func(input_name)
# compute the engine input from the kwargs
# and append it to the engine_inputs
engine_inputs.append(compute_func(**kwargs))

if "causal_mask" in onnx_input_names:
# compute the causal mask and append it to the engine_inputs
input_ids, attention_mask, *_ = engine_inputs
engine_inputs.append(create_causal_mask(input_ids, attention_mask))

return engine_inputs


def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]:
# given the input_name, return the appropriate compute function
compute_func = {
"input_ids": _compute_input_ids,
"attention_mask": _compute_attention_mask,
"positions": _compute_positions,
}.get(input_name)
if compute_func is None:
raise ValueError(
"Could not find compute function " f"for the input_name: {input_name}"
)
return compute_func


def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray:
# convert the token_batch to a numpy array
return numpy.array([token_batch])


def _compute_attention_mask(
sequence_length: int,
prompt_sequence_length: int,
num_total_processed_tokens: int,
**kwargs,
) -> numpy.ndarray:
# create a fully masked attention mask with the appropriate
# shape (equal to the sequence_length)
attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64)
# unmask the appropriate number of tokens, the sum of
# - the number of tokens already processed and cached (num_total_processed_tokens)
# - the number of tokens currently processed (prompt_sequence_length)
# the sum cannot exceed the maximum length of the attention_mask
num_attention_entries_to_unmask = min(
num_total_processed_tokens + prompt_sequence_length, sequence_length
)
# unmask the bits from the right-hand side
attention_mask[:, -num_attention_entries_to_unmask:] = 1
return attention_mask


def _compute_positions(
num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs
):
# create the positions array with the appropriate shape
# positions count starts from the number of tokens already processed
# and ends at the number of tokens already processed + the number of tokens
# currently processed
return (
numpy.arange(
num_total_processed_tokens,
num_total_processed_tokens + prompt_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)


def validate_session_ids(
session_ids: Optional[str], other_attributes: Dict[str, Any]
) -> Optional[List[str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import logging
from typing import Any

import numpy

from deepsparse.transformers.utils.helpers import create_causal_mask
from deepsparse.transformers.utils.helpers import compute_engine_inputs
from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import PipelineState

Expand Down Expand Up @@ -66,30 +64,16 @@ def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwarg

num_total_processed_tokens = kv_cache.total_num_processed_tokens
new_token = tokens[num_total_processed_tokens]
engine_input_names = pipeline_state.current_state.get(
"onnx_input_names_no_cache"
)

# padding is added to left, so attention mask is 1s from the
# right up to the number of total tokens (prompt + generated)
attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)
num_attention_entries_to_unmask = min(
num_total_processed_tokens + 1, self.sequence_length
) # cap by seq len
attention_mask[:, -num_attention_entries_to_unmask:] = 1
positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64)
input_ids = numpy.array([[new_token]])
causal_mask = create_causal_mask(input_ids, attention_mask)

engine_inputs_map = dict(
input_ids=input_ids,
attention_mask=attention_mask,
causal_mask=causal_mask,
positions=positions,
engine_inputs = compute_engine_inputs(
onnx_input_names=pipeline_state.current_state.get(
"onnx_input_names_no_cache"
),
token_batch=[new_token],
prompt_sequence_length=1,
sequence_length=self.sequence_length,
num_total_processed_tokens=num_total_processed_tokens,
)

engine_inputs = [engine_inputs_map[name] for name in engine_input_names]

return {
"engine_inputs": engine_inputs,
"kv_cache": kv_cache,
Expand Down
81 changes: 11 additions & 70 deletions src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.

import logging
from enum import Enum
from typing import Any

import numpy

from deepsparse.transformers.utils.helpers import create_causal_mask
from deepsparse.transformers.utils.helpers import compute_engine_inputs
from deepsparse.v2.operators import Operator
from deepsparse.v2.utils import PipelineState

Expand All @@ -28,34 +25,14 @@
__all__ = ["MultiEnginePrefill"]


class OnnxInputNames(Enum):
INPUT_IDS = "input_ids"
ATTN_MASK = "attention_mask"
CAUSAL_MASK = "causal_mask"
POSITIONS = "positions"


# NOTE: A possible clean-up could involve combining this Operator and the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this PR resolves this todo note.

# autoregressive_preprocess_operator


class MultiEnginePrefill(Operator):
def __init__(self, prompt_sequence_length, sequence_length):
"""
Prepare the tokens for the multi-token engine. This requires creating the
attention mask, positions, and causal mask. The output contains these three
arrays to be passed into the multi-token engine.
appropriate engine_inputsto be passed into the multi-token engine.
"""
self.prompt_sequence_length = prompt_sequence_length
self.sequence_length = sequence_length
self.cases = {
OnnxInputNames.ATTN_MASK.value: self._case_attn_mask,
OnnxInputNames.POSITIONS.value: self._case_positions,
}
_LOGGER.warn(
"This operator requires the PipelineState to be set-up with the "
"onnx_input_names_no_cache attribute set from the NLEngineOperator."
)

def can_operate(self, inp: Any):
"""
Expand All @@ -75,59 +52,23 @@ def can_operate(self, inp: Any):
return True
return False

def _case_attn_mask(self, num_total_processed_tokens: int):
# create an empty attention mask
engine_input = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)
# calculate the number of entries in attention mask that should be set to 1
num_attention_entries_to_unmask = min(
num_total_processed_tokens + self.prompt_sequence_length,
self.sequence_length,
)
engine_input[:, -num_attention_entries_to_unmask:] = 1
return engine_input

def _case_positions(self, num_total_processed_tokens: int):
return (
numpy.arange(
num_total_processed_tokens,
num_total_processed_tokens + self.prompt_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)

def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs):
kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length)

onnx_input_names_no_cache = pipeline_state.current_state.get(
"onnx_input_names_no_cache"
)

num_total_processed_tokens = kv_cache.total_num_processed_tokens
start = num_total_processed_tokens
end = start + self.prompt_sequence_length
token_batch = tokens[start:end]

engine_inputs = []
for name in onnx_input_names_no_cache:
if name == OnnxInputNames.INPUT_IDS.value:
engine_input = numpy.array([token_batch])
elif (
name == OnnxInputNames.ATTN_MASK.value
or name == OnnxInputNames.POSITIONS.value
):
engine_input = self.cases[name](num_total_processed_tokens)
elif name == OnnxInputNames.CAUSAL_MASK.value:
continue

engine_inputs.append(engine_input)

if OnnxInputNames.CAUSAL_MASK.value in onnx_input_names_no_cache:
causal_mask = create_causal_mask(
input_ids=engine_inputs[0],
attention_mask=engine_inputs[1],
)
engine_inputs.append(causal_mask)
engine_inputs = compute_engine_inputs(
onnx_input_names=pipeline_state.current_state.get(
"onnx_input_names_no_cache"
),
token_batch=token_batch,
prompt_sequence_length=self.prompt_sequence_length,
sequence_length=self.sequence_length,
num_total_processed_tokens=num_total_processed_tokens,
)

return {
"engine_inputs": engine_inputs,
Expand Down
74 changes: 74 additions & 0 deletions tests/deepsparse/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,86 @@

import pytest
from deepsparse.transformers.utils.helpers import (
compute_engine_inputs,
create_causal_mask,
initialize_kv_cache_state,
validate_session_ids,
)


@pytest.mark.parametrize(
"onnx_input_names, "
"token_batch, "
"prompt_sequence_length, "
"sequence_length, "
"num_total_processed_tokens, "
"expected_engine_inputs",
[
(
["input_ids", "attention_mask", "positions"],
[1, 2, 3],
3,
6,
2,
[
numpy.array([[1, 2, 3]]),
numpy.array([[0, 1, 1, 1, 1, 1]]),
numpy.array([[2, 3, 4]]),
],
),
(
["input_ids", "attention_mask", "positions", "causal_mask"],
[1, 2, 3],
3,
6,
2,
[
numpy.array([[1, 2, 3]]),
numpy.array([[0, 1, 1, 1, 1, 1]]),
numpy.array([[2, 3, 4]]),
create_causal_mask(
input_ids=numpy.array([[1, 2, 3]]),
attention_mask=numpy.array([[0, 1, 1, 1, 1, 1]]),
),
],
),
(
["input_ids", "attention_mask", "positions", "causal_mask"],
[15],
1,
5,
3,
[
numpy.array([[15]]),
numpy.array([[0, 1, 1, 1, 1]]),
numpy.array([[3]]),
create_causal_mask(
input_ids=numpy.array([[15]]),
attention_mask=numpy.array([[0, 1, 1, 1, 1]]),
),
],
),
],
)
def test_compute_engine_inputs(
onnx_input_names,
token_batch,
prompt_sequence_length,
sequence_length,
num_total_processed_tokens,
expected_engine_inputs,
):
engine_inputs = compute_engine_inputs(
onnx_input_names=onnx_input_names,
token_batch=token_batch,
prompt_sequence_length=prompt_sequence_length,
sequence_length=sequence_length,
num_total_processed_tokens=num_total_processed_tokens,
)
for x, y in zip(engine_inputs, expected_engine_inputs):
assert numpy.array_equal(x, y)


@pytest.mark.parametrize(
"input_ids, attention_mask, expected_causal_mask",
[
Expand Down