Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def can_operate(self, inp: Any) -> bool:
if inp.get("in_generation"):
return True

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens
can_process = (
remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -43,9 +42,6 @@ def run(self, inference_state: InferenceState, **kwargs):
generated_logits = inference_state.current_state.get("generated_logits")
finished_reason = inference_state.current_state.get("finished_reason")

if len(finished_reason) == 0:
finished_reason.append(FinishReason.LENGTH)

generated_tokens = numpy.array([generated_tokens])
generated_logits = numpy.concatenate(generated_logits, axis=1)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import (
NLEngineOutputs,
)
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -36,14 +33,16 @@ def __init__(
self.force_max_tokens = force_max_tokens
self.tokenizer = tokenizer

def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]):
def can_operate(
self, inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"] # noqa: F821
):
if inp.in_generation:
return True
return False

def run(
self,
inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs],
inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"], # noqa: F821
inference_state: InferenceState,
**kwargs,
):
Expand All @@ -52,21 +51,26 @@ def run(
if isinstance(inp, NLEngineOutputs)
else inp.prompt_logits
)
kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None
kv_cache = inp.kv_cache

max_tokens = inference_state.current_state.get("max_tokens")
length_finish_reason = inference_state.current_state.get("length_finish_reason")
generated_tokens = inference_state.current_state.get("generated_tokens")
num_generated_tokens = len(generated_tokens)

token_generator = inference_state.current_state.get("token_generator")
token = token_generator.generate(logits=logits[0, -1, :])
finish_reason = None

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if (
kv_cache is not None
and kv_cache.total_num_processed_tokens >= kv_cache.capacity
):
finish_reason = FinishReason.CAPACITY

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
finish_reason = FinishReason.STOP

Expand All @@ -84,9 +88,11 @@ def run(
)
finish_reason = FinishReason.CALLBACK

max_tokens = inference_state.current_state.get("max_tokens")
if len(inference_state.current_state.get("generated_tokens")) + 1 >= max_tokens:
finish_reason = inference_state.current_state.get("length_finish_reason")
# Note: this is +1 as the inference state variable keeping track of all the
# generated tokens has not yet been updated with the most recently generated
# token from this operator
if num_generated_tokens + 1 == max_tokens:
finish_reason = length_finish_reason

