Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def can_operate(self, inp: Any) -> bool:
if inp.get("in_generation"):
return True

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens
can_process = (
remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -43,9 +42,6 @@ def run(self, inference_state: InferenceState, **kwargs):
generated_logits = inference_state.current_state.get("generated_logits")
finished_reason = inference_state.current_state.get("finished_reason")

if len(finished_reason) == 0:
finished_reason.append(FinishReason.LENGTH)

generated_tokens = numpy.array([generated_tokens])
generated_logits = numpy.concatenate(generated_logits, axis=1)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import (
NLEngineOutputs,
)
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -36,14 +33,16 @@ def __init__(
self.force_max_tokens = force_max_tokens
self.tokenizer = tokenizer

def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]):
def can_operate(
self, inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"] # noqa: F821
):
if inp.in_generation:
return True
return False

def run(
self,
inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs],
inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"], # noqa: F821
inference_state: InferenceState,
**kwargs,
):
Expand All @@ -52,21 +51,26 @@ def run(
if isinstance(inp, NLEngineOutputs)
else inp.prompt_logits
)
kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None
kv_cache = inp.kv_cache

max_tokens = inference_state.current_state.get("max_tokens")
length_finish_reason = inference_state.current_state.get("length_finish_reason")
generated_tokens = inference_state.current_state.get("generated_tokens")
num_generated_tokens = len(generated_tokens)

token_generator = inference_state.current_state.get("token_generator")
token = token_generator.generate(logits=logits[0, -1, :])
finish_reason = None

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if (
kv_cache is not None
and kv_cache.total_num_processed_tokens >= kv_cache.capacity
):
finish_reason = FinishReason.CAPACITY

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
finish_reason = FinishReason.STOP

Expand All @@ -84,9 +88,11 @@ def run(
)
finish_reason = FinishReason.CALLBACK

max_tokens = inference_state.current_state.get("max_tokens")
if len(inference_state.current_state.get("generated_tokens")) + 1 >= max_tokens:
finish_reason = inference_state.current_state.get("length_finish_reason")
# Note: this is +1 as the inference state variable keeping track of all the
# generated tokens has not yet been updated with the most recently generated
# token from this operator
if num_generated_tokens + 1 == max_tokens:
finish_reason = length_finish_reason

