Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,19 @@
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import FlexibleArgumentParser, GB_bytes, get_open_port

from vllm_ascend.utils import vllm_version_is

from .model_utils import TextTextLogprobs

if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
from vllm.model_executor.model_loader.loader import get_model_loader # type: ignore[import] # isort: skip
else:
from vllm.model_executor.model_loader import get_model_loader

VLLM_PATH = Path(__file__).parent.parent
"""Path to root of the vLLM repository."""

Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def vanilla_chunked_prefill(

attn_output = (attn_output[q_mask].view([-1, num_query_heads,
head_dim]).to(output.dtype))
output = output.view_as(attn_output)
output.copy_(attn_output)
return attn_output

Expand Down
14 changes: 12 additions & 2 deletions vllm_ascend/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)

from vllm_ascend.utils import vllm_version_is

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

Expand Down Expand Up @@ -1007,7 +1009,10 @@ def save_sharded_state(
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
from vllm.model_executor.model_loader.loader import ShardedStateLoader # type: ignore[import] # isort: skip # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonable approach to skip code static checking. we could not install 2 version of vllm in the same python env, thus I agree to skip this in v0.8.5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ensure vllm-ascend main branch is compatible with both vllm version 0.8.5 and 0.8.5.post1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but this is just skip code format static check of 0.8.5 in CI of main branch. This has no impact on the features.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using vLLM v0.8.5 without checking this condition for v0.8.5 will trigger the else branch, causing a problem.

else:
from vllm.model_executor.model_loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
Expand All @@ -1019,7 +1024,12 @@ def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
from vllm.model_executor.model_loader.loader import TensorizerLoader
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
from vllm.model_executor.model_loader.loader import \
TensorizerLoader # type: ignore # noqa
else:
from vllm.model_executor.model_loader import \
TensorizerLoader # type: ignore # noqa
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
Expand Down