state_update = {
"token_generator": token_generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
tokens = inp.get("tokens")

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

if len(tokens) < self.prompt_sequence_length:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def __init__(
sequence_length=sequence_length,
prompt_sequence_length=prompt_sequence_length,
token_generator=token_generator,
process_output_operator=process_output,
)

# TODO: do we want to support lists for different engines?
Expand Down Expand Up @@ -286,7 +285,7 @@ def __init__(
"compile_logits",
"generate_new_token",
],
"prep_for_generation": "autoregressive_preprocess",
"prep_for_generation": "generate_new_token",
"generate_new_token": "compile_generated_tokens",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepsparse.routers import GraphRouter
from deepsparse.schedulers import OperatorScheduler
from deepsparse.transformers.pipelines.text_generation import (
CompileGeneratedTokens,
CompileGenerations,
GenerateNewTokenOperator,
JoinOutput,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
tokenizer=self.tokenizer, force_max_tokens=True
)
compile_generations = CompileGenerations()
compile_generated_tokens = CompileGeneratedTokens()
join_output = JoinOutput(tokenizer=self.tokenizer)
process_outputs = ProcessOutputs(tokenizer=self.tokenizer)

Expand All @@ -82,6 +84,7 @@ def __init__(
"engine_operator": engine_operator,
"prepare_generation": prepare_generation,
"generate_new_token": generate_new_token,
"compile_generated_tokens": compile_generated_tokens,
"compile_generations": compile_generations,
"join_output": join_output,
"process_outputs": process_outputs,
Expand All @@ -92,7 +95,8 @@ def __init__(
"SPLIT": "engine_operator",
"engine_operator": "prepare_generation",
"prepare_generation": "generate_new_token",
"generate_new_token": "compile_generations",
"generate_new_token": "compile_generated_tokens",
"compile_generated_tokens": "compile_generations",
"compile_generations": "JOIN",
"JOIN": "join_output",
"join_output": "process_outputs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,38 @@
from typing import Any, Optional

import numpy
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.subgraph_execute import StreamingOutput
from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.transformers.utils.helpers import set_generated_length
from deepsparse.utils import InferenceState


__all__ = ["PrepareGeneration"]
__all__ = ["PrepareGeneration", "PrepareForGenerationOutput"]


class PrepareForGenerationOutput(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated during prefill"
)
kv_cache: Optional[Any] = Field(description="kv cache")
in_generation: Optional[bool] = Field(description="in_generation flag")


class PrepareGeneration(Operator):
output_schema = PrepareForGenerationOutput

def __init__(
self,
token_generator: TokenGeneratorOperator,
prompt_sequence_length: int,
sequence_length: int,
process_output_operator: Optional[Operator] = None,
):
self.sequence_length = sequence_length
self.token_generator_creator = token_generator
self.prompt_sequence_length = prompt_sequence_length
# Needed for streaming as currently both setting up generation and generating
# Will split this up soon
self.process_output_operator = process_output_operator

def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
Expand Down Expand Up @@ -79,7 +82,6 @@ def run(
**inference_state.current_state,
)
token_generator = token_generator_creator_output.get("token_generator")
token_generator.generate(prompt_logits[0, -1, :])

max_tokens, length_finish_reason = set_generated_length(
max_length=generation_config.max_length,
Expand All @@ -93,43 +95,21 @@ def run(
state_update = {
"max_tokens": max_tokens,
"length_finish_reason": length_finish_reason,
"generated_tokens": [token_generator.tokens[-1]],
"generated_logits": [prompt_logits]
"generated_tokens": [],
"generated_logits": [prompt_logits[:, 0:-1, :]]
if include_prompt_logits
else [numpy.expand_dims(prompt_logits[:, -1, :], 0)],
else [],
"finished_reason": [],
"token_generator": token_generator,
}

if kv_cache is None:
output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits)
output = {"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0)}
else:
output = {
"tokens": token_generator.tokens,
"kv_cache": kv_cache,
"in_generation": True,
"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0),
}
# TODO: maybe break this operator up since it is both generating and setting
# up values needed for generation? Holding off on this as this will change
# routes slighty and want to confirm wont break anything for non-kv cache
if inference_state.current_state.get("streaming") and max_tokens >= 1:
finished_reason = [length_finish_reason] if max_tokens == 1 else [None]

if self.process_output_operator is None:
raise ValueError(
"An operator must be provided to process outputs"
"while streaming."
)
data_to_yield = self.process_output_operator.run(
generated_tokens=state_update.get("generated_tokens"),
finished_reason=finished_reason,
inference_state=inference_state,
generated_logits=prompt_logits[0, -1, :],
)
output = StreamingOutput(
data_to_yield=self.process_output_operator.output_schema(
**data_to_yield
),
data_to_return=output,
)

return output, state_update
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,3 @@ class TextGenerationOutput(BaseModel):
class Config:
arbitrary_types_allowed = True
extra = "allow"


class PromptLogitsNoKVCacheInference(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated "
"during the inference pass with a "
"non-kv cache model"
)
4 changes: 3 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def set_generated_length(
:param max_new_tokens: the max_new_tokens attribute, which may be provided
as part of the input during inference
"""
if max_length:
if max_length is not None:
# if max_length provided, use that to cap total tokens generated
if max_length == 0:
raise ValueError("max_length must be greater than 0")
max_tokens = max_length
finish_reason = finish_reason_choices.LENGTH
else:
Expand Down
52 changes: 52 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_stop_inference_kv_cache_full(prompt):
expected_generated_tokens_length=max_new_tokens_plus_one,
expected_finished_reason="capacity",
)

"""
Check the following structure ok the kv cache:
minus_one | full | plus_one | plus_two
Expand All @@ -152,6 +153,7 @@ def test_stop_inference_kv_cache_full(prompt):
[row B] | [row C] | [row D] | [row D]
... | ... | ... | ...
"""

# check for the "free" space in the kv cache
assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0
# check for the row A
Expand Down Expand Up @@ -282,3 +284,53 @@ def test_streaming_non_streaming_generate_same_tokens(pipeline, prompt):
tokens.append(g.generations[0].text)
output_2 = "".join(tokens)
assert output_1 == output_2


def test_edge_cases(pipeline, prompt):
# total length of the generated sequence is just 1 token; this should just use
# the last prompt logit
output = pipeline(prompt=prompt, max_length=1, output_scores=True)
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_length=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

# max_new_tokens == 0 and max_length == 1 should result in the same behaviour
# the generation is only dependent on the prompt logit, not any new generated logit
output = pipeline(prompt=prompt, max_new_tokens=0, output_scores=True)
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_new_tokens=0, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

# expect total scores/length of the generation to be 2: 1 for the token generated
# from the last prompt logit and the rest generated from the value provided
# using the max_new_tokens argument (which in this case is 1)
output = pipeline(prompt=prompt, max_new_tokens=1, output_scores=True)
assert len(output.generations[0].score) == 2

output = pipeline(
prompt=prompt, max_new_tokens=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 12

# dont support max_length == 0; raise value error
with pytest.raises(ValueError):
pipeline(prompt=prompt, max_length=0)


def test_kv_cache_too_small_for_prefill(prompt):
for i in range(10):
prompt += prompt

pipeline = Pipeline.create(
task="text_generation",
model_path="hf:mgoin/TinyStories-1M-deepsparse",
sequence_length=25,
)
with pytest.raises(RuntimeError):
pipeline(prompt=prompt)
Loading