Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions src/deepsparse/v2/operators/engine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from deepsparse import Context as EngineContext
from deepsparse import Engine, MultiModelEngine, Scheduler
from deepsparse.benchmark import ORTEngine
from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs
from deepsparse.utils import model_to_path
from deepsparse.v2.operators import Operator


Expand Down Expand Up @@ -145,18 +145,6 @@ def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict:
# planned refactor
engine_outputs = inp.engine(inp.engine_inputs)
return {"engine_outputs": engine_outputs}
inp = inp.engine_inputs
batches, orig_batch_size = self.expand_inputs(engine_inputs=inp)
batches_outputs = list(map(self.engine, batches))
engine_outputs = self.condense_inputs(
batch_outputs=batches_outputs, orig_batch_size=orig_batch_size
)
return {"engine_outputs": engine_outputs}

def expand_inputs(self, **kwargs):
return split_engine_inputs(kwargs["engine_inputs"], self._batch_size)

def condense_inputs(self, **kwargs):
batch_outputs = kwargs["batch_outputs"]
orig_batch_size = kwargs["orig_batch_size"]
return join_engine_outputs(batch_outputs, orig_batch_size)
engine_outputs = self.engine(inp.engine_inputs)
return {"engine_outputs": engine_outputs}
13 changes: 0 additions & 13 deletions src/deepsparse/v2/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __call__(
pipeline_state=pipeline_state,
**kwargs,
)

if self.has_output_schema():
return self.output_schema(**run_output)
return run_output
Expand All @@ -117,18 +116,6 @@ def can_operate(self, inp: Any) -> bool:
"""
return True

def expand_inputs(self, **kwargs):
"""
Generic function to handle expanding values.
"""
raise NotImplementedError

def condense_inputs(self, **kwargs):
"""
Generic function to handle condensing values.
"""
raise NotImplementedError

def yaml(self):
pass

Expand Down
153 changes: 125 additions & 28 deletions src/deepsparse/v2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.


from typing import Dict, List, Union
import copy
from concurrent.futures import Future
from functools import partial
from typing import Any, Callable, Dict, List, Union

from deepsparse.v2.operators import Operator
from deepsparse.v2.routers import Router
Expand Down Expand Up @@ -56,9 +59,88 @@ def __init__(
self.pipeline_state = pipeline_state
self.validate()

# SchedulerGroup handles running all schedulers in order of priority
self._scheduler_group = SchedulerGroup(self.schedulers)

def _run_sequential(
self,
inp: Any,
inference_state: InferenceState,
pipeline_state: PipelineState,
start: str,
end: str,
):
next_step = start
while next_step != end:
outputs = self._run_next_step(
func=self.ops[next_step],
next_step=next_step,
input=inp,
pipeline_state=pipeline_state,
inference_state=inference_state,
)
next_step, operator_output, state_update = outputs
if state_update:
inference_state.update_state(state_update)
inp = operator_output
return inp

def _apply_split(self, inp: Any, inference_state: InferenceState):
"""
Split inputs using the pipeline's expand_inputs function. Inputs are split
into a batch size of one when a SPLIT_ROUTE node is found in a given pipeline's
provided router. The split batches are run asynchronously and then joined when
a JOIN_ROUTE node is found, using the pipeline's condense_inputs function.
"""

batches, orig_batch_size = self.expand_inputs(inp, 1)
run_with_state = partial(
self._run_sequential,
pipeline_state=self.pipeline_state,
start=self.router.route[self.router.SPLIT_ROUTE],
end=self.router.JOIN_ROUTE,
)
inference_state_list = [
copy.deepcopy(inference_state) for x in range(len(batches))
]
futures = self._scheduler_group.map(
batches,
inference_state_list,
func=run_with_state,
)
return self.condense_inputs([x.result() for x in futures])

def _run_next_step(
self,
*args,
func: Callable,
next_step: Union[str, int],
input: Any = None,
**kwargs,
):
"""
Generic function to run a given func, process the output and determine the next
step.
"""
if input:
operator_output = (
func(*args, **kwargs, **input)
if isinstance(input, dict)
else func(input, *args, **kwargs)
)
else:
operator_output = func(*args, **kwargs)

if isinstance(operator_output, Future):
operator_output = operator_output.result()

state_update = None
if isinstance(operator_output, tuple):
state_update = operator_output[-1]
operator_output = operator_output[0]

next_step = self.router.next(next_step, self.ops, operator_output)
return next_step, operator_output, state_update

def run(
self,
*args,
Expand All @@ -78,40 +160,34 @@ def run(
operator_output = None

while next_step != self.router.END_ROUTE:
# Either a dictionary key or valid index
operator = self.ops[next_step]
# NOTE: split_route should only appear after the start route node
if next_step == self.router.SPLIT_ROUTE:
operator_output = self._apply_split(operator_output, inference_state)
next_step = self.router.route[self.router.JOIN_ROUTE]

if next_step == self.router.START_ROUTE:
output_future = self._scheduler_group.submit(
outputs = self._run_next_step(
*args,
next_step=next_step,
func=self._scheduler_group.submit,
inference_state=inference_state,
operator=operator,
operator=self.ops[next_step],
pipeline_state=pipeline_state,
**kwargs,
)
else:
if isinstance(operator_output, dict):
output_future = self._scheduler_group.submit(
inference_state=inference_state,
operator=operator,
pipeline_state=pipeline_state,
**operator_output,
)
else:
output_future = self._scheduler_group.submit(
operator_output,
inference_state=inference_state,
pipeline_state=pipeline_state,
operator=operator,
)

operator_output = output_future.result()
if isinstance(operator_output, tuple):
state_update = operator_output[-1]
operator_output = operator_output[0]
inference_state.update_state(state_update)

next_step = self.router.next(next_step, self.ops, operator_output)
outputs = self._run_next_step(
func=self._scheduler_group.submit,
input=operator_output,
next_step=next_step,
inference_state=inference_state,
operator=self.ops[next_step],
pipeline_state=pipeline_state,
)

next_step, operator_output, state_update = outputs
if state_update:
inference_state.update_state(state_update)
return operator_output

def __call__(self, *args, **kwargs):
Expand All @@ -136,6 +212,27 @@ def __call__(self, *args, **kwargs):

return self.run(*args, **kwargs)

def expand_inputs(self, *args, **kwargs):
"""
Generic function to handle expanding values.
"""
raise NotImplementedError(
"This function should be implemented for any router with split or join"
"nodes. expand_inputs will be called prior to the split node (stored in "
"the router's SPLIT_ROUTE attribute), expanding outputs for each output "
"such that there is a batch size of one per thread."
)

def condense_inputs(self, *args, **kwargs):
"""
Generic function to handle condensing values.
"""
raise NotImplementedError(
"This function should be implemented for any router with split or join "
"nodes. condense_inputs will be called after the join node (stored in the "
"router's JOIN_ROUTE attribute), condensing outputs from multiple threads."
)

def validate(self):
"""
Validate that compatability of the router and operators provided.
Expand Down
13 changes: 11 additions & 2 deletions src/deepsparse/v2/routers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ def __init__(
end_route: Union[str, int],
start_route: Union[str, int],
route: Optional[Dict] = None,
split_route: str = "SPLIT",
join_route: str = "JOIN",
):
self.START_ROUTE = start_route
self.END_ROUTE = end_route
self.SPLIT_ROUTE = split_route
self.JOIN_ROUTE = join_route
self.route = route