state_update = {
"token_generator": token_generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
tokens = inp.get("tokens")

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

if len(tokens) < self.prompt_sequence_length:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def __init__(
sequence_length=sequence_length,
prompt_sequence_length=prompt_sequence_length,
token_generator=token_generator,
process_output_operator=process_output,
)

# TODO: do we want to support lists for different engines?
Expand Down Expand Up @@ -286,7 +285,7 @@ def __init__(
"compile_logits",
"generate_new_token",
],
"prep_for_generation": "autoregressive_preprocess",
"prep_for_generation": "generate_new_token",
"generate_new_token": "compile_generated_tokens",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepsparse.routers import GraphRouter
from deepsparse.schedulers import OperatorScheduler
from deepsparse.transformers.pipelines.text_generation import (
CompileGeneratedTokens,
CompileGenerations,
GenerateNewTokenOperator,
JoinOutput,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
tokenizer=self.tokenizer, force_max_tokens=True
)
compile_generations = CompileGenerations()
compile_generated_tokens = CompileGeneratedTokens()
join_output = JoinOutput(tokenizer=self.tokenizer)
process_outputs = ProcessOutputs(tokenizer=self.tokenizer)

Expand All @@ -82,6 +84,7 @@ def __init__(
"engine_operator": engine_operator,
"prepare_generation": prepare_generation,
"generate_new_token": generate_new_token,
"compile_generated_tokens": compile_generated_tokens,
"compile_generations": compile_generations,
"join_output": join_output,
"process_outputs": process_outputs,
Expand All @@ -92,7 +95,8 @@ def __init__(
"SPLIT": "engine_operator",
"engine_operator": "prepare_generation",
"prepare_generation": "generate_new_token",
"generate_new_token": "compile_generations",
"generate_new_token": "compile_generated_tokens",
"compile_generated_tokens": "compile_generations",
"compile_generations": "JOIN",
"JOIN": "join_output",
"join_output": "process_outputs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,38 @@
from typing import Any, Optional

import numpy
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.subgraph_execute import StreamingOutput
from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.transformers.utils.helpers import set_generated_length
from deepsparse.utils import InferenceState


__all__ = ["PrepareGeneration"]
__all__ = ["PrepareGeneration", "PrepareForGenerationOutput"]


class PrepareForGenerationOutput(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated during prefill"
)
kv_cache: Optional[Any] = Field(description="kv cache")
in_generation: Optional[bool] = Field(description="in_generation flag")


class PrepareGeneration(Operator):
output_schema = PrepareForGenerationOutput

def __init__(
self,
token_generator: TokenGeneratorOperator,
prompt_sequence_length: int,
sequence_length: int,
process_output_operator: Optional[Operator] = None,
):
self.sequence_length = sequence_length
self.token_generator_creator = token_generator
self.prompt_sequence_length = prompt_sequence_length
# Needed for streaming as currently both setting up generation and generating
# Will split this up soon
self.process_output_operator = process_output_operator

def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
Expand Down Expand Up @@ -79,7 +82,6 @@ def run(
**inference_state.current_state,
)
token_generator = token_generator_creator_output.get("token_generator")
token_generator.generate(prompt_logits[0, -1, :])

max_tokens, length_finish_reason = set_generated_length(
max_length=generation_config.max_length,
Expand All @@ -93,43 +95,21 @@ def run(
state_update = {
"max_tokens": max_tokens,
"length_finish_reason": length_finish_reason,
"generated_tokens": [token_generator.tokens[-1]],
"generated_logits": [prompt_logits]
"generated_tokens": [],
"generated_logits": [prompt_logits[:, 0:-1, :]]
if include_prompt_logits
else [numpy.expand_dims(prompt_logits[:, -1, :], 0)],
else [],
"finished_reason": [],
"token_generator": token_generator,
}

if kv_cache is None:
output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits)
output = {"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0)}
else:
output = {
"tokens": token_generator.tokens,
"kv_cache": kv_cache,
"in_generation": True,
"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0),
}
# TODO: maybe break this operator up since it is both generating and setting
# up values needed for generation? Holding off on this as this will change
# routes slighty and want to confirm wont break anything for non-kv cache
if inference_state.current_state.get("streaming") and max_tokens >= 1:
finished_reason = [length_finish_reason] if max_tokens == 1 else [None]

if self.process_output_operator is None:
raise ValueError(
"An operator must be provided to process outputs"
"while streaming."
)
data_to_yield = self.process_output_operator.run(
generated_tokens=state_update.get("generated_tokens"),
finished_reason=finished_reason,
inference_state=inference_state,
generated_logits=prompt_logits[0, -1, :],
)
output = StreamingOutput(
data_to_yield=self.process_output_operator.output_schema(
**data_to_yield
),
data_to_return=output,
)

return output, state_update
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,3 @@ class TextGenerationOutput(BaseModel):
class Config:
arbitrary_types_allowed = True
extra = "allow"


class PromptLogitsNoKVCacheInference(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated "
"during the inference pass with a "
"non-kv cache model"
)
4 changes: 3 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def set_generated_length(
:param max_new_tokens: the max_new_tokens attribute, which may be provided
as part of the input during inference
"""
if max_length:
if max_length is not None:
# if max_length provided, use that to cap total tokens generated
if max_length == 0:
raise ValueError("max_length must be greater than 0")
max_tokens = max_length
finish_reason = finish_reason_choices.LENGTH
else:
Expand Down
52 changes: 52 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_stop_inference_kv_cache_full(prompt):
expected_generated_tokens_length=max_new_tokens_plus_one,
expected_finished_reason="capacity",
)

"""
Check the following structure ok the kv cache:
minus_one | full | plus_one | plus_two
Expand All @@ -152,6 +153,7 @@ def test_stop_inference_kv_cache_full(prompt):
[row B] | [row C] | [row D] | [row D]
... | ... | ... | ...
"""

# check for the "free" space in the kv cache
assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0
# check for the row A
Expand Down Expand Up @@ -282,3 +284,53 @@ def test_streaming_non_streaming_generate_same_tokens(pipeline, prompt):
tokens.append(g.generations[0].text)
output_2 = "".join(tokens)
assert output_1 == output_2


def test_edge_cases(pipeline, prompt):
# total length of the generated sequence is just 1 token; this should just use
# the last prompt logit
output = pipeline(prompt=prompt, max_length=1, output_scores=True)
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_length=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

# max_new_tokens == 0 and max_length == 1 should result in the same behaviour
# the generation is only dependent on the prompt logit, not any new generated logit
output = pipeline(prompt=prompt, max_new_tokens=0, output_scores=True)
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_new_tokens=0, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

# expect total scores/length of the generation to be 2: 1 for the token generated
# from the last prompt logit and the rest generated from the value provided
# using the max_new_tokens argument (which in this case is 1)
output = pipeline(prompt=prompt, max_new_tokens=1, output_scores=True)
assert len(output.generations[0].score) == 2

output = pipeline(
prompt=prompt, max_new_tokens=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 12

# dont support max_length == 0; raise value error
with pytest.raises(ValueError):
pipeline(prompt=prompt, max_length=0)


def test_kv_cache_too_small_for_prefill(prompt):
for i in range(10):
prompt += prompt

pipeline = Pipeline.create(
task="text_generation",
model_path="hf:mgoin/TinyStories-1M-deepsparse",
sequence_length=25,
)
with pytest.raises(RuntimeError):
pipeline(prompt=prompt)
Loading