Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,6 @@ def npugraph_ex_compile(
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# When currently using the FULL_DECODE_ONLY mode,
# the piecewise compilation level slicing process
# in vllm is also encountered.
# This process causes the output to no longer be
# wrapped as a tuple when the fx graph has a single
# output, but torch.compile has a mandatory check.
fx_graph = graph.graph
if not graph_returns_tuple(graph):
output_node = fx_graph.output_node()
with fx_graph.inserting_before(output_node):
return_value = output_node.args[0]
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
import torchair

# TODO: use a better way to lazy register replacement, instead of import one by one
Expand Down Expand Up @@ -118,8 +104,10 @@ def npugraph_ex_compile(

npugraph_ex = torchair.get_npu_backend(compiler_config=config)

compile_graph = npugraph_ex(graph, example_inputs)
return compile_graph, None
# torch.compile requires the output of the fx graph to be a tuple
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None
return npugraph_ex(graph, example_inputs), None


class AscendCompiler(CompilerInterface):
Expand Down
16 changes: 2 additions & 14 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,12 +1582,6 @@ def execute_model(
self.debugger.stop()
self.debugger.step()
return pool_output
# Sometimes, after the model is compiled through the AOT backend,
# the model output may become a list containing only one Tensor object.
if isinstance(hidden_states, list) and \
len(hidden_states) == 1 and \
isinstance(hidden_states[0], torch.Tensor):
hidden_states = hidden_states[0]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if broadcast_pp_output:
Expand Down Expand Up @@ -2296,14 +2290,8 @@ def _dummy_sampler_run(
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# TODO: need to rum a dummy sampler for generate task
# Sometimes, after the model is compiled through the AOT backend,
# the model output may become a list containing only one Tensor object.
if isinstance(hidden_states, list) and \
len(hidden_states) == 1 and \
isinstance(hidden_states[0], torch.Tensor):
hidden_states = hidden_states[0]
hidden_states = hidden_states[logit_indices]
output = self.model.compute_logits(hidden_states)
hidden_states = hidden_states[logit_indices]
output = self.model.compute_logits(hidden_states)
return output

def profile_run(self) -> None:
Expand Down