@abstractmethod
Expand Down Expand Up @@ -79,6 +83,9 @@ class LinearRouter(Router):

def __init__(self, end_route: int, start_route: int = 0):
super().__init__(end_route=end_route, start_route=start_route)
self.SPLIT_ROUTE = None
self.JOIN_ROUTE = None
_LOGGER.warn("SPLIT and JOIN are not yet supported for the LinearRouter.")

def next(
self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None
Expand Down Expand Up @@ -128,8 +135,10 @@ class GraphRouter(Router):
where `can_operate` returns True will run. Paths should be deterministic.
"""

def __init__(self, end_route: str, start_route: str, route: Dict):
super().__init__(end_route=end_route, start_route=start_route, route=route)
def __init__(self, end_route: str, start_route: str, route: Dict, **kwargs):
super().__init__(
end_route=end_route, start_route=start_route, route=route, **kwargs
)

def next(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/deepsparse/v2/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from concurrent.futures import Future, ThreadPoolExecutor
from typing import Callable

from deepsparse.v2.operators import Operator

Expand Down Expand Up @@ -64,3 +65,13 @@ def can_process(
Base OperatorScheduler always returns True
"""
return True

def map(self, *args, func: Callable):
"""
:param func: generic callable run for each arg
:return: list of futures for each submit
"""
futures = []
for _, values in enumerate(zip(*args)):
futures.append(self.submit(*values, operator=func))
return futures
20 changes: 0 additions & 20 deletions src/deepsparse/v2/schedulers/scheduler_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,3 @@ def submit(
operator=operator,
**kwargs,
)

def can_process(
self,
*args,
operator: Operator,
**kwargs,
) -> bool:
"""
:param operator: operator to check
:return: True if this Operator can process the given operator and input.
SchedulerGroup always returns True
"""
return any(
scheduler.can_process(
*args,
operator=operator,
**kwargs,
)
for scheduler in self.schedulers
)
1 change: 1 addition & 0 deletions src/deepsparse/v2/text_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .compile_generations import *
from .compile_logits import *
from .generate_new_token import *
from .join_output import *
from .kv_cache_operator import *
from .multi_engine_prefill_operator import *
from .nl_engine_operator import *
Expand Down
Loading