diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index d5fc5ed438..34591d8a64 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -30,11 +30,7 @@ from onnx import ModelProto from deepsparse.log import get_main_logger -from deepsparse.utils.onnx import ( - _MODEL_DIR_ONNX_NAME, - model_to_path, - truncate_onnx_model, -) +from deepsparse.utils.onnx import MODEL_ONNX_NAME, model_to_path, truncate_onnx_model from sparsezoo.utils import save_onnx @@ -55,6 +51,7 @@ def setup_transformers_pipeline( sequence_length: int, tokenizer_padding_side: str = "left", engine_kwargs: Optional[Dict] = None, + onnx_model_name: Optional[str] = None, ) -> Tuple[ str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer, Dict[str, Any] ]: @@ -65,10 +62,14 @@ def setup_transformers_pipeline( :param sequence_length: The sequence length to use for the model :param tokenizer_padding_side: The side to pad on for the tokenizer, either "left" or "right" + :param onnx_model_name: The name of the onnx model to be loaded. + If not specified, defaults are used (see fetch_onnx_file_path) :param engine_kwargs: The kwargs to pass to the engine :return The model path, config, tokenizer, and engine kwargs """ - model_path, config, tokenizer = fetch_onnx_file_path(model_path, sequence_length) + model_path, config, tokenizer = fetch_onnx_file_path( + model_path, sequence_length, onnx_model_name + ) tokenizer.padding_side = tokenizer_padding_side if not tokenizer.pad_token: @@ -89,6 +90,7 @@ def setup_transformers_pipeline( def fetch_onnx_file_path( model_path: str, sequence_length: int, + onnx_model_name: Optional[str] = None, task: Optional[str] = None, ) -> Tuple[str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer]: """ @@ -97,9 +99,13 @@ def fetch_onnx_file_path( derived from the `model_path` provided. :param model_path: path to the model to be parsed :param sequence_length: maximum sequence length of the model + :param onnx_model_name: optionally, the precise name of the ONNX model + of interest may be specified. If not specified, the default ONNX model + name will be used (refer to `get_deployment_path` for details) + :param task: task to use for the config. Defaults to None :return: file path to the processed ONNX file for the engine to compile """ - deployment_path, onnx_path = get_deployment_path(model_path) + deployment_path, onnx_path = get_deployment_path(model_path, onnx_model_name) hf_logger = logging.getLogger("transformers") hf_logger_level = hf_logger.level @@ -126,7 +132,9 @@ def fetch_onnx_file_path( return onnx_path, config, tokenizer -def get_deployment_path(model_path: str) -> Tuple[str, str]: +def get_deployment_path( + model_path: str, onnx_model_name: Optional[str] = None +) -> Tuple[str, str]: """ Returns the path to the deployment directory for the given model path and the path to the mandatory @@ -135,9 +143,13 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: for running the transformers model in the deepsparse pipeline :param model_path: path to model directory, sparsezoo stub, or ONNX file + :param onnx_model_name: name of the ONNX file to look for in the deployment + directory. Defaults to MODEL_ONNX_NAME :return: path to the deployment directory and path to the ONNX file inside the deployment directory """ + onnx_model_name = onnx_model_name or MODEL_ONNX_NAME + if os.path.isfile(model_path): # return the parent directory of the ONNX file return os.path.dirname(model_path), model_path @@ -145,17 +157,19 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: if os.path.isdir(model_path): model_files = os.listdir(model_path) - if _MODEL_DIR_ONNX_NAME not in model_files: + if onnx_model_name not in model_files: raise ValueError( - f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " + f"{onnx_model_name} not found in transformers model directory " f"{model_path}. Be sure that an export of the model is written to " - f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}" + f"{os.path.join(model_path, onnx_model_name)}" ) - return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME) + return model_path, os.path.join(model_path, onnx_model_name) elif model_path.startswith("zoo:") or model_path.startswith("hf:"): onnx_model_path = model_to_path(model_path) - return os.path.dirname(onnx_model_path), onnx_model_path + return os.path.dirname(onnx_model_path), onnx_model_path.replace( + MODEL_ONNX_NAME, onnx_model_name + ) else: raise ValueError( f"model_path {model_path} is not a valid file, directory, or zoo stub" diff --git a/src/deepsparse/transformers/pipelines/text_generation/__init__.py b/src/deepsparse/transformers/pipelines/text_generation/__init__.py index bbae1278e5..7914b0b361 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/__init__.py +++ b/src/deepsparse/transformers/pipelines/text_generation/__init__.py @@ -21,6 +21,7 @@ from .kv_cache_operator import * from .multi_engine_prefill_operator import * from .nl_engine_operator import * +from .nl_engine_operator_no_kv_cache import * from .parse_inputs import * from .prep_for_prefill import * from .process_inputs import * @@ -31,3 +32,4 @@ from .prep_for_generation import * # isort:skip from .pipeline import * # isort:skip +from .pipeline_no_kv_cache import * # isort:skip diff --git a/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py b/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py index 830a3e20bd..471a3d8dd2 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py +++ b/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py @@ -19,7 +19,10 @@ from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import ( NLEngineOutputs, ) -from deepsparse.transformers.schemas.text_generation_schemas import FinishReason +from deepsparse.transformers.schemas.text_generation_schemas import ( + FinishReason, + PromptLogitsNoKVCacheInference, +) from deepsparse.utils import InferenceState @@ -33,14 +36,23 @@ def __init__( self.force_max_tokens = force_max_tokens self.tokenizer = tokenizer - def can_operate(self, inp: NLEngineOutputs): + def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]): if inp.in_generation: return True return False - def run(self, inp: NLEngineOutputs, inference_state: InferenceState, **kwargs): - logits = inp.engine_outputs - kv_cache = inp.kv_cache + def run( + self, + inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs], + inference_state: InferenceState, + **kwargs, + ): + logits = ( + inp.engine_outputs + if isinstance(inp, NLEngineOutputs) + else inp.prompt_logits + ) + kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None token_generator = inference_state.current_state.get("token_generator") token = token_generator.generate(logits=logits[0, -1, :]) @@ -49,7 +61,10 @@ def run(self, inp: NLEngineOutputs, inference_state: InferenceState, **kwargs): callback = inference_state.current_state.get("callback") stop = inference_state.current_state.get("stop") - if kv_cache.total_num_processed_tokens >= kv_cache.capacity: + if ( + kv_cache is not None + and kv_cache.total_num_processed_tokens >= kv_cache.capacity + ): finish_reason = FinishReason.CAPACITY if token == self.tokenizer.eos_token_id and not self.force_max_tokens: diff --git a/src/deepsparse/transformers/pipelines/text_generation/join_output.py b/src/deepsparse/transformers/pipelines/text_generation/join_output.py index b8176c19db..ecfdc4b30e 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/join_output.py +++ b/src/deepsparse/transformers/pipelines/text_generation/join_output.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Dict, List, Tuple import numpy @@ -34,7 +34,8 @@ class JoinOutput(Operator): def __init__(self, tokenizer): self.tokenizer = tokenizer - def run(self, inp: List[CompileGenerationsOutput], **kwargs): + def run(self, inp: Tuple[List[CompileGenerationsOutput], Dict], **kwargs): + batch_outputs = [x for x in inp[0]] generated_tokens = [x.generated_tokens for x in batch_outputs] generated_logits = [x.generated_logits for x in batch_outputs] diff --git a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py index ece095d8d3..f78d4306d1 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py @@ -32,7 +32,7 @@ ) -__all__ = ["NLEngineOperator", "NLEngineInputs"] +__all__ = ["NLEngineOperator", "NLEngineInputs", "NLEngineOutputs"] class NLEngineInputs(BaseModel): diff --git a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator_no_kv_cache.py b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator_no_kv_cache.py new file mode 100644 index 0000000000..c6ae6c51f3 --- /dev/null +++ b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator_no_kv_cache.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy +from pydantic import BaseModel + +from deepsparse.operators.engine_operator import EngineOperator, EngineOperatorInputs +from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs + + +__all__ = [ + "NLEngineOperatorNoCache", + "NLEngineInputsNoCache", +] + + +class NLEngineInputsNoCache(BaseModel): + input_ids: Any + attention_mask: Any + + +class NLEngineOperatorNoCache(EngineOperator): + """ + Operator the Natural Language Engine, that operates without + KV Cache. This means that this operator merely maps input_ids + and attention_mask to logits + """ + + input_schema = NLEngineInputsNoCache + output_schema = None + + def __init__(self, sequence_length: int, **kwargs): + overwrite_transformer_onnx_model_inputs( + path=kwargs.get("model_path"), + batch_size=kwargs.get("batch_size", 1), + max_length=sequence_length, + ) + super().__init__(**kwargs) + + def run(self, inp: NLEngineInputsNoCache, **kwargs) -> Any: + engine_inputs = [inp.input_ids, inp.attention_mask] + logits = ( + super() + .run(EngineOperatorInputs(engine_inputs=engine_inputs), **kwargs) + .get("engine_outputs") + ) + + # By default, the engine outputs logits for all tokens in the sequence. + # Let's filter out the logits for the padding tokens. + logits = numpy.compress(inp.attention_mask.flatten(), logits[0], axis=1) + + return {"logits": [logits], "kv_cache": None, "tokens": None}, { + "prompt_logits": [logits] + } diff --git a/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py b/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py new file mode 100644 index 0000000000..ffa8eeebac --- /dev/null +++ b/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Optional + +from deepsparse.pipeline import Pipeline +from deepsparse.routers import GraphRouter +from deepsparse.schedulers import OperatorScheduler +from deepsparse.transformers.helpers import setup_transformers_pipeline +from deepsparse.transformers.pipelines.text_generation import ( + CompileGenerations, + GenerateNewTokenOperator, + JoinOutput, + NLEngineOperatorNoCache, + ParseTextGenerationInputs, + PrepareGeneration, + ProcessInputsTextGeneration, + ProcessOutputs, + TokenGeneratorOperator, +) +from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.utils import split_engine_inputs +from deepsparse.utils.onnx import default_cached_outputs + + +_LOGGER = logging.getLogger(__name__) + + +class TextGenerationPipelineNoCache(Pipeline): + def __init__( + self, + model_path: str, + sequence_length: int = 1024, + onnx_model_name: Optional[str] = None, + generation_config=None, + engine_kwargs: Optional[Dict] = None, + **kwargs, + ): + + ( + self.model_path, + self.config, + self.tokenizer, + engine_kwargs, + ) = setup_transformers_pipeline( + model_path, + sequence_length, + tokenizer_padding_side="right", + onnx_model_name=onnx_model_name, + engine_kwargs=engine_kwargs, + ) + self.verify_no_kv_cache_present() + + token_generator = TokenGeneratorOperator() + + parse_inputs = ParseTextGenerationInputs() + + process_inputs = ProcessInputsTextGeneration( + generation_config=process_generation_config(generation_config), + sequence_length=sequence_length, + tokenizer=self.tokenizer, + ) + engine_operator = NLEngineOperatorNoCache( + sequence_length=sequence_length, + **engine_kwargs, + ) + prepare_generation = PrepareGeneration( + sequence_length=sequence_length, + prompt_sequence_length=1, + token_generator=token_generator, + ) + generate_new_token = GenerateNewTokenOperator( + tokenizer=self.tokenizer, force_max_tokens=True + ) + compile_generations = CompileGenerations() + join_output = JoinOutput(tokenizer=self.tokenizer) + process_outputs = ProcessOutputs(tokenizer=self.tokenizer) + + ops = { + "parse_inputs": parse_inputs, + "process_input": process_inputs, + "engine_operator": engine_operator, + "prepare_generation": prepare_generation, + "generate_new_token": generate_new_token, + "compile_generations": compile_generations, + "join_output": join_output, + "process_outputs": process_outputs, + } + routes = { + "parse_inputs": "process_input", + "process_input": "SPLIT", + "SPLIT": "engine_operator", + "engine_operator": "prepare_generation", + "prepare_generation": "generate_new_token", + "generate_new_token": "compile_generations", + "compile_generations": "JOIN", + "JOIN": "join_output", + "join_output": "process_outputs", + "process_outputs": "STOP", + } + + # TODO: Using the GraphRouter, but should use + # LinearRouter with appropriate split/join support + router = GraphRouter( + end_route="STOP", start_route="process_input", route=routes + ) + scheduler = [OperatorScheduler()] + super().__init__( + ops=ops, + router=router, + schedulers=scheduler, + ) + + def run(self, *args, **kwargs): + # we need to set the fixed_sequences_length flag to True + # for the non-kv cache pipeline + kwargs.update(dict(fixed_sequences_length=True, max_new_tokens=1)) + return super().run(*args, **kwargs) + + def condense_inputs(self, *args, **kwargs): + return args[0], kwargs + + def expand_inputs(self, items, batch_size): + items = [items.get(key) for key in items.keys()] + out, orig_batch_size = split_engine_inputs(items, batch_size) + combined_batches = [{"input_ids": b[0], "attention_mask": b[1]} for b in out] + return combined_batches, orig_batch_size + + def verify_no_kv_cache_present(self) -> bool: + """ + Verifies that the ONNX model does not have + KV cache inputs/outputs present. + :return: True if compatible, False otherwise + """ + is_kv_cache_present = any(default_cached_outputs(self.model_path)) + if is_kv_cache_present: + raise ValueError( + f"The model: {self.model_path} has KV cache inputs/outputs present. " + "Please use the TextGenerationPipeline instead." + ) + return not is_kv_cache_present diff --git a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py index 0ac010aedf..572840d13e 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py @@ -18,7 +18,10 @@ from deepsparse.operators import Operator from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator -from deepsparse.transformers.schemas.text_generation_schemas import FinishReason +from deepsparse.transformers.schemas.text_generation_schemas import ( + FinishReason, + PromptLogitsNoKVCacheInference, +) from deepsparse.transformers.utils.helpers import set_generated_length from deepsparse.utils import InferenceState @@ -41,10 +44,11 @@ def can_operate(self, inp: Any): kv_cache = inp.get("kv_cache") tokens = inp.get("tokens") - # If the number of prompt tokens is greater than what we've processed, - # don't start generation. Should be equal when started as all prompt logits - # should be accounted for and we should have updated the kv_cache for the single - # token engine. + # If the number of prompt tokens is greater + # than what we've processed, don't start generation. + # Should be equal when started as all prompt logits + # should be accounted for, and we should have updated + # the kv_cache for the single token engine. if len(tokens) == kv_cache.total_num_processed_tokens: return True return False @@ -90,9 +94,12 @@ def run( "finished_reason": [], "token_generator": token_generator, } - output = { - "tokens": token_generator.tokens, - "kv_cache": kv_cache, - "in_generation": True, - } + if kv_cache is None: + output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits) + else: + output = { + "tokens": token_generator.tokens, + "kv_cache": kv_cache, + "in_generation": True, + } return output, state_update diff --git a/src/deepsparse/transformers/pipelines/text_generation/process_inputs.py b/src/deepsparse/transformers/pipelines/text_generation/process_inputs.py index 77a30ebe32..37b1bd29d7 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/process_inputs.py +++ b/src/deepsparse/transformers/pipelines/text_generation/process_inputs.py @@ -36,8 +36,8 @@ class ProcessInputsTextGeneration(Operator): """ Input processing operator. Responsible for tokenizing the input, handling the generation_config (if provided), updating the inference_state for later use, - and returning the tokens for prompt inferece. The expected input is defined by - the input_schema, which for this operator is TextGeneratioInput. + and returning the tokens for prompt inference. The expected input is defined by + the input_schema, which for this operator is TextGenerationInput. """ input_schema = TextGenerationInput diff --git a/src/deepsparse/transformers/schemas/text_generation_schemas.py b/src/deepsparse/transformers/schemas/text_generation_schemas.py index c3d9e28229..7c08aa8d80 100644 --- a/src/deepsparse/transformers/schemas/text_generation_schemas.py +++ b/src/deepsparse/transformers/schemas/text_generation_schemas.py @@ -165,3 +165,11 @@ class TextGenerationOutput(BaseModel): class Config: arbitrary_types_allowed = True extra = "allow" + + +class PromptLogitsNoKVCacheInference(BaseModel): + prompt_logits: Any = Field( + description="A set of prompt logits generated " + "during the inference pass with a " + "non-kv cache model" + ) diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index ae0913ffd7..e4b41f3286 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -56,12 +56,12 @@ "has_model_kv_cache", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", - "_MODEL_DIR_ONNX_NAME", + "MODEL_ONNX_NAME", ] _LOGGER = logging.getLogger(__name__) -_MODEL_DIR_ONNX_NAME = "model.onnx" +MODEL_ONNX_NAME = "model.onnx" CACHE_INPUT_PREFIX = "past_key_values" CACHE_OUTPUT_PREFIX = "present" @@ -132,7 +132,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model.deployment.path # default to the main onnx file for the model - model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME).path + model = model.deployment.get_file(MODEL_ONNX_NAME).path elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system @@ -143,10 +143,10 @@ def model_to_path(model: Union[str, Model, File]) -> str: from huggingface_hub import snapshot_download deployment_path = snapshot_download(repo_id=model.replace("hf:", "", 1)) - onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + onnx_path = os.path.join(deployment_path, MODEL_ONNX_NAME) if not os.path.isfile(onnx_path): raise ValueError( - f"Could not find the ONNX model file '{_MODEL_DIR_ONNX_NAME}' in the " + f"Could not find the ONNX model file '{MODEL_ONNX_NAME}' in the " f"Hugging Face Hub repository located at {deployment_path}. Please " f"ensure the model has been correctly exported to ONNX format and " f"exists in the repository." @@ -161,7 +161,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model_path = Path(model) if model_path.is_dir(): - return str(model_path / _MODEL_DIR_ONNX_NAME) + return str(model_path / MODEL_ONNX_NAME) return model diff --git a/tests/deepsparse/transformers/text_generation/integration_tests/configs/codegen.yaml b/tests/deepsparse/transformers/text_generation/integration_tests/configs/codegen.yaml index 904358b55f..9ec212a6cc 100644 --- a/tests/deepsparse/transformers/text_generation/integration_tests/configs/codegen.yaml +++ b/tests/deepsparse/transformers/text_generation/integration_tests/configs/codegen.yaml @@ -1,6 +1,7 @@ cadence: "nightly" model_path: "zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none" torch_model_name: "salesforce/codegen-350m-mono" +model_name_no_kv_cache: None prompt: "\ndef Fibonacci(n):\n # Check if input is 0 then it will\n # print incorrect input" precision: 0.0001 internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/transformers/text_generation/integration_tests/configs/gpt_neo.yaml b/tests/deepsparse/transformers/text_generation/integration_tests/configs/gpt_neo.yaml index b422efc831..71c57e1f97 100644 --- a/tests/deepsparse/transformers/text_generation/integration_tests/configs/gpt_neo.yaml +++ b/tests/deepsparse/transformers/text_generation/integration_tests/configs/gpt_neo.yaml @@ -1,6 +1,7 @@ cadence: "commit" model_path: "hf:mgoin/TinyStories-1M-ds" torch_model_name: "roneneldan/TinyStories-1M" +model_name_no_kv_cache: "model-orig.onnx" prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio" precision: 0.001 internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/transformers/text_generation/integration_tests/configs/opt.yaml b/tests/deepsparse/transformers/text_generation/integration_tests/configs/opt.yaml index ff2350dbe7..216d4c03ca 100644 --- a/tests/deepsparse/transformers/text_generation/integration_tests/configs/opt.yaml +++ b/tests/deepsparse/transformers/text_generation/integration_tests/configs/opt.yaml @@ -1,6 +1,7 @@ cadence: "nightly" model_path: "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none" torch_model_name: "facebook/opt-1.3b" +model_name_no_kv_cache: None prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio" precision: 0.0001 internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/transformers/text_generation/integration_tests/test_llms.py b/tests/deepsparse/transformers/text_generation/integration_tests/test_llms.py index 45ba1135b7..82a81d611c 100644 --- a/tests/deepsparse/transformers/text_generation/integration_tests/test_llms.py +++ b/tests/deepsparse/transformers/text_generation/integration_tests/test_llms.py @@ -15,7 +15,7 @@ This test suite consumes config files to test the text generation pipeline for various scenarios. -A sample config file is a yaml that r_equires the following fields: +A sample config file is a yaml that requires the following fields: cadence: The cadence of the tests. The available options are: "nightly", "weekly" and "commit". By default, only the tests that have cadence "commit" will be run @@ -23,6 +23,8 @@ list of strings. model_path: The path to the model to be tested (sparsezoo stub/hf model path/local_path) + model_name_no_kv_cache: The name of the onnx model without + the KV cache support torch_model_name: The name of the torch model (to generate ground truth info) prompt: The prompt to use for testing @@ -33,17 +35,18 @@ values: [True], [False] or [True, False] (to test both external and internal KV cache management) """ -import os from typing import List, Tuple import numpy import pytest -from deepsparse import Pipeline -from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from deepsparse.pipeline import Pipeline +from deepsparse.transformers.pipelines.text_generation import ( + TextGenerationPipeline, + TextGenerationPipelineNoCache, +) from deepsparse.transformers.schemas.text_generation_schemas import TextGenerationOutput -from sparsezoo import Model -from tests.deepsparse.transformers.pipelines.legacy.integration_tests.helpers import ( +from tests.deepsparse.transformers.text_generation.integration_tests.helpers import ( TorchGroundTruthSource, parse_params, validate_internal_kv_cache, @@ -71,7 +74,7 @@ class TestsIntegrationLLMsPipelines: the text generation pipeline. """ - def get_pipeline(self, **kwargs) -> Pipeline: + def get_pipeline(self, kv_cache_support=True, **kwargs) -> Pipeline: """ If no kwargs provided, returns the cached "default" pipeline that is used for most of the tests. @@ -84,9 +87,17 @@ def get_pipeline(self, **kwargs) -> Pipeline: "default" pipeline is returned) :return: the appropriate pipeline """ + # TODO: This if statement should disappear once + # the TextGenerationPipeline contains the + # non-kv-cache version of the pipeline + text_generation_pipeline_class = ( + TextGenerationPipeline + if kv_cache_support + else TextGenerationPipelineNoCache + ) if not kwargs: if self.default_pipeline is None: - self.default_pipeline = TextGenerationPipeline( + self.default_pipeline = text_generation_pipeline_class( **self.default_pipeline_kwargs ) return self.default_pipeline @@ -94,7 +105,7 @@ def get_pipeline(self, **kwargs) -> Pipeline: # return a pipeline with the updated default kwargs updated_kwargs = self.default_pipeline_kwargs.copy() updated_kwargs.update(kwargs) - return TextGenerationPipeline(**updated_kwargs) + return text_generation_pipeline_class(**updated_kwargs) @pytest.fixture def setup(self, params_dict, max_new_tokens, internal_kv_cache): @@ -164,9 +175,7 @@ def test_ort_multi_token_prefill(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - pipeline = self.get_pipeline( - engine_type="onnxruntime", - ) + pipeline = self.get_pipeline(engine_type="onnxruntime") output = pipeline( prompt=self.prompt, include_prompt_logits=True, @@ -229,42 +238,32 @@ def test_deepsparse_multi_token_prefill(self, setup): run_kv_cache_validation=not self.internal_kv_cache, ) - @pytest.mark.skip( - "This test is skipped because we do " - "not have support for non-kv-cache models yet" - ) def test_inference_no_kv_cache_deepsparse(self, setup): self._test_inference_no_kv_cache(engine_type="deepsparse") - @pytest.mark.skip( - "This test is skipped because we do " - "not have support for non-kv-cache models yet" - ) def test_inference_no_kv_cache_ort(self, setup): self._test_inference_no_kv_cache(engine_type="onnxruntime") def _test_inference_no_kv_cache(self, engine_type): - model_path_no_cache = self._get_model_path_no_cache() pipeline = self.get_pipeline( - model_path=model_path_no_cache, engine_type=engine_type - ) - assert not pipeline.cache_support_enabled, ( - "This pipeline test inference using non-kv cache " - "model and thus should not support kv cache" + onnx_model_name=self.model_name_no_kv_cache, + kv_cache_support=False, + engine_type=engine_type, ) output = pipeline( - self.prompt, max_length=1, output_scores=True, include_prompt_logits=True + prompt=[self.prompt, self.prompt], + include_prompt_logits=True, + generation_kwargs=dict(output_scores=True), ) - prompt_length = self.torch_ground_truth[1].shape[1] - # prompt logits + one logit for the new generated token - logits = output.generations[0].score[-(prompt_length + 1) :, :] - # compute ground truth logits analogously + + # logits -> prompt logits + one logit for the new generated token generated_logits, prompt_logits, *_ = self.torch_ground_truth logits_gt = numpy.concatenate( [prompt_logits[0], generated_logits[0, :1, :]], axis=0 ) - assert numpy.allclose(logits, logits_gt, atol=self.precision) + for gen in output.generations: + assert numpy.allclose(gen.score, logits_gt, atol=self.precision) def _test_output( self, @@ -320,51 +319,3 @@ def _test_kv_cache_state( assert numpy.allclose( x[:, :, -start_index:-end_index, :], y, atol=self.precision ) - - def _get_model_path_no_cache(self): - if not self.model_path.startswith("zoo:"): - pytest.skip("For this test, for now only the zoo model is supported") - model = Model(self.model_path) - # fetch the necessary file names for pipeline creation - required_file_names = [ - os.path.basename(file.name) for file in model.deployment.files - ] - training_directory = model.training - onnx_model_name_no_cache = [ - os.path.basename(file.name) - for file in model.training.files - if file.name.endswith(".onnx") - ][0] - - # check if 'training' exists, - # if not, download the files - if "training" not in os.listdir(model._path): - for filename in required_file_names: - # download the files to a training directory - if filename.endswith(".data"): - # data files are typically stored in a deployment directory - # download them to training - file = model.deployment.get_file(filename) - assert ( - file is not None - ), f"Unable to find file {filename} in model {model}" - file.name = file.name.replace("deployment", "training") - file.download() - continue - - if filename.endswith(".onnx"): - # instead of `model.onnx` the onnx_model_name_no_cache - # should be downloaded - filename = filename.replace("model.onnx", onnx_model_name_no_cache) - - file = training_directory.get_file(filename) - assert ( - file is not None - ), f"Unable to find file {filename} in model {model}" - file.download() - # rename the model file to `model.onnx` - os.rename( - os.path.join(training_directory.path, onnx_model_name_no_cache), - os.path.join(training_directory.path, "model.onnx"), - ) - return training_directory._path diff --git a/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py new file mode 100644 index 0000000000..de12d0e709 --- /dev/null +++ b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import ( + TextGenerationPipelineNoCache, +) + + +@pytest.mark.parametrize( + "onnx_model_name, raise_error", + [("model.onnx", True), (None, True), ("model-orig.onnx", False)], +) +def test_verify_no_kv_cache_present(model_attributes, onnx_model_name, raise_error): + _, model_path = model_attributes + # model_path points to .../directory/model.onnx + # we need to go up one level to .../directory + model_path = os.path.dirname(model_path) + + if raise_error: + with pytest.raises(ValueError): + if onnx_model_name is None: + TextGenerationPipelineNoCache(model_path=model_path) + else: + TextGenerationPipelineNoCache( + model_path=model_path, onnx_model_name=onnx_model_name + ) + return + else: + TextGenerationPipelineNoCache( + model_path=model_path, onnx_model_name=onnx_model_name